diff --git a/api/client.go b/api/client.go index ef162c13555cdba73c64ce38cc970c343c263eca..c0c5e5133803a66610ecfa21792c525643969502 100644 --- a/api/client.go +++ b/api/client.go @@ -500,7 +500,7 @@ func (c *Client) GetErrorsChannel() <-chan interfaces.ClientError { // Handles both auth confirm and requests func (c *Client) StartNetworkFollower(timeout time.Duration) error { u := c.GetUser() - jww.INFO.Printf("StartNetworkFollower() \n\tTransmisstionID: %s "+ + jww.INFO.Printf("StartNetworkFollower() \n\tTransmissionID: %s "+ "\n\tReceptionID: %s", u.TransmissionID, u.ReceptionID) return c.followerServices.start(timeout) @@ -580,9 +580,9 @@ func (c *Client) GetNetworkInterface() interfaces.NetworkManager { } // GetNodeRegistrationStatus gets the current state of node registration. It -// returns the the total number of nodes in the NDF and the number of those -// which are currently registers with. An error is returned if the network is -// not healthy. +// returns the total number of nodes in the NDF and the number of those which +// are currently registers with. An error is returned if the network is not +// healthy. func (c *Client) GetNodeRegistrationStatus() (int, int, error) { // Return an error if the network is not healthy if !c.GetHealth().IsHealthy() { diff --git a/api/send.go b/api/send.go index 28b8435425b8941290581342ee54838a8de27eda..3c6cc0178451408221692bb801458e69b4350ef8 100644 --- a/api/send.go +++ b/api/send.go @@ -56,7 +56,7 @@ func (c *Client) SendCMIX(msg format.Message, recipientID *id.ID, // SendManyCMIX sends many "raw" CMIX message payloads to each of the // provided recipients. Used for group chat functionality. Returns the // round ID of the round the payload was sent or an error if it fails. -func (c *Client) SendManyCMIX(messages map[id.ID]format.Message, +func (c *Client) SendManyCMIX(messages []message.TargetedCmixMessage, params params.CMIX) (id.Round, []ephemeral.Id, error) { return c.network.SendManyCMIX(messages, params) } diff --git a/api/utilsInterfaces_test.go b/api/utilsInterfaces_test.go index 380a691a4b21e23534b28a836ca8c30cf475afdf..3bb409ceea106d0881ff3186bad4e40ab8ce7b71 100644 --- a/api/utilsInterfaces_test.go +++ b/api/utilsInterfaces_test.go @@ -100,16 +100,13 @@ func (t *testNetworkManagerGeneric) Follow(report interfaces.ClientErrorReport) func (t *testNetworkManagerGeneric) CheckGarbledMessages() { return } - func (t *testNetworkManagerGeneric) GetVerboseRounds() string { return "" } - func (t *testNetworkManagerGeneric) SendE2E(message.Send, params.E2E, *stoppable.Single) ( []id.Round, cE2e.MessageID, time.Time, error) { rounds := []id.Round{id.Round(0), id.Round(1), id.Round(2)} return rounds, cE2e.MessageID{}, time.Time{}, nil - } func (t *testNetworkManagerGeneric) SendUnsafe(m message.Send, p params.Unsafe) ([]id.Round, error) { return nil, nil @@ -117,7 +114,7 @@ func (t *testNetworkManagerGeneric) SendUnsafe(m message.Send, p params.Unsafe) func (t *testNetworkManagerGeneric) SendCMIX(message format.Message, rid *id.ID, p params.CMIX) (id.Round, ephemeral.Id, error) { return id.Round(0), ephemeral.Id{}, nil } -func (t *testNetworkManagerGeneric) SendManyCMIX(messages map[id.ID]format.Message, p params.CMIX) (id.Round, []ephemeral.Id, error) { +func (t *testNetworkManagerGeneric) SendManyCMIX(messages []message.TargetedCmixMessage, p params.CMIX) (id.Round, []ephemeral.Id, error) { return 0, []ephemeral.Id{}, nil } func (t *testNetworkManagerGeneric) GetInstance() *network.Instance { diff --git a/bindings/fileTransfer.go b/bindings/fileTransfer.go new file mode 100644 index 0000000000000000000000000000000000000000..2ed9f9b02c9999dd11c3685bc20a3612d9d49bf9 --- /dev/null +++ b/bindings/fileTransfer.go @@ -0,0 +1,228 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package bindings + +import ( + "encoding/json" + ft "gitlab.com/elixxir/client/fileTransfer" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/primitives/id" + "time" +) + +// FileTransfer contains the file transfer manager. +type FileTransfer struct { + m *ft.Manager +} + +// FileTransferSentProgressFunc contains a function callback that tracks the +// progress of sending a file. It is called when a file part is sent, a file +// part arrives, the transfer completes, or on error. +type FileTransferSentProgressFunc interface { + SentProgressCallback(completed bool, sent, arrived, total int, err error) +} + +// FileTransferReceivedProgressFunc contains a function callback that tracks the +// progress of receiving a file. It is called when a file part is received, the +// transfer completes, or on error. +type FileTransferReceivedProgressFunc interface { + ReceivedProgressCallback(completed bool, received, total int, err error) +} + +// FileTransferReceiveFunc contains a function callback that notifies the +// receiver of an incoming file transfer. It is called on the reception of the +// initial file transfer message. +type FileTransferReceiveFunc interface { + ReceiveCallback(tid []byte, fileName string, sender []byte, size int, + preview []byte) +} + +// NewFileTransferManager creates a new file transfer manager and starts the +// sending and receiving threads. The receiveFunc is called everytime a new file +// transfer is received. The parameters string is a JSON formatted string of the +// fileTransfer.Params object. If it is left empty, then defaults are used. It +// must match the following format: {"MaxThroughput":150000} +func NewFileTransferManager(client *Client, receiveFunc FileTransferReceiveFunc, + parameters string) (FileTransfer, error) { + + receiveCB := func(tid ftCrypto.TransferID, fileName string, sender *id.ID, + size uint32, preview []byte) { + receiveFunc.ReceiveCallback( + tid.Bytes(), fileName, sender.Bytes(), int(size), preview) + } + + // JSON unmarshal parameters string + p := ft.DefaultParams() + if parameters != "" { + err := json.Unmarshal([]byte(parameters), &p) + if err != nil { + return FileTransfer{}, err + } + } + + // Create new file transfer manager + m, err := ft.NewManager(&client.api, receiveCB, p) + if err != nil { + return FileTransfer{}, err + } + + // Start sending and receiving threads + err = client.api.AddService(m.StartProcesses) + if err != nil { + return FileTransfer{}, err + } + + return FileTransfer{m}, nil +} + +// Send sends a file to the recipient. The sender must have an E2E relationship +// with the recipient. +// The file name is the name of the file to show a user. It has a max length of +// 32 bytes. +// The file data cannot be larger than 4 mB +// The retry float is the total amount of data to send relative to the data +// size. Data will be resent on error and will resend up to [(1 + retry) * +// fileSize]. +// The preview stores a preview of the data (such as a thumbnail) and is +// capped at 4 kB in size. +// Returns a unique transfer ID used to identify the transfer. +// PeriodMS is the duration, in milliseconds, to wait between progress callback +// calls. Set this large enough to prevent spamming. +func (f FileTransfer) Send(fileName string, fileData []byte, recipientID []byte, + retry float32, preview []byte, progressFunc FileTransferSentProgressFunc, + periodMS int) ([]byte, error) { + + // Create SentProgressCallback + progressCB := func(completed bool, sent, arrived, total uint16, err error) { + progressFunc.SentProgressCallback( + completed, int(sent), int(arrived), int(total), err) + } + + // Convert recipient ID bytes to id.ID + recipient, err := id.Unmarshal(recipientID) + if err != nil { + return []byte{}, err + } + + // Convert period to time.Duration + period := time.Duration(periodMS) * time.Millisecond + + // Send file + tid, err := f.m.Send( + fileName, fileData, recipient, retry, preview, progressCB, period) + if err != nil { + return nil, err + } + + // Return transfer ID as bytes and error + return tid.Bytes(), nil +} + +// RegisterSendProgressCallback allows for the registration of a callback to +// track the progress of an individual sent file transfer. The callback will be +// called immediately when added to report the current status of the transfer. +// It will then call every time a file part is sent, a file part arrives, the +// transfer completes, or an error occurs. It is called at most once every +// period, which means if events occur faster than the period, then they will +// not be reported and instead the progress will be reported once at the end of +// the period. +// The period is specified in milliseconds. +func (f FileTransfer) RegisterSendProgressCallback(transferID []byte, + progressFunc FileTransferSentProgressFunc, periodMS int) error { + + // Unmarshal transfer ID + tid := ftCrypto.UnmarshalTransferID(transferID) + + // Create SentProgressCallback + progressCB := func(completed bool, sent, arrived, total uint16, err error) { + progressFunc.SentProgressCallback( + completed, int(sent), int(arrived), int(total), err) + } + + // Convert period to time.Duration + period := time.Duration(periodMS) * time.Millisecond + + return f.m.RegisterSendProgressCallback(tid, progressCB, period) +} + +// Resend resends a file if Send fails. +func (f FileTransfer) Resend(transferID []byte) error { + // Unmarshal transfer ID + tid := ftCrypto.UnmarshalTransferID(transferID) + + return f.m.Resend(tid) +} + +// CloseSend deletes a sent file transfer from the sent transfer map and from +// storage once a transfer has completed or reached the retry limit. Returns an +// error if the transfer has not run out of retries. +func (f FileTransfer) CloseSend(transferID []byte) error { + // Unmarshal transfer ID + tid := ftCrypto.UnmarshalTransferID(transferID) + + return f.m.CloseSend(tid) +} + +// RegisterReceiveProgressCallback allows for the registration of a callback to +// track the progress of an individual received file transfer. The callback will +// be called immediately when added to report the current status of the +// transfer. It will then call every time a file part is received, the transfer +// completes, or an error occurs. It is called at most once ever period, which +// means if events occur faster than the period, then they will not be reported +// and instead the progress will be reported once at the end of the period. +// Once the callback reports that the transfer has completed, the recipient +// can get the full file by calling Receive. +// The period is specified in milliseconds. +func (f FileTransfer) RegisterReceiveProgressCallback(transferID []byte, + progressFunc FileTransferReceivedProgressFunc, periodMS int) error { + // Unmarshal transfer ID + tid := ftCrypto.UnmarshalTransferID(transferID) + + // Create ReceivedProgressCallback + progressCB := func(completed bool, received, total uint16, err error) { + progressFunc.ReceivedProgressCallback( + completed, int(received), int(total), err) + } + + // Convert period to time.Duration + period := time.Duration(periodMS) * time.Millisecond + + return f.m.RegisterReceiveProgressCallback(tid, progressCB, period) +} + +// Receive returns the fully assembled file on the completion of the transfer. +// It deletes the transfer from the received transfer map and from storage. +// Returns an error if the transfer is not complete, the full file cannot be +// verified, or if the transfer cannot be found. +func (f FileTransfer) Receive(transferID []byte) ([]byte, error) { + // Unmarshal transfer ID + tid := ftCrypto.UnmarshalTransferID(transferID) + + return f.m.Receive(tid) +} + +//////////////////////////////////////////////////////////////////////////////// +// Utility Functions // +//////////////////////////////////////////////////////////////////////////////// + +// GetMaxFilePreviewSize returns the maximum file preview size, in bytes. +func (f FileTransfer) GetMaxFilePreviewSize() int { + return ft.PreviewMaxSize +} + +// GetMaxFileNameByteLength returns the maximum length, in bytes, allowed for a +// file name. +func (f FileTransfer) GetMaxFileNameByteLength() int { + return ft.FileNameMaxLen +} + +// GetMaxFileSize returns the maximum file size, in bytes, allowed to be +// transferred. +func (f FileTransfer) GetMaxFileSize() int { + return ft.FileMaxSize +} diff --git a/cmd/fileTransfer.go b/cmd/fileTransfer.go new file mode 100644 index 0000000000000000000000000000000000000000..5ba84a18cbc7ab6df98d0282e3f2748063dc55e5 --- /dev/null +++ b/cmd/fileTransfer.go @@ -0,0 +1,340 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package cmd + +import ( + "fmt" + "github.com/spf13/cobra" + jww "github.com/spf13/jwalterweatherman" + "github.com/spf13/viper" + "gitlab.com/elixxir/client/api" + ft "gitlab.com/elixxir/client/fileTransfer" + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/crypto/contact" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "gitlab.com/xx_network/primitives/utils" + "io/ioutil" + "strconv" + "time" +) + +const callbackPeriod = 25 * time.Millisecond + +// ftCmd starts the file transfer manager and allows the sending and receiving +// of files. +var ftCmd = &cobra.Command{ + Use: "fileTransfer", + Short: "Send and receive file for cMix client", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + + // Initialise a new client + client := initClient() + + // Print user's reception ID and save contact file + user := client.GetUser() + jww.INFO.Printf("User: %s", user.ReceptionID) + writeContact(user.GetContact()) + + // Start the network follower + err := client.StartNetworkFollower(5 * time.Second) + if err != nil { + jww.FATAL.Panicf("Failed to start the network follower: %+v", err) + } + + // Initialize the file transfer manager + maxThroughput := viper.GetInt("maxThroughput") + m, receiveChan := initFileTransferManager(client, maxThroughput) + + // Wait until connected or crash on timeout + connected := make(chan bool, 10) + client.GetHealth().AddChannel(connected) + waitUntilConnected(connected) + + // After connection, wait until registered with at least 85% of nodes + for numReg, total := 1, 100; numReg < (total*3)/4; { + time.Sleep(1 * time.Second) + + numReg, total, err = client.GetNodeRegistrationStatus() + if err != nil { + jww.FATAL.Panicf("Failed to get node registration status: %+v", + err) + } + + jww.INFO.Printf("Registering with nodes (%d/%d)...", numReg, total) + } + + // Start thread that receives new file transfers and prints them to log + receiveQuit := make(chan struct{}) + receiveDone := make(chan struct{}) + go receiveNewFileTransfers(receiveChan, receiveDone, receiveQuit, m) + + // If set, send the file to the recipient + sendDone := make(chan struct{}) + if viper.IsSet("sendFile") { + recipientContactPath := viper.GetString("sendFile") + filePath := viper.GetString("filePath") + filePreviewPath := viper.GetString("filePreviewPath") + filePreviewString := viper.GetString("filePreviewString") + retry := float32(viper.GetFloat64("retry")) + + sendFile(filePath, filePreviewPath, filePreviewString, + recipientContactPath, retry, m, sendDone) + } + + // Wait until either the file finishes sending or the file finishes + // being received, stop the receiving thread, and exit + for done := false; !done; { + select { + case <-sendDone: + jww.DEBUG.Printf("Finished sending message. Stopping threads " + + "and network follower.") + done = true + case <-receiveDone: + jww.DEBUG.Printf("Finished receiving message. Stopping " + + "threads and network follower.") + done = true + } + } + + // Stop reception thread + receiveQuit <- struct{}{} + + // Stop network follower + err = client.StopNetworkFollower() + if err != nil { + jww.WARN.Printf("Failed to stop network follower: %+v", err) + } + + jww.DEBUG.Print("File transfer finished stopping threads and network " + + "follower.") + }, +} + +// receivedFtResults is used to return received new file transfer results on a +// channel from a callback. +type receivedFtResults struct { + tid ftCrypto.TransferID + fileName string + sender *id.ID + size uint32 + preview []byte +} + +// initFileTransferManager creates a new file transfer manager with a new +// reception callback. Returns the file transfer manager and the channel that +// will be triggered when the callback is called. +func initFileTransferManager(client *api.Client, maxThroughput int) ( + *ft.Manager, chan receivedFtResults) { + + // Create interfaces.ReceiveCallback that returns the results on a channel + receiveChan := make(chan receivedFtResults, 100) + receiveCB := func(tid ftCrypto.TransferID, fileName string, sender *id.ID, + size uint32, preview []byte) { + receiveChan <- receivedFtResults{tid, fileName, sender, size, preview} + } + + // Create new parameters + p := ft.DefaultParams() + if maxThroughput != 0 { + p = ft.NewParams(maxThroughput) + } + + // Create new manager + manager, err := ft.NewManager(client, receiveCB, p) + if err != nil { + jww.FATAL.Panicf("Failed to create new file transfer manager: %+v", err) + } + + // Start the file transfer sending and receiving threads + err = client.AddService(manager.StartProcesses) + if err != nil { + jww.FATAL.Panicf("Failed to start file transfer threads: %+v", err) + } + + return manager, receiveChan +} + +// sendFile sends the file to the recipient and prints the progress. +func sendFile(filePath, filePreviewPath, filePreviewString, + recipientContactPath string, retry float32, m *ft.Manager, + done chan struct{}) { + + // Get file from path + fileData, err := utils.ReadFile(filePath) + if err != nil { + jww.FATAL.Panicf("Failed to read file %q: %+v", filePath, err) + } + + // Get file preview from path + filePreviewData := []byte(filePreviewString) + if filePreviewPath != "" { + filePreviewData, err = utils.ReadFile(filePreviewPath) + if err != nil { + jww.FATAL.Panicf("Failed to read file preview %q: %+v", + filePreviewPath, err) + } + } + + // Get recipient contact from file + recipient := getContactFromFile(recipientContactPath) + + jww.DEBUG.Printf("Sending file %q of size %d to recipient %s.", + filePath, len(fileData), recipient.ID) + + // Create sent progress callback that prints the results + progressCB := func(completed bool, sent, arrived, total uint16, err error) { + jww.DEBUG.Printf("Sent progress callback for %q "+ + "{completed: %t, sent: %d, arrived: %d, total: %d, err: %v}\n", + filePath, completed, sent, arrived, total, err) + if (sent == 0 && arrived == 0) || (arrived == total) || completed || + err != nil { + fmt.Printf("Sent progress callback for %q "+ + "{completed: %t, sent: %d, arrived: %d, total: %d, err: %v}\n", + filePath, completed, sent, arrived, total, err) + } + + if completed { + fmt.Printf("Completed sending file.\n") + done <- struct{}{} + } else if err != nil { + fmt.Printf("Failed sending file: %+v\n", err) + done <- struct{}{} + } + } + + // Send the file + _, err = m.Send(filePath, fileData, recipient.ID, retry, filePreviewData, + progressCB, callbackPeriod) + if err != nil { + jww.FATAL.Panicf("Failed to send file %q to %s: %+v", + filePath, recipient.ID, err) + } +} + +// receiveNewFileTransfers waits to receive new file transfers and prints its +// information to the log. +func receiveNewFileTransfers(receive chan receivedFtResults, done, + quit chan struct{}, m *ft.Manager) { + jww.DEBUG.Print("Starting thread waiting to receive NewFileTransfer " + + "E2E message.") + select { + case <-quit: + jww.DEBUG.Print("Quitting thread waiting for NewFileTransfer E2E " + + "message.") + return + case r := <-receive: + jww.DEBUG.Printf("Received new file %q transfer %s from %s of size %d "+ + "bytes with preview: %q", + r.fileName, r.tid, r.sender, r.size, r.preview) + fmt.Printf("Received new file transfer %q of size %d "+ + "bytes with preview: %q\n", r.fileName, r.size, r.preview) + + cb := newReceiveProgressCB(r.tid, done, m) + err := m.RegisterReceiveProgressCallback(r.tid, cb, callbackPeriod) + if err != nil { + jww.FATAL.Panicf("Failed to register new receive progress "+ + "callback for transfer %s: %+v", r.tid, err) + } + } +} + +// newReceiveProgressCB creates a new reception progress callback that prints +// the results to the log. +func newReceiveProgressCB(tid ftCrypto.TransferID, done chan struct{}, + m *ft.Manager) interfaces.ReceivedProgressCallback { + return func(completed bool, received, total uint16, err error) { + jww.DEBUG.Printf("Receive progress callback for transfer %s "+ + "{completed: %t, received: %d, total: %d, err: %v}", + tid, completed, received, total, err) + + if received == total || completed || err != nil { + fmt.Printf("Received progress callback "+ + "{completed: %t, received: %d, total: %d, err: %v}\n", + completed, received, total, err) + } + + if completed { + receivedFile, err2 := m.Receive(tid) + if err2 != nil { + jww.FATAL.Panicf("Failed to receive file %s: %+v", tid, err) + } + fmt.Printf("Completed receiving file:\n%s\n", receivedFile) + done <- struct{}{} + } else if err != nil { + fmt.Printf("Failed sending file: %+v\n", err) + done <- struct{}{} + } + } +} + +// getContactFromFile loads the contact from the given file path. +func getContactFromFile(path string) contact.Contact { + data, err := ioutil.ReadFile(path) + jww.INFO.Printf("Read in contact file of size %d bytes", len(data)) + if err != nil { + jww.FATAL.Panicf("Failed to read contact file: %+v", err) + } + + c, err := contact.Unmarshal(data) + if err != nil { + jww.FATAL.Panicf("Failed to unmarshal contact: %+v", err) + } + + return c +} + +//////////////////////////////////////////////////////////////////////////////// +// Command Line Flags // +//////////////////////////////////////////////////////////////////////////////// + +// init initializes commands and flags for Cobra. +func init() { + ftCmd.Flags().String("sendFile", "", + "Sends a file to a recipient with with the contact at this path.") + bindPFlagCheckErr("sendFile") + + ftCmd.Flags().String("filePath", "testFile-"+timeNanoString()+".txt", + "The path to the file to send. Also used as the file name.") + bindPFlagCheckErr("filePath") + + ftCmd.Flags().String("filePreviewPath", "", + "The path to the file preview to send. Set either this flag or "+ + "filePreviewString.") + bindPFlagCheckErr("filePreviewPath") + + ftCmd.Flags().String("filePreviewString", "", + "File preview data. Set either this flag or filePreviewPath.") + bindPFlagCheckErr("filePreviewString") + + ftCmd.Flags().Int("maxThroughput", 0, + "Maximum data transfer speed to send file parts (in bytes per second)") + bindPFlagCheckErr("maxThroughput") + + ftCmd.Flags().Float64("retry", 0.5, + "Retry rate.") + bindPFlagCheckErr("retry") + + rootCmd.AddCommand(ftCmd) +} + +// timeNanoString returns the current UNIX time in nanoseconds as a string. +func timeNanoString() string { + return strconv.Itoa(int(netTime.Now().UnixNano())) +} + +// bindPFlagCheckErr binds the key to a pflag.Flag used by Cobra and prints an +// error if one occurs. +func bindPFlagCheckErr(key string) { + err := viper.BindPFlag(key, ftCmd.Flags().Lookup(key)) + if err != nil { + jww.ERROR.Printf("viper.BindPFlag failed for %q: %+v", key, err) + } +} diff --git a/dummy/utils_test.go b/dummy/utils_test.go index 7312effd4d1d7b9d138e467d764f08fd5ddbd401..5ad1a5b65ceb6770897795ff4b13c9ce63866d1a 100644 --- a/dummy/utils_test.go +++ b/dummy/utils_test.go @@ -133,7 +133,7 @@ func (tnm *testNetworkManager) SendCMIX(message format.Message, return 0, ephemeral.Id{}, nil } -func (tnm *testNetworkManager) SendManyCMIX(map[id.ID]format.Message, params.CMIX) ( +func (tnm *testNetworkManager) SendManyCMIX([]message.TargetedCmixMessage, params.CMIX) ( id.Round, []ephemeral.Id, error) { return 0, nil, nil } diff --git a/fileTransfer/fileMessage.go b/fileTransfer/fileMessage.go new file mode 100644 index 0000000000000000000000000000000000000000..b73b948c90cc4d99bad2d459be6a8f98df61f2a5 --- /dev/null +++ b/fileTransfer/fileMessage.go @@ -0,0 +1,130 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "encoding/binary" + "github.com/pkg/errors" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" +) + +// Size constants. +const ( + paddingLen = ftCrypto.NonceSize // The length of the padding in bytes + partNumLen = 2 // The length of the part number in bytes + fmMinSize = partNumLen + paddingLen // Minimum size for the partMessage +) + +// Error messages. +const ( + newFmSizeErr = "size of external payload (%d) must be greater than %d" + unmarshalFmSizeErr = "size of passed in bytes (%d) must be greater than %d" + setFileFmErr = "length of part bytes (%d) must be smaller than maximum payload size %d" +) + +/* ++-----------------------------------------+ +| CMIX Message Contents | ++---------+-------------+-----------------+ +| Padding | Part Number | File Data | +| 8 bytes | 2 bytes | remaining space | ++---------+-------------+-----------------+ +*/ + +// partMessage contains part of the data being transferred and 256-bit padding +// that is used as a nonce. +type partMessage struct { + data []byte // Serial of all contents + padding []byte // Random padding bytes + partNum []byte // The part number of the file + part []byte // File part data +} + +// newPartMessage generates a new part message that fits into the specified +// external payload size. An error is returned if the external payload size is +// too small to fit the part message. +func newPartMessage(externalPayloadSize int) (partMessage, error) { + if externalPayloadSize < fmMinSize { + return partMessage{}, + errors.Errorf(newFmSizeErr, externalPayloadSize, fmMinSize) + } + + return mapPartMessage(make([]byte, externalPayloadSize)), nil +} + +// mapPartMessage maps the data to the components of a partMessage. It is mapped +// by reference; a copy is not made. +func mapPartMessage(data []byte) partMessage { + return partMessage{ + data: data, + padding: data[:paddingLen], + partNum: data[paddingLen : paddingLen+partNumLen], + part: data[paddingLen+partNumLen:], + } +} + +// unmarshalPartMessage converts the bytes into a partMessage. An error is +// returned if the size of the data is too small for a partMessage. +func unmarshalPartMessage(b []byte) (partMessage, error) { + if len(b) < fmMinSize { + return partMessage{}, + errors.Errorf(unmarshalFmSizeErr, len(b), fmMinSize) + } + + return mapPartMessage(b), nil +} + +// marshal returns the byte representation of the partMessage. +func (m partMessage) marshal() []byte { + return m.data +} + +// getPadding returns the padding in the message. +func (m partMessage) getPadding() []byte { + return m.padding +} + +// setPadding sets the partMessage padding to the given bytes. Note that this +// padding should be random bytes generated via the appropriate crypto function. +func (m partMessage) setPadding(b []byte) { + copy(m.padding, b) +} + +// getPartNum returns the file part number. +func (m partMessage) getPartNum() uint16 { + return binary.LittleEndian.Uint16(m.partNum) +} + +// setPartNum sets the file part number. +func (m partMessage) setPartNum(num uint16) { + b := make([]byte, partNumLen) + binary.LittleEndian.PutUint16(b, num) + copy(m.partNum, b) +} + +// getPart returns the file part data from the message. +func (m partMessage) getPart() []byte { + return m.part +} + +// setPart sets the partMessage part to the given bytes. An error is returned if +// the size of the provided part data is too large to store. +func (m partMessage) setPart(b []byte) error { + if len(b) > len(m.part) { + return errors.Errorf(setFileFmErr, len(b), len(m.part)) + } + + copy(m.part, b) + + return nil +} + +// getPartSize returns the number of bytes available to store part data. +func (m partMessage) getPartSize() int { + return len(m.part) +} diff --git a/fileTransfer/fileMessage_test.go b/fileTransfer/fileMessage_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7189d95facd053c1815a4a08fb210620538e13c5 --- /dev/null +++ b/fileTransfer/fileMessage_test.go @@ -0,0 +1,281 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "encoding/binary" + "fmt" + "math/rand" + "testing" +) + +// Tests that newPartMessage returns a partMessage of the expected size. +func Test_newPartMessage(t *testing.T) { + externalPayloadSize := 256 + + fm, err := newPartMessage(externalPayloadSize) + if err != nil { + t.Errorf("newPartMessage returned an error: %+v", err) + } + + if len(fm.data) != externalPayloadSize { + t.Errorf("Size of partMessage data does not match payload size."+ + "\nexpected: %d\nreceived: %d", externalPayloadSize, len(fm.data)) + } +} + +// Error path: tests that newPartMessage returns the expected error when the +// external payload size is too small. +func Test_newPartMessage_SmallPayloadSizeError(t *testing.T) { + externalPayloadSize := fmMinSize - 1 + expectedErr := fmt.Sprintf(newFmSizeErr, externalPayloadSize, fmMinSize) + + _, err := newPartMessage(externalPayloadSize) + if err == nil || err.Error() != expectedErr { + t.Errorf("newPartMessage did not return the expected error when the "+ + "given external payload size is too small."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that mapPartMessage maps the data to the correct parts of the +// partMessage. +func Test_mapPartMessage(t *testing.T) { + // Generate expected values + _, expectedData, expectedPadding, expectedPartNum, expectedFile := newRandomFileMessage() + + fm := mapPartMessage(expectedData) + + if !bytes.Equal(expectedData, fm.data) { + t.Errorf("Incorrect data.\nexpected: %q\nreceived: %q", + expectedData, fm.data) + } + + if !bytes.Equal(expectedPadding, fm.padding) { + t.Errorf("Incorrect padding data.\nexpected: %q\nreceived: %q", + expectedPadding, fm.padding) + } + + if !bytes.Equal(expectedPartNum, fm.partNum) { + t.Errorf("Incorrect part number.\nexpected: %q\nreceived: %q", + expectedPartNum, fm.partNum) + } + + if !bytes.Equal(expectedFile, fm.part) { + t.Errorf("Incorrect part data.\nexpected: %q\nreceived: %q", + expectedFile, fm.part) + } + +} + +// Tests that unmarshalPartMessage returns a partMessage with the expected +// values. +func Test_unmarshalPartMessage(t *testing.T) { + // Generate expected values + _, expectedData, expectedPadding, expectedPartNumb, expectedFile := newRandomFileMessage() + + fm, err := unmarshalPartMessage(expectedData) + if err != nil { + t.Errorf("unmarshalPartMessage return an error: %+v", err) + } + + if !bytes.Equal(expectedData, fm.data) { + t.Errorf("Incorrect data.\nexpected: %q\nreceived: %q", + expectedData, fm.data) + } + + if !bytes.Equal(expectedPadding, fm.padding) { + t.Errorf("Incorrect padding data.\nexpected: %q\nreceived: %q", + expectedPadding, fm.padding) + } + + if !bytes.Equal(expectedPartNumb, fm.partNum) { + t.Errorf("Incorrect part number.\nexpected: %q\nreceived: %q", + expectedPartNumb, fm.partNum) + } + + if !bytes.Equal(expectedFile, fm.part) { + t.Errorf("Incorrect part data.\nexpected: %q\nreceived: %q", + expectedFile, fm.part) + } +} + +// Error path: tests that unmarshalPartMessage returns the expected error when +// the provided data is too small to be unmarshalled into a partMessage. +func Test_unmarshalPartMessage_SizeError(t *testing.T) { + data := make([]byte, fmMinSize-1) + expectedErr := fmt.Sprintf(unmarshalFmSizeErr, len(data), fmMinSize) + + _, err := unmarshalPartMessage(data) + if err == nil || err.Error() != expectedErr { + t.Errorf("unmarshalPartMessage did not return the expected error when "+ + "the given bytes are too small to be a partMessage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that partMessage.marshal returns the correct data. +func Test_fileMessage_marshal(t *testing.T) { + fm, expectedData, _, _, _ := newRandomFileMessage() + + data := fm.marshal() + + if !bytes.Equal(expectedData, data) { + t.Errorf("Marshalled data does not match expected."+ + "\nexpected: %q\nreceived: %q", expectedData, data) + } +} + +// Tests that partMessage.getPadding returns the correct padding data. +func Test_fileMessage_getPadding(t *testing.T) { + fm, _, expectedPadding, _, _ := newRandomFileMessage() + + padding := fm.getPadding() + + if !bytes.Equal(expectedPadding, padding) { + t.Errorf("Padding data does not match expected."+ + "\nexpected: %q\nreceived: %q", expectedPadding, padding) + } +} + +// Tests that partMessage.setPadding sets the correct data. +func Test_fileMessage_setPadding(t *testing.T) { + fm, err := newPartMessage(256) + if err != nil { + t.Errorf("Failed to create new partMessage: %+v", err) + } + + expectedPadding := make([]byte, paddingLen) + rand.New(rand.NewSource(42)).Read(expectedPadding) + + fm.setPadding(expectedPadding) + + if !bytes.Equal(expectedPadding, fm.getPadding()) { + t.Errorf("Failed to set correct padding.\nexpected: %q\nreceived: %q", + expectedPadding, fm.getPadding()) + } +} + +// Tests that partMessage.getPartNum returns the correct part number. +func Test_fileMessage_getPartNum(t *testing.T) { + fm, _, _, expectedPartNum, _ := newRandomFileMessage() + + partNum := fm.getPartNum() + expected := binary.LittleEndian.Uint16(expectedPartNum) + + if expected != partNum { + t.Errorf("Part number does not match expected."+ + "\nexpected: %d\nreceived: %d", expected, partNum) + } +} + +// Tests that partMessage.setPartNum sets the correct part number. +func Test_fileMessage_setPartNum(t *testing.T) { + fm, err := newPartMessage(256) + if err != nil { + t.Errorf("Failed to create new partMessage: %+v", err) + } + + expectedPartNum := make([]byte, partNumLen) + rand.New(rand.NewSource(42)).Read(expectedPartNum) + expected := binary.LittleEndian.Uint16(expectedPartNum) + + fm.setPartNum(expected) + + if expected != fm.getPartNum() { + t.Errorf("Failed to set correct part number.\nexpected: %d\nreceived: %d", + expected, fm.getPartNum()) + } +} + +// Tests that partMessage.getPart returns the correct part data. +func Test_fileMessage_getFile(t *testing.T) { + fm, _, _, _, expectedFile := newRandomFileMessage() + + file := fm.getPart() + + if !bytes.Equal(expectedFile, file) { + t.Errorf("File data does not match expected."+ + "\nexpected: %q\nreceived: %q", expectedFile, file) + } +} + +// Tests that partMessage.setPart sets the correct part data. +func Test_fileMessage_setFile(t *testing.T) { + fm, err := newPartMessage(256) + if err != nil { + t.Errorf("Failed to create new partMessage: %+v", err) + } + + fileData := make([]byte, 64) + rand.New(rand.NewSource(42)).Read(fileData) + expectedFile := make([]byte, fm.getPartSize()) + copy(expectedFile, fileData) + + err = fm.setPart(expectedFile) + if err != nil { + t.Errorf("setPart returned an error: %+v", err) + } + + if !bytes.Equal(expectedFile, fm.getPart()) { + t.Errorf("Failed to set correct part data.\nexpected: %q\nreceived: %q", + expectedFile, fm.getPart()) + } +} + +// Error path: tests that partMessage.setPart returns the expected error when +// the provided part data is too large for the message. +func Test_fileMessage_setFile_FileTooLargeError(t *testing.T) { + fm, err := newPartMessage(fmMinSize + 1) + if err != nil { + t.Errorf("Failed to create new partMessage: %+v", err) + } + + expectedErr := fmt.Sprintf(setFileFmErr, fm.getPartSize()+1, fm.getPartSize()) + + err = fm.setPart(make([]byte, fm.getPartSize()+1)) + if err == nil || err.Error() != expectedErr { + t.Errorf("setPart did not return the expected error when the given "+ + "part data is too large to fit in the partMessage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that partMessage.getPartSize returns the expected available space for +// the part data. +func Test_fileMessage_getFileSize(t *testing.T) { + expectedSize := 256 + + fm, err := newPartMessage(fmMinSize + expectedSize) + if err != nil { + t.Errorf("Failed to create new partMessage: %+v", err) + } + + if expectedSize != fm.getPartSize() { + t.Errorf("File size incorrect.\nexpected: %d\nreceived: %d", + expectedSize, fm.getPartSize()) + } +} + +// newRandomFileMessage generates a new partMessage filled with random data and +// return the partMessage and its individual parts. +func newRandomFileMessage() (partMessage, []byte, []byte, []byte, []byte) { + prng := rand.New(rand.NewSource(42)) + padding := make([]byte, paddingLen) + prng.Read(padding) + partNum := make([]byte, partNumLen) + prng.Read(partNum) + part := make([]byte, 64) + prng.Read(part) + data := append(append(padding, partNum...), part...) + + fm := mapPartMessage(data) + + return fm, data, padding, partNum, part +} diff --git a/fileTransfer/ftMessages.pb.go b/fileTransfer/ftMessages.pb.go new file mode 100644 index 0000000000000000000000000000000000000000..6ca9bdba1d70a40232bab6e7b77eae833defb9e9 --- /dev/null +++ b/fileTransfer/ftMessages.pb.go @@ -0,0 +1,131 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: fileTransfer/ftMessages.proto + +package fileTransfer + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type NewFileTransfer struct { + FileName string `protobuf:"bytes,1,opt,name=fileName,proto3" json:"fileName,omitempty"` + TransferKey []byte `protobuf:"bytes,2,opt,name=transferKey,proto3" json:"transferKey,omitempty"` + TransferMac []byte `protobuf:"bytes,3,opt,name=transferMac,proto3" json:"transferMac,omitempty"` + NumParts uint32 `protobuf:"varint,4,opt,name=numParts,proto3" json:"numParts,omitempty"` + Size uint32 `protobuf:"varint,5,opt,name=size,proto3" json:"size,omitempty"` + Retry float32 `protobuf:"fixed32,6,opt,name=retry,proto3" json:"retry,omitempty"` + Preview []byte `protobuf:"bytes,7,opt,name=preview,proto3" json:"preview,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *NewFileTransfer) Reset() { *m = NewFileTransfer{} } +func (m *NewFileTransfer) String() string { return proto.CompactTextString(m) } +func (*NewFileTransfer) ProtoMessage() {} +func (*NewFileTransfer) Descriptor() ([]byte, []int) { + return fileDescriptor_9d574f363dd34365, []int{0} +} + +func (m *NewFileTransfer) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_NewFileTransfer.Unmarshal(m, b) +} +func (m *NewFileTransfer) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_NewFileTransfer.Marshal(b, m, deterministic) +} +func (m *NewFileTransfer) XXX_Merge(src proto.Message) { + xxx_messageInfo_NewFileTransfer.Merge(m, src) +} +func (m *NewFileTransfer) XXX_Size() int { + return xxx_messageInfo_NewFileTransfer.Size(m) +} +func (m *NewFileTransfer) XXX_DiscardUnknown() { + xxx_messageInfo_NewFileTransfer.DiscardUnknown(m) +} + +var xxx_messageInfo_NewFileTransfer proto.InternalMessageInfo + +func (m *NewFileTransfer) GetFileName() string { + if m != nil { + return m.FileName + } + return "" +} + +func (m *NewFileTransfer) GetTransferKey() []byte { + if m != nil { + return m.TransferKey + } + return nil +} + +func (m *NewFileTransfer) GetTransferMac() []byte { + if m != nil { + return m.TransferMac + } + return nil +} + +func (m *NewFileTransfer) GetNumParts() uint32 { + if m != nil { + return m.NumParts + } + return 0 +} + +func (m *NewFileTransfer) GetSize() uint32 { + if m != nil { + return m.Size + } + return 0 +} + +func (m *NewFileTransfer) GetRetry() float32 { + if m != nil { + return m.Retry + } + return 0 +} + +func (m *NewFileTransfer) GetPreview() []byte { + if m != nil { + return m.Preview + } + return nil +} + +func init() { + proto.RegisterType((*NewFileTransfer)(nil), "parse.NewFileTransfer") +} + +func init() { proto.RegisterFile("fileTransfer/ftMessages.proto", fileDescriptor_9d574f363dd34365) } + +var fileDescriptor_9d574f363dd34365 = []byte{ + // 202 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x4d, 0xcb, 0xcc, 0x49, + 0x0d, 0x29, 0x4a, 0xcc, 0x2b, 0x4e, 0x4b, 0x2d, 0xd2, 0x4f, 0x2b, 0xf1, 0x4d, 0x2d, 0x2e, 0x4e, + 0x4c, 0x4f, 0x2d, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2d, 0x48, 0x2c, 0x2a, 0x4e, + 0x55, 0xba, 0xc8, 0xc8, 0xc5, 0xef, 0x97, 0x5a, 0xee, 0x86, 0xa4, 0x56, 0x48, 0x8a, 0x8b, 0x03, + 0xa4, 0xd7, 0x2f, 0x31, 0x37, 0x55, 0x82, 0x51, 0x81, 0x51, 0x83, 0x33, 0x08, 0xce, 0x17, 0x52, + 0xe0, 0xe2, 0x2e, 0x81, 0xaa, 0xf3, 0x4e, 0xad, 0x94, 0x60, 0x52, 0x60, 0xd4, 0xe0, 0x09, 0x42, + 0x16, 0x42, 0x56, 0xe1, 0x9b, 0x98, 0x2c, 0xc1, 0x8c, 0xaa, 0xc2, 0x37, 0x31, 0x19, 0x64, 0x7e, + 0x5e, 0x69, 0x6e, 0x40, 0x62, 0x51, 0x49, 0xb1, 0x04, 0x8b, 0x02, 0xa3, 0x06, 0x6f, 0x10, 0x9c, + 0x2f, 0x24, 0xc4, 0xc5, 0x52, 0x9c, 0x59, 0x95, 0x2a, 0xc1, 0x0a, 0x16, 0x07, 0xb3, 0x85, 0x44, + 0xb8, 0x58, 0x8b, 0x52, 0x4b, 0x8a, 0x2a, 0x25, 0xd8, 0x14, 0x18, 0x35, 0x98, 0x82, 0x20, 0x1c, + 0x21, 0x09, 0x2e, 0xf6, 0x82, 0xa2, 0xd4, 0xb2, 0xcc, 0xd4, 0x72, 0x09, 0x76, 0xb0, 0x1d, 0x30, + 0xae, 0x13, 0x5f, 0x14, 0x0f, 0xb2, 0xdf, 0x93, 0xd8, 0xc0, 0x3e, 0x36, 0x06, 0x04, 0x00, 0x00, + 0xff, 0xff, 0x3e, 0x0f, 0x1c, 0x27, 0x12, 0x01, 0x00, 0x00, +} diff --git a/fileTransfer/ftMessages.proto b/fileTransfer/ftMessages.proto new file mode 100644 index 0000000000000000000000000000000000000000..e90449697bb5c9b3d9b4a40132b651aee6298d02 --- /dev/null +++ b/fileTransfer/ftMessages.proto @@ -0,0 +1,21 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +syntax = "proto3"; + +package parse; +option go_package = "fileTransfer"; + +message NewFileTransfer { + string fileName = 1; // Name of the file; max 32 characters + bytes transferKey = 2; // 256 bit encryption key to identify the transfer + bytes transferMac = 3; // 256 bit MAC of the entire file + uint32 numParts = 4; // Number of file parts + uint32 size = 5; // The size of the file; max of 4 mB + float retry = 6; // Used to determine how many times to retry sending + bytes preview = 7; // A preview of the file; max of 4 kB +} \ No newline at end of file diff --git a/fileTransfer/generateProto.sh b/fileTransfer/generateProto.sh new file mode 100644 index 0000000000000000000000000000000000000000..fd7eea2f23c4a365dcf68948eaa617c783ad1285 --- /dev/null +++ b/fileTransfer/generateProto.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +protoc --go_out=paths=source_relative:. fileTransfer/ftMessages.proto diff --git a/fileTransfer/manager.go b/fileTransfer/manager.go new file mode 100644 index 0000000000000000000000000000000000000000..e99fc971b4c85c55ec00d918ca025ba280611e15 --- /dev/null +++ b/fileTransfer/manager.go @@ -0,0 +1,391 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "github.com/pkg/errors" + "gitlab.com/elixxir/client/api" + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" + "gitlab.com/elixxir/client/storage" + ftStorage "gitlab.com/elixxir/client/storage/fileTransfer" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/crypto/fastRNG" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/primitives/id" + "time" +) + +const ( + // PreviewMaxSize is the maximum size, in bytes, for a file preview. + // Currently, it is set to 4 kB. + PreviewMaxSize = 4_000 + + // FileNameMaxLen is the maximum size, in bytes, for a file name. Currently, + // it is set to 32 bytes. + FileNameMaxLen = 32 + + // FileMaxSize is the maximum file size that can be transferred. Currently, + // it is set to 4 mB. + FileMaxSize = 4_000_000 + + // minPartsSendPerRound is the minimum number of file parts sent each round. + minPartsSendPerRound = 1 + + // maxPartsSendPerRound is the maximum number of file parts sent each round. + maxPartsSendPerRound = 11 + + // Size of the buffered channel that queues file parts to send + sendQueueBuffLen = 10_000 + + // Size of the buffered channel that reports if the network is healthy + networkHealthBuffLen = 10_000 +) + +// Error messages. +const ( + // newManager + newManagerSentErr = "failed to load or create new list of sent file transfers: %+v" + newManagerReceivedErr = "failed to load or create new list of received file transfers: %+v" + + // Manager.Send + fileNameSizeErr = "length of filename (%d) greater than max allowed length (%d)" + fileSizeErr = "size of file (%d bytes) greater than max allowed size (%d bytes)" + previewSizeErr = "size of preview (%d bytes) greater than max allowed size (%d bytes)" + getPartSizeErr = "failed to get file part size: %+v" + sendInitMsgErr = "failed to send initial file transfer message: %+v" + + // Manager.Resend + transferNotFailedErr = "transfer %s has not failed" + + // Manager.CloseSend + transferInProgressErr = "transfer %s has not completed or failed" +) + +// Stoppable and listener values. +const ( + rawMessageBuffSize = 10_000 + sendStoppableName = "FileTransferSend" + newFtStoppableName = "FileTransferNew" + newFtListenerName = "FileTransferNewListener" + filePartStoppableName = "FilePart" + filePartListenerName = "FilePartListener" + fileTransferStoppableName = "FileTransfer" +) + +// Manager is used to manage the sending and receiving of all file transfers. +type Manager struct { + // Callback that is called every time a new file transfer is received + receiveCB interfaces.ReceiveCallback + + // Storage-backed structure for tracking sent file transfers + sent *ftStorage.SentFileTransfers + + // Storage-backed structure for tracking received file transfers + received *ftStorage.ReceivedFileTransfers + + // Queue of parts to send + sendQueue chan queuedPart + + // Maximum data transfer speed in bytes per second + maxThroughput int + + // Client interfaces + client *api.Client + store *storage.Session + swb interfaces.Switchboard + net interfaces.NetworkManager + healthy chan bool + rng *fastRNG.StreamGenerator + getRoundResults getRoundResultsFunc +} + +// getRoundResultsFunc is a function that matches client.GetRoundResults. It is +// used to pass in an alternative function for testing. +type getRoundResultsFunc func(roundList []id.Round, timeout time.Duration, + roundCallback api.RoundEventCallback) error + +// queuedPart contains the unique information identifying a file part. +type queuedPart struct { + tid ftCrypto.TransferID + partNum uint16 +} + +// NewManager produces a new empty file transfer Manager. Does not start sending +// and receiving services. +func NewManager(client *api.Client, receiveCB interfaces.ReceiveCallback, + p Params) (*Manager, error) { + return newManager(client, client.GetStorage(), client.GetSwitchboard(), + client.GetNetworkInterface(), client.GetRng(), client.GetRoundResults, + client.GetStorage().GetKV(), receiveCB, p) +} + +// newManager builds the manager from fields explicitly passed in. This function +// is a helper function for NewManager to make it easier to test. +func newManager(client *api.Client, store *storage.Session, + swb interfaces.Switchboard, net interfaces.NetworkManager, + rng *fastRNG.StreamGenerator, getRoundResults getRoundResultsFunc, + kv *versioned.KV, receiveCB interfaces.ReceiveCallback, p Params) ( + *Manager, error) { + + // Create a new list of sent file transfers or load one if it exists in + // storage + sent, err := ftStorage.NewOrLoadSentFileTransfers(kv) + if err != nil { + return nil, errors.Errorf(newManagerSentErr, err) + } + + // Create a new list of received file transfers or load one if it exists in + // storage + received, err := ftStorage.NewOrLoadReceivedFileTransfers(kv) + if err != nil { + return nil, errors.Errorf(newManagerReceivedErr, err) + } + + return &Manager{ + receiveCB: receiveCB, + sent: sent, + received: received, + sendQueue: make(chan queuedPart, sendQueueBuffLen), + maxThroughput: p.MaxThroughput, + client: client, + store: store, + swb: swb, + net: net, + healthy: make(chan bool, networkHealthBuffLen), + rng: rng, + getRoundResults: getRoundResults, + }, nil +} + +// StartProcesses starts the processes needed to send and receive file parts. It +// starts three threads that (1) receives the initial NewFileTransfer E2E +// message; (2) receives each file part; and (3) sends file parts. It also +// registers the network health channel. +func (m *Manager) StartProcesses() (stoppable.Stoppable, error) { + // Create the two reception channels + newFtChan := make(chan message.Receive, rawMessageBuffSize) + filePartChan := make(chan message.Receive, rawMessageBuffSize) + + return m.startProcesses(newFtChan, filePartChan) +} + +// startProcesses starts the sending and receiving processes with the provided +// channels. +func (m *Manager) startProcesses(newFtChan, filePartChan chan message.Receive) ( + stoppable.Stoppable, error) { + + // Register network health channel that is used by the sending thread to + // ensure the network is healthy before sending + m.net.GetHealthTracker().AddChannel(m.healthy) + + // Start the new file transfer message reception thread + newFtStop := stoppable.NewSingle(newFtStoppableName) + m.swb.RegisterChannel(newFtListenerName, &id.ID{}, + message.NewFileTransfer, newFtChan) + go m.receiveNewFileTransfer(newFtChan, newFtStop) + + // Start the file part message reception thread + filePartStop := stoppable.NewSingle(filePartStoppableName) + m.swb.RegisterChannel(filePartListenerName, &id.ID{}, message.Raw, + filePartChan) + go m.receive(filePartChan, filePartStop) + + // Start the file part sending thread + sendStop := stoppable.NewSingle(sendStoppableName) + go m.sendThread(sendStop, getRandomNumParts) + + // Create a multi stoppable + multiStoppable := stoppable.NewMulti(fileTransferStoppableName) + multiStoppable.Add(newFtStop) + multiStoppable.Add(filePartStop) + multiStoppable.Add(sendStop) + + return multiStoppable, nil +} + +// Send starts the sending of a file transfer to the recipient. It sends the +// initial NewFileTransfer E2E message to the recipient to inform them of the +// incoming file parts. It partitions the file, puts it into storage, and queues +// each file for sending. Returns a unique ID identifying the file transfer. +func (m Manager) Send(fileName string, fileData []byte, recipient *id.ID, + retry float32, preview []byte, progressCB interfaces.SentProgressCallback, + period time.Duration) (ftCrypto.TransferID, error) { + + // Return an error if the file name is too long + if len(fileName) > FileNameMaxLen { + return ftCrypto.TransferID{}, errors.Errorf( + fileNameSizeErr, len(fileName), FileNameMaxLen) + } + + // Return an error if the file is too large + if len(fileData) > FileMaxSize { + return ftCrypto.TransferID{}, errors.Errorf( + fileSizeErr, len(fileData), FileMaxSize) + } + + // Return an error if the preview is too large + if len(preview) > PreviewMaxSize { + return ftCrypto.TransferID{}, errors.Errorf( + previewSizeErr, len(preview), PreviewMaxSize) + } + + // Generate new transfer key + rng := m.rng.GetStream() + transferKey, err := ftCrypto.NewTransferKey(rng) + if err != nil { + rng.Close() + return ftCrypto.TransferID{}, err + } + rng.Close() + + // Get the size of each file part + partSize, err := m.getPartSize() + if err != nil { + return ftCrypto.TransferID{}, errors.Errorf(getPartSizeErr, err) + } + + // Generate transfer MAC + mac := ftCrypto.CreateTransferMAC(fileData, transferKey) + + // Partition the file into parts + parts := partitionFile(fileData, partSize) + numParts := uint16(len(parts)) + fileSize := uint32(len(fileData)) + + // Send the initial file transfer message over E2E + err = m.sendNewFileTransfer(recipient, fileName, transferKey, mac, numParts, + fileSize, retry, preview) + if err != nil { + return ftCrypto.TransferID{}, errors.Errorf(sendInitMsgErr, err) + } + + // Calculate the number of fingerprints to generate + numFps := calcNumberOfFingerprints(numParts, retry) + + // Add the transfer to storage + rng = m.rng.GetStream() + transferID, err := m.sent.AddTransfer( + recipient, transferKey, parts, numFps, progressCB, period, rng) + if err != nil { + return ftCrypto.TransferID{}, err + } + rng.Close() + + m.queueParts(transferID, numParts) + + return transferID, nil +} + +// RegisterSendProgressCallback adds the sent progress callback to the sent +// transfer so that it will be called when updates for the transfer occur. The +// progress callback is called when initially added and on transfer updates, at +// most once per period. +func (m Manager) RegisterSendProgressCallback(tid ftCrypto.TransferID, + progressCB interfaces.SentProgressCallback, period time.Duration) error { + // Get the transfer for the given ID + transfer, err := m.sent.GetTransfer(tid) + if err != nil { + return err + } + + // Add the progress callback + transfer.AddProgressCB(progressCB, period) + + return nil +} + +// Resend resends a file if Send fails. Returns an error if CloseSend +// was already called or if the transfer did not run out of retries. +// TODO: add test +// TODO: write test +// TODO: can you resend? Can you reuse fingerprints? +// TODO: what to do if sendE2E fails? +func (m Manager) Resend(tid ftCrypto.TransferID) error { + // Get the transfer for the given ID + transfer, err := m.sent.GetTransfer(tid) + if err != nil { + return err + } + + // Check if the transfer has run out of fingerprints, which occurs when the + // retry limit is reached + if transfer.GetNumAvailableFps() > 0 { + return errors.Errorf(transferNotFailedErr, tid) + } + + return nil +} + +// CloseSend deletes a sent file transfer from the sent transfer map and from +// storage once a transfer has completed or reached the retry limit. Returns an +// error if the transfer has not run out of retries. +func (m Manager) CloseSend(tid ftCrypto.TransferID) error { + // Get the transfer for the given ID + transfer, err := m.sent.GetTransfer(tid) + if err != nil { + return err + } + + // Check if the transfer has completed or run out of fingerprints, which + // occurs when the retry limit is reached + completed, _, _, _ := transfer.GetProgress() + if transfer.GetNumAvailableFps() > 0 && !completed { + return errors.Errorf(transferInProgressErr, tid) + } + + // Delete the transfer from storage + return m.sent.DeleteTransfer(tid) +} + +// RegisterReceiveProgressCallback adds the reception progress callback to the +// received transfer so that it will be called when updates for the transfer +// occur. The progress callback is called when initially added and on transfer +// updates, at most once per period. +func (m Manager) RegisterReceiveProgressCallback(tid ftCrypto.TransferID, + progressCB interfaces.ReceivedProgressCallback, period time.Duration) error { + // Get the transfer for the given ID + transfer, err := m.received.GetTransfer(tid) + if err != nil { + return err + } + + // Add the progress callback + transfer.AddProgressCB(progressCB, period) + + return nil +} + +// Receive returns the fully assembled file on the completion of the transfer. +// It deletes the transfer from the received transfer map and from storage. +// Returns an error if the transfer is not complete, the full file cannot be +// verified, or if the transfer cannot be found. +func (m Manager) Receive(tid ftCrypto.TransferID) ([]byte, error) { + // Get the transfer for the given ID + transfer, err := m.received.GetTransfer(tid) + if err != nil { + return nil, err + } + + // Get the file from the transfer + file, err := transfer.GetFile() + if err != nil { + return nil, err + } + + // Return the file and delete the transfer from storage + return file, m.received.DeleteTransfer(tid) +} + +// calcNumberOfFingerprints is the formula used to calculate the number of +// fingerprints to generate, which is based off the number of file parts and the +// retry float. +func calcNumberOfFingerprints(numParts uint16, retry float32) uint16 { + return uint16(float32(numParts) * (1 + retry)) +} diff --git a/fileTransfer/manager_test.go b/fileTransfer/manager_test.go new file mode 100644 index 0000000000000000000000000000000000000000..4767dfa09d84f1e67f1dd8ffa6390c4f48385e22 --- /dev/null +++ b/fileTransfer/manager_test.go @@ -0,0 +1,681 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "errors" + "fmt" + "github.com/golang/protobuf/proto" + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/interfaces/message" + ftStorage "gitlab.com/elixxir/client/storage/fileTransfer" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/primitives/id" + "reflect" + "strings" + "sync" + "testing" + "time" +) + +// Tests that newManager does not return errors, that the sent and received +// transfer lists are new, and that the callback works. +func Test_newManager(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + + cbChan := make(chan bool) + cb := func(ftCrypto.TransferID, string, *id.ID, uint32, []byte) { + cbChan <- true + } + + m, err := newManager(nil, nil, nil, nil, nil, nil, kv, cb, DefaultParams()) + if err != nil { + t.Errorf("newManager returned an error: %+v", err) + } + + // Check that the SentFileTransfers is new and correct + expectedSent, _ := ftStorage.NewSentFileTransfers(kv) + if !reflect.DeepEqual(expectedSent, m.sent) { + t.Errorf("SentFileTransfers in manager incorrect."+ + "\nexpected: %+v\nreceived: %+v", expectedSent, m.sent) + } + + // Check that the ReceivedFileTransfers is new and correct + expectedReceived, _ := ftStorage.NewReceivedFileTransfers(kv) + if !reflect.DeepEqual(expectedReceived, m.received) { + t.Errorf("ReceivedFileTransfers in manager incorrect."+ + "\nexpected: %+v\nreceived: %+v", expectedReceived, m.received) + } + + // Check that the callback is called + go m.receiveCB(ftCrypto.TransferID{}, "", nil, 0, nil) + select { + case <-cbChan: + case <-time.NewTimer(time.Millisecond).C: + t.Error("Timed out waiting for callback to be called") + } +} + +// Tests that Manager.Send adds a new sent transfer, sends the NewFileTransfer +// E2E message, and adds all the file parts to the queue. +func TestManager_Send(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + prng := NewPrng(42) + recipient := id.NewIdFromString("recipient", id.User, t) + fileName := "testFile" + numParts := uint16(16) + partSize, _ := m.getPartSize() + fileData, _ := newFile(numParts, uint32(partSize), prng, t) + preview := []byte("filePreview") + retry := float32(1.5) + numFps := calcNumberOfFingerprints(numParts, retry) + + tid, err := m.Send(fileName, fileData, recipient, retry, preview, nil, 0) + if err != nil { + t.Errorf("Send returned an error: %+v", err) + } + + //// + // Check if the transfer exists + //// + transfer, err := m.sent.GetTransfer(tid) + if err != nil { + t.Errorf("Failed to get transfer %s: %+v", tid, err) + } + + if !recipient.Cmp(transfer.GetRecipient()) { + t.Errorf("New transfer has incorrect recipient."+ + "\nexpected: %s\nreceoved: %s", recipient, transfer.GetRecipient()) + } + if transfer.GetNumParts() != numParts { + t.Errorf("New transfer has incorrect number of parts."+ + "\nexpected: %d\nreceived: %d", numParts, transfer.GetNumParts()) + } + if transfer.GetNumFps() != numFps { + t.Errorf("New transfer has incorrect number of fingerprints."+ + "\nexpected: %d\nreceived: %d", numFps, transfer.GetNumFps()) + } + + //// + // Get NewFileTransfer E2E message + //// + sendMsg := m.net.(*testNetworkManager).GetE2eMsg(0) + if sendMsg.MessageType != message.NewFileTransfer { + t.Errorf("E2E message has wrong MessageType.\nexpected: %d\nreceived: %d", + message.NewFileTransfer, sendMsg.MessageType) + } + if !sendMsg.Recipient.Cmp(recipient) { + t.Errorf("E2E message has wrong Recipient.\nexpected: %s\nreceived: %s", + recipient, sendMsg.Recipient) + } + receivedNFT := &NewFileTransfer{} + err = proto.Unmarshal(sendMsg.Payload, receivedNFT) + if err != nil { + t.Errorf("Failed to unmarshal received NewFileTransfer: %+v", err) + } + expectedNFT := &NewFileTransfer{ + FileName: fileName, + TransferKey: transfer.GetTransferKey().Bytes(), + TransferMac: ftCrypto.CreateTransferMAC(fileData, transfer.GetTransferKey()), + NumParts: uint32(numParts), + Size: uint32(len(fileData)), + Retry: retry, + Preview: preview, + } + if !reflect.DeepEqual(expectedNFT, receivedNFT) { + t.Errorf("Received NewFileTransfer message does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expectedNFT, receivedNFT) + } + + //// + // Check queued parts + //// + if len(m.sendQueue) != int(numParts) { + t.Errorf("Failed to add all file parts to queue."+ + "\nexpected: %d\nreceived: %d", numParts, len(m.sendQueue)) + } +} + +// Error path: tests that Manager.Send returns the expected error when the +// provided file name is longer than FileNameMaxLen. +func TestManager_Send_FileNameLengthError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + + fileName := strings.Repeat("A", FileNameMaxLen+1) + expectedErr := fmt.Sprintf(fileNameSizeErr, len(fileName), FileNameMaxLen) + + _, err := m.Send(fileName, nil, nil, 0, nil, nil, 0) + if err == nil || err.Error() != expectedErr { + t.Errorf("Send did not return the expected error when the file name "+ + "is too long.\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: tests that Manager.Send returns the expected error when the +// provided file is larger than FileMaxSize. +func TestManager_Send_FileSizeError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + + fileData := make([]byte, FileMaxSize+1) + expectedErr := fmt.Sprintf(fileSizeErr, len(fileData), FileMaxSize) + + _, err := m.Send("", fileData, nil, 0, nil, nil, 0) + if err == nil || err.Error() != expectedErr { + t.Errorf("Send did not return the expected error when the file data "+ + "is too large.\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: tests that Manager.Send returns the expected error when the +// provided preview is larger than PreviewMaxSize. +func TestManager_Send_PreviewSizeError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + + previewData := make([]byte, PreviewMaxSize+1) + expectedErr := fmt.Sprintf(previewSizeErr, len(previewData), PreviewMaxSize) + + _, err := m.Send("", nil, nil, 0, previewData, nil, 0) + if err == nil || err.Error() != expectedErr { + t.Errorf("Send did not return the expected error when the preview "+ + "data is too large.\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: tests that Manager.Send returns the expected error when the E2E +// message fails to send. +func TestManager_Send_SendE2eError(t *testing.T) { + m := newTestManager(true, nil, nil, nil, t) + prng := NewPrng(42) + recipient := id.NewIdFromString("recipient", id.User, t) + fileName := "testFile" + numParts := uint16(16) + partSize, _ := m.getPartSize() + fileData, _ := newFile(numParts, uint32(partSize), prng, t) + preview := []byte("filePreview") + retry := float32(1.5) + + expectedErr := fmt.Sprintf(sendE2eErr, recipient, "") + + _, err := m.Send(fileName, fileData, recipient, retry, preview, nil, 0) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("Send did not return the expected error when the E2E message "+ + "failed to send.\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that Manager.RegisterSendProgressCallback calls the callback when it is +// added to the transfer and that the callback is associated with the expected +// transfer and is called when calling from the transfer. +func TestManager_RegisterSendProgressCallback(t *testing.T) { + m, sti, _ := newTestManagerWithTransfers([]uint16{12, 4, 1}, false, nil, t) + expectedErr := errors.New("CallbackError") + + // Create new callback and channel for the callback to trigger + cbChan := make(chan sentProgressResults, 6) + cb := func(completed bool, sent, arrived, total uint16, err error) { + cbChan <- sentProgressResults{completed, sent, arrived, total, err} + } + + // Start thread waiting for callback to be called + done0, done1 := make(chan bool), make(chan bool) + go func() { + for i := 0; i < 2; i++ { + select { + case <-time.NewTimer(20 * time.Millisecond).C: + t.Errorf("Timed out waiting for callback call #%d.", i) + case r := <-cbChan: + switch i { + case 0: + err := checkSentProgress(r.completed, r.sent, r.arrived, + r.total, false, 0, 0, sti[0].numParts) + if err != nil { + t.Errorf("%d: %+v", i, err) + } + if r.err != nil { + t.Errorf("Callback returned an error (%d): %+v", i, r.err) + } + done0 <- true + case 1: + if r.err == nil || r.err != expectedErr { + t.Errorf("Callback did not return the expected error (%d)."+ + "\nexpected: %v\nreceived: %+v", i, expectedErr, r.err) + } + done1 <- true + } + } + } + }() + + err := m.RegisterSendProgressCallback(sti[0].tid, cb, time.Millisecond) + if err != nil { + t.Errorf("RegisterSendProgressCallback returned an error: %+v", err) + } + <-done0 + + transfer, _ := m.sent.GetTransfer(sti[0].tid) + + transfer.CallProgressCB(expectedErr) + + <-done1 +} + +// Error path: tests that Manager.RegisterSendProgressCallback returns an error +// when no transfer with the ID exists. +func TestManager_RegisterSendProgressCallback_NoTransferError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) + + err := m.RegisterSendProgressCallback(tid, nil, 0) + if err == nil { + t.Error("RegisterSendProgressCallback did not return an error when " + + "no transfer with the ID exists.") + } +} + +func TestManager_Resend(t *testing.T) { + +} + +// Error path: tests that Manager.Resend returns an error when no transfer with +// the ID exists. +func TestManager_Resend_NoTransferError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) + + err := m.Resend(tid) + if err == nil { + t.Error("Resend did not return an error when no transfer with the " + + "ID exists.") + } +} + +// Error path: tests that Manager.Resend returns the error when the transfer has +// not run out of fingerprints. +func TestManager_Resend_NoFingerprints(t *testing.T) { + m, sti, _ := newTestManagerWithTransfers([]uint16{16}, false, nil, t) + expectedErr := fmt.Sprintf(transferNotFailedErr, sti[0].tid) + // Delete the transfer + err := m.Resend(sti[0].tid) + if err == nil || err.Error() != expectedErr { + t.Errorf("Resend did not return the expected error when the transfer "+ + "has not run out of fingerprints.\nexpected: %s\nreceived: %v", + expectedErr, err) + } +} + +// Tests that Manager.CloseSend deletes the transfer when it has run out of +// fingerprints but is not complete. +func TestManager_CloseSend_NoFingerprints(t *testing.T) { + m, sti, _ := newTestManagerWithTransfers([]uint16{16}, false, nil, t) + prng := NewPrng(42) + partSize, _ := m.getPartSize() + + // Use up all the fingerprints in the transfer + transfer, _ := m.sent.GetTransfer(sti[0].tid) + for fpNum := uint16(0); fpNum < sti[0].numFps; fpNum++ { + partNum := fpNum % sti[0].numParts + _, _, _, _, err := transfer.GetEncryptedPart(partNum, partSize, prng) + if err != nil { + t.Errorf("Failed to encrypt part %d (%d): %+v", partNum, fpNum, err) + } + } + + // Delete the transfer + err := m.CloseSend(sti[0].tid) + if err != nil { + t.Errorf("CloseSend returned an error: %+v", err) + } + + // Check that the transfer was deleted + _, err = m.sent.GetTransfer(sti[0].tid) + if err == nil { + t.Errorf("Failed to delete transfer %s.", sti[0].tid) + } +} + +// Tests that Manager.CloseSend deletes the transfer when it completed but has +// fingerprints. +func TestManager_CloseSend_Complete(t *testing.T) { + m, sti, _ := newTestManagerWithTransfers([]uint16{3}, false, nil, t) + + // Set all parts to finished + transfer, _ := m.sent.GetTransfer(sti[0].tid) + err, _ := transfer.SetInProgress(0, 0, 1, 2) + if err != nil { + t.Errorf("Failed to set parts to in-progress: %+v", err) + } + err = transfer.FinishTransfer(0) + if err != nil { + t.Errorf("Failed to set parts to finished: %+v", err) + } + + // Delete the transfer + err = m.CloseSend(sti[0].tid) + if err != nil { + t.Errorf("CloseSend returned an error: %+v", err) + } + + // Check that the transfer was deleted + _, err = m.sent.GetTransfer(sti[0].tid) + if err == nil { + t.Errorf("Failed to delete transfer %s.", sti[0].tid) + } +} + +// Error path: tests that Manager.CloseSend returns an error when no transfer +// with the ID exists. +func TestManager_CloseSend_NoTransferError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) + + err := m.CloseSend(tid) + if err == nil { + t.Error("CloseSend did not return an error when no transfer with the " + + "ID exists.") + } +} + +// Error path: tests that Manager.CloseSend returns an error when the transfer +// has not run out of fingerprints and is not complete +func TestManager_CloseSend_NotCompleteErr(t *testing.T) { + m, sti, _ := newTestManagerWithTransfers([]uint16{16}, false, nil, t) + expectedErr := fmt.Sprintf(transferInProgressErr, sti[0].tid) + + err := m.CloseSend(sti[0].tid) + if err == nil || err.Error() != expectedErr { + t.Errorf("CloseSend did not return the expected error when the transfer"+ + "is not complete.\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that Manager.RegisterReceiveProgressCallback calls the callback when it +// is added to the transfer and that the callback is associated with the +// expected transfer and is called when calling from the transfer. +func TestManager_RegisterReceiveProgressCallback(t *testing.T) { + m, _, rti := newTestManagerWithTransfers([]uint16{12, 4, 1}, false, nil, t) + expectedErr := errors.New("CallbackError") + + // Create new callback and channel for the callback to trigger + cbChan := make(chan receivedProgressResults, 6) + cb := func(completed bool, received, total uint16, err error) { + cbChan <- receivedProgressResults{completed, received, total, err} + } + + // Start thread waiting for callback to be called + done0, done1 := make(chan bool), make(chan bool) + go func() { + for i := 0; i < 2; i++ { + select { + case <-time.NewTimer(20 * time.Millisecond).C: + t.Errorf("Timed out waiting for callback call #%d.", i) + case r := <-cbChan: + switch i { + case 0: + err := checkReceivedProgress(r.completed, r.received, + r.total, false, 0, rti[0].numParts) + if err != nil { + t.Errorf("%d: %+v", i, err) + } + if r.err != nil { + t.Errorf("Callback returned an error (%d): %+v", i, r.err) + } + done0 <- true + case 1: + if r.err == nil || r.err != expectedErr { + t.Errorf("Callback did not return the expected error (%d)."+ + "\nexpected: %v\nreceived: %+v", i, expectedErr, r.err) + } + done1 <- true + } + } + } + }() + + err := m.RegisterReceiveProgressCallback(rti[0].tid, cb, time.Millisecond) + if err != nil { + t.Errorf("RegisterReceiveProgressCallback returned an error: %+v", err) + } + <-done0 + + transfer, _ := m.received.GetTransfer(rti[0].tid) + + transfer.CallProgressCB(expectedErr) + + <-done1 +} + +// Error path: tests that Manager.RegisterReceiveProgressCallback returns an +// error when no transfer with the ID exists. +func TestManager_RegisterReceiveProgressCallback_NoTransferError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) + + err := m.RegisterReceiveProgressCallback(tid, nil, 0) + if err == nil { + t.Error("RegisterReceiveProgressCallback did not return an error " + + "when no transfer with the ID exists.") + } +} + +// Error path: tests that Manager.Receive returns an error when no transfer with +// the ID exists. +func TestManager_Receive_NoTransferError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) + + _, err := m.Receive(tid) + if err == nil { + t.Error("Receive did not return an error when no transfer with the ID " + + "exists.") + } +} + +// Error path: tests that Manager.Receive returns an error when the file is +// incomplete. +func TestManager_Receive_GetFileError(t *testing.T) { + m, _, rti := newTestManagerWithTransfers([]uint16{12, 4, 1}, false, nil, t) + + _, err := m.Receive(rti[0].tid) + if err == nil || !strings.Contains(err.Error(), "missing") { + t.Error("Receive did not return the expected error when no transfer " + + "with the ID exists.") + } +} + +// Tests that calcNumberOfFingerprints matches some manually calculated +// results. +func Test_calcNumberOfFingerprints(t *testing.T) { + testValues := []struct { + numParts uint16 + retry float32 + result uint16 + }{ + {12, 0.5, 18}, + {13, 0.6667, 21}, + {1, 0.89, 1}, + {2, 0.75, 3}, + {119, 0.45, 172}, + } + + for i, val := range testValues { + result := calcNumberOfFingerprints(val.numParts, val.retry) + + if val.result != result { + t.Errorf("calcNumberOfFingerprints(%3d, %3.2f) result is "+ + "incorrect (%d).\nexpected: %d\nreceived: %d", + val.numParts, val.retry, i, val.result, result) + } + } +} + +// Tests that Manager satisfies the interfaces.FileTransfer interface. +func TestManager_FileTransferInterface(t *testing.T) { + var _ interfaces.FileTransfer = Manager{} +} + +// Sets up a mock file transfer with two managers that sends one file from one +// to another. +func Test_FileTransfer(t *testing.T) { + var wg sync.WaitGroup + + // Create callback with channel for receiving new file transfer + receiveNewCbChan := make(chan receivedFtResults, 100) + receiveNewCB := func(tid ftCrypto.TransferID, fileName string, + sender *id.ID, size uint32, preview []byte) { + receiveNewCbChan <- receivedFtResults{tid, fileName, sender, size, preview} + } + + // Create reception channels for both managers + newFtChan1 := make(chan message.Receive, rawMessageBuffSize) + filePartChan1 := make(chan message.Receive, rawMessageBuffSize) + newFtChan2 := make(chan message.Receive, rawMessageBuffSize) + filePartChan2 := make(chan message.Receive, rawMessageBuffSize) + + // Generate sending and receiving managers + m1 := newTestManager(false, filePartChan2, newFtChan2, nil, t) + m2 := newTestManager(false, filePartChan1, newFtChan1, receiveNewCB, t) + + stop1, err := m1.startProcesses(newFtChan1, filePartChan1) + if err != nil { + t.Errorf("Failed to start processes for sending manager: %+v", err) + } + + stop2, err := m2.startProcesses(newFtChan2, filePartChan2) + if err != nil { + t.Errorf("Failed to start processes for receving manager: %+v", err) + } + + // Create progress tracker for sending + sentCbChan := make(chan sentProgressResults, 20) + sentCb := func(completed bool, sent, arrived, total uint16, err error) { + sentCbChan <- sentProgressResults{completed, sent, arrived, total, err} + } + + // Start threads that tracks sent progress until complete + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 20; i++ { + select { + case <-time.NewTimer(250 * time.Millisecond).C: + t.Errorf("Timed out waiting for sent progress callback %d.", i) + case r := <-sentCbChan: + if r.completed { + return + } + } + } + t.Error("Sent progress callback never reported file finishing to send.") + }() + + // Create file and parameters + prng := NewPrng(42) + partSize, _ := m1.getPartSize() + fileName := "testFile" + file, parts := newFile(32, uint32(partSize), prng, t) + preview := parts[0] + recipient := id.NewIdFromString("recipient", id.User, t) + + // Send file + sendTid, err := m1.Send( + fileName, file, recipient, 0.5, preview, sentCb, time.Millisecond) + if err != nil { + t.Errorf("Send returned an error: %+v", err) + } + + // Wait for the receiving manager to get E2E message and call callback to + // get transfer ID of received transfer + var receiveTid ftCrypto.TransferID + select { + case <-time.NewTimer(10 * time.Millisecond).C: + t.Error("Timed out waiting for receive callback: ") + case r := <-receiveNewCbChan: + if !bytes.Equal(preview, r.preview) { + t.Errorf("File preview received from callback incorrect."+ + "\nexpected: %q\nreceived: %q", preview, r.preview) + } + if len(file) != int(r.size) { + t.Errorf("File size received from callback incorrect."+ + "\nexpected: %d\nreceived: %d", len(file), r.size) + } + receiveTid = r.tid + } + + // Register progress callback with receiving manager + receiveCbChan := make(chan receivedProgressResults, 100) + receiveCb := func(completed bool, received, total uint16, err error) { + receiveCbChan <- receivedProgressResults{completed, received, total, err} + } + + // Start threads that tracks received progress until complete + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 20; i++ { + select { + case <-time.NewTimer(250 * time.Millisecond).C: + t.Errorf("Timed out waiting for receive progress callback %d.", i) + case r := <-receiveCbChan: + if r.completed { + return + } + } + } + t.Error("Receive progress callback never reported file finishing to receive.") + }() + + err = m2.RegisterReceiveProgressCallback(receiveTid, receiveCb, time.Millisecond) + if err != nil { + t.Errorf("Failed to register receive progress callback: %+v", err) + } + + wg.Wait() + + // Check that the file can be received + receivedFile, err := m2.Receive(receiveTid) + if err != nil { + t.Errorf("Failed to receive file: %+v", err) + } + + // Check that the received file matches the sent file + if !bytes.Equal(file, receivedFile) { + t.Errorf("Received file does not match sent."+ + "\nexpected: %q\nrecevied: %q", file, receivedFile) + } + + // Check that the received transfer was deleted + _, err = m2.received.GetTransfer(receiveTid) + if err == nil { + t.Error("Failed to delete received file transfer once file has been " + + "received.") + } + + // Close the transfer on the sending manager + err = m1.CloseSend(sendTid) + if err != nil { + t.Errorf("Failed to close the send: %+v", err) + } + + // Check that the sent transfer was deleted + _, err = m1.sent.GetTransfer(sendTid) + if err == nil { + t.Error("Failed to delete sent file transfer once file has been sent " + + "and closed.") + } + + if err = stop1.Close(); err != nil { + t.Errorf("Failed to close sending manager threads: %+v", err) + } + + if err = stop2.Close(); err != nil { + t.Errorf("Failed to close receiving manager threads: %+v", err) + } +} diff --git a/fileTransfer/params.go b/fileTransfer/params.go new file mode 100644 index 0000000000000000000000000000000000000000..76fa1cf2d871f78aeba37b6b08e8b73e2dcdfb8f --- /dev/null +++ b/fileTransfer/params.go @@ -0,0 +1,33 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +const ( + defaultMaxThroughput = 150_000 // 150 kB per second +) + +// Params contains parameters used for file transfer. +type Params struct { + // MaxThroughput is the maximum data transfer speed to send file parts (in + // bytes per second) + MaxThroughput int +} + +// NewParams generates a new Params object filled with the given parameters. +func NewParams(maxThroughput int) Params { + return Params{ + MaxThroughput: maxThroughput, + } +} + +// DefaultParams returns a Params object filled with the default values. +func DefaultParams() Params { + return Params{ + MaxThroughput: defaultMaxThroughput, + } +} diff --git a/fileTransfer/params_test.go b/fileTransfer/params_test.go new file mode 100644 index 0000000000000000000000000000000000000000..8a41f1fc8de56b6c906ccecb6d9761b4096402f0 --- /dev/null +++ b/fileTransfer/params_test.go @@ -0,0 +1,38 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "reflect" + "testing" +) + +// Tests that NewParams returns the expected Params object. +func TestNewParams(t *testing.T) { + expected := Params{ + MaxThroughput: 42, + } + + received := NewParams(expected.MaxThroughput) + + if !reflect.DeepEqual(expected, received) { + t.Errorf("Received Params does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, received) + } +} + +// Tests that DefaultParams returns a Params object with the expected defaults. +func TestDefaultParams(t *testing.T) { + expected := Params{MaxThroughput: defaultMaxThroughput} + received := DefaultParams() + + if !reflect.DeepEqual(expected, received) { + t.Errorf("Received Params does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, received) + } +} diff --git a/fileTransfer/receive.go b/fileTransfer/receive.go new file mode 100644 index 0000000000000000000000000000000000000000..84ff35a626a5a75c19894ffa1bd1392cb2ab908f --- /dev/null +++ b/fileTransfer/receive.go @@ -0,0 +1,82 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" + "gitlab.com/elixxir/primitives/format" + "strings" +) + +// Error messages. +const ( + // Manager.readMessage + unmarshalPartMessageErr = "failed to unmarshal cMix message contents into file part message: %+v" +) + +// receive runs a loop that receives file message parts and stores them in their +// appropriate transfer. +func (m *Manager) receive(rawMsgs chan message.Receive, stop *stoppable.Single) { + jww.DEBUG.Print("Starting file part reception thread.") + + for { + select { + case <-stop.Quit(): + jww.DEBUG.Print("Stopping file part reception thread: stoppable " + + "triggered") + stop.ToStopped() + return + case receiveMsg := <-rawMsgs: + cMixMsg, err := m.readMessage(receiveMsg) + if err != nil { + // Print error as warning unless the fingerprint does not match, + // which means this message is not of the correct type and will + // be ignored + if strings.Contains(err.Error(), "fingerprint") { + jww.INFO.Print(err) + } else { + jww.WARN.Print(err) + } + continue + } + + // Denote that the message is a file part + m.store.GetGarbledMessages().Remove(cMixMsg) + } + } +} + +// readMessage unmarshal the payload in the message.Receive and stores it with +// the appropriate received transfer. The cMix message is returned so that, on +// error, it can be either marked as used not used. +func (m *Manager) readMessage(msg message.Receive) (format.Message, error) { + // Unmarshal payload into cMix message + cMixMsg := format.Unmarshal(msg.Payload) + + // Unmarshal cMix message contents into a file part message + partMsg, err := unmarshalPartMessage(cMixMsg.GetContents()) + if err != nil { + return cMixMsg, errors.Errorf(unmarshalPartMessageErr, err) + } + + // Add part to received transfer + transfer, _, err := m.received.AddPart(partMsg.getPart(), + partMsg.getPadding(), cMixMsg.GetMac(), partMsg.getPartNum(), + cMixMsg.GetKeyFP()) + if err != nil { + return cMixMsg, err + } + + // Call callback with updates + transfer.CallProgressCB(nil) + + return cMixMsg, nil +} diff --git a/fileTransfer/receiveNew.go b/fileTransfer/receiveNew.go new file mode 100644 index 0000000000000000000000000000000000000000..908eb0e7777ee24e1b7e3b46cf4e066b4d590ff1 --- /dev/null +++ b/fileTransfer/receiveNew.go @@ -0,0 +1,99 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "github.com/golang/protobuf/proto" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/primitives/id" +) + +// Error messages. +const ( + receiveMessageTypeErr = "received message is not of type NewFileTransfer" + protoUnmarshalErr = "failed to unmarshal request: %+v" +) + +// receiveNewFileTransfer starts a thread that waits for new file transfer +// messages. +func (m *Manager) receiveNewFileTransfer(rawMsgs chan message.Receive, + stop *stoppable.Single) { + jww.DEBUG.Print("Starting new file transfer message reception thread.") + + for { + select { + case <-stop.Quit(): + jww.DEBUG.Print("Stopping new file transfer message reception " + + "thread: stoppable triggered") + stop.ToStopped() + return + case receivedMsg := <-rawMsgs: + jww.DEBUG.Print("New file transfer message thread received message.") + + tid, fileName, sender, size, preview, err := + m.readNewFileTransferMessage(receivedMsg) + if err != nil { + if err.Error() == receiveMessageTypeErr { + jww.INFO.Printf("Failed to read message as new file "+ + "transfer message: %+v", err) + } else { + jww.WARN.Printf("Failed to read message as new file "+ + "transfer message: %+v", err) + } + continue + } + + // Call the reception callback + go m.receiveCB(tid, fileName, sender, size, preview) + + // Trigger a resend of all garbled messages + m.net.CheckGarbledMessages() + } + } +} + +// readNewFileTransferMessage reads the received message and adds it to the +// received transfer list. Returns the transfer ID, sender ID, file size, and +// file preview. +func (m *Manager) readNewFileTransferMessage(msg message.Receive) ( + ftCrypto.TransferID, string, *id.ID, uint32, []byte, error) { + + // Return an error if the message is not a NewFileTransfer + if msg.MessageType != message.NewFileTransfer { + return ftCrypto.TransferID{}, "", nil, 0, nil, + errors.New(receiveMessageTypeErr) + } + + // Unmarshal the request message + newFT := &NewFileTransfer{} + err := proto.Unmarshal(msg.Payload, newFT) + if err != nil { + return ftCrypto.TransferID{}, "", nil, 0, nil, + errors.Errorf(protoUnmarshalErr, err) + } + + // Get RNG from stream + rng := m.rng.GetStream() + defer rng.Close() + + // Add the transfer to the list of receiving transfers + key := ftCrypto.UnmarshalTransferKey(newFT.TransferKey) + numParts := uint16(newFT.NumParts) + numFps := calcNumberOfFingerprints(numParts, newFT.Retry) + tid, err := m.received.AddTransfer( + key, newFT.TransferMac, newFT.Size, numParts, numFps, rng) + if err != nil { + return ftCrypto.TransferID{}, "", nil, 0, nil, err + } + + return tid, newFT.FileName, msg.Sender, newFT.Size, newFT.Preview, nil +} diff --git a/fileTransfer/receiveNew_test.go b/fileTransfer/receiveNew_test.go new file mode 100644 index 0000000000000000000000000000000000000000..16863e80e7a8bdbdcb5b640b77daa6840e253251 --- /dev/null +++ b/fileTransfer/receiveNew_test.go @@ -0,0 +1,283 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "github.com/golang/protobuf/proto" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/primitives/id" + "strings" + "testing" + "time" +) + +// Tests that Manager.receiveNewFileTransfer receives the sent message and that +// it reports the correct data to the callback. +func TestManager_receiveNewFileTransfer(t *testing.T) { + // Create new ReceiveCallback that sends the results on a channel + receiveChan := make(chan receivedFtResults) + receiveCB := func(tid ftCrypto.TransferID, fileName string, sender *id.ID, + size uint32, preview []byte) { + receiveChan <- receivedFtResults{tid, fileName, sender, size, preview} + } + + // Create new manager, stoppable, and channel to receive messages + m := newTestManager(false, nil, nil, receiveCB, t) + stop := stoppable.NewSingle(newFtStoppableName) + rawMsgs := make(chan message.Receive, rawMessageBuffSize) + + // Start receiving thread + go m.receiveNewFileTransfer(rawMsgs, stop) + + // Create new message.Receive with marshalled NewFileTransfer + key, _ := ftCrypto.NewTransferKey(NewPrng(42)) + protoMsg := &NewFileTransfer{ + FileName: "testFile", + TransferKey: key.Bytes(), + TransferMac: []byte("transferMac"), + NumParts: 16, + Size: 256, + Retry: 1.5, + Preview: []byte("filePreview"), + } + marshalledMsg, err := proto.Marshal(protoMsg) + if err != nil { + t.Errorf("Failed to marshal proto message: %+v", err) + } + receiveMsg := message.Receive{ + Payload: marshalledMsg, + MessageType: message.NewFileTransfer, + Sender: id.NewIdFromString("sender", id.User, t), + } + + // Add message to channel + rawMsgs <- receiveMsg + + // Wait to receive message + select { + case <-time.NewTimer(100 * time.Millisecond).C: + t.Error("Timed out waiting to receive message.") + case r := <-receiveChan: + if !receiveMsg.Sender.Cmp(r.sender) { + t.Errorf("Received sender ID does not match expected."+ + "\nexpected: %s\nreceived: %s", receiveMsg.Sender, r.sender) + } + if protoMsg.Size != r.size { + t.Errorf("Received file size does not match expected."+ + "\nexpected: %d\nreceived: %d", protoMsg.Size, r.size) + } + if !bytes.Equal(protoMsg.Preview, r.preview) { + t.Errorf("Received preview does not match expected."+ + "\nexpected: %q\nreceived: %q", protoMsg.Preview, r.preview) + } + } + + // Stop thread + err = stop.Close() + if err != nil { + t.Errorf("Failed to stop stoppable: %+v", err) + } +} + +// Tests that Manager.receiveNewFileTransfer stops receiving messages when the +// stoppable is triggered. +func TestManager_receiveNewFileTransfer_Stop(t *testing.T) { + // Create new ReceiveCallback that sends the results on a channel + receiveChan := make(chan receivedFtResults) + receiveCB := func(tid ftCrypto.TransferID, fileName string, sender *id.ID, + size uint32, preview []byte) { + receiveChan <- receivedFtResults{tid, fileName, sender, size, preview} + } + + // Create new manager, stoppable, and channel to receive messages + m := newTestManager(false, nil, nil, receiveCB, t) + stop := stoppable.NewSingle(newFtStoppableName) + rawMsgs := make(chan message.Receive, rawMessageBuffSize) + + // Start receiving thread + go m.receiveNewFileTransfer(rawMsgs, stop) + + // Create new message.Receive with marshalled NewFileTransfer + key, _ := ftCrypto.NewTransferKey(NewPrng(42)) + protoMsg := &NewFileTransfer{ + FileName: "testFile", + TransferKey: key.Bytes(), + TransferMac: []byte("transferMac"), + NumParts: 16, + Size: 256, + Retry: 1.5, + Preview: []byte("filePreview"), + } + marshalledMsg, err := proto.Marshal(protoMsg) + if err != nil { + t.Errorf("Failed to marshal proto message: %+v", err) + } + receiveMsg := message.Receive{ + Payload: marshalledMsg, + MessageType: message.NewFileTransfer, + Sender: id.NewIdFromString("sender", id.User, t), + } + + // Stop thread + err = stop.Close() + if err != nil { + t.Errorf("Failed to stop stoppable: %+v", err) + } + + for !stop.IsStopped() { + } + + // Add message to channel + rawMsgs <- receiveMsg + + // Wait to receive message + select { + case <-time.NewTimer(time.Millisecond).C: + case r := <-receiveChan: + t.Errorf("Callback called when the thread should have quit."+ + "\nreceived: %+v", r) + } +} + +// Tests that Manager.receiveNewFileTransfer does not report on the callback +// when the received message is of the wrong type. +func TestManager_receiveNewFileTransfer_InvalidMessageError(t *testing.T) { + // Create new ReceiveCallback that sends the results on a channel + receiveChan := make(chan receivedFtResults) + receiveCB := func(tid ftCrypto.TransferID, fileName string, sender *id.ID, + size uint32, preview []byte) { + receiveChan <- receivedFtResults{tid, fileName, sender, size, preview} + } + + // Create new manager, stoppable, and channel to receive messages + m := newTestManager(false, nil, nil, receiveCB, t) + stop := stoppable.NewSingle(newFtStoppableName) + rawMsgs := make(chan message.Receive, rawMessageBuffSize) + + // Start receiving thread + go m.receiveNewFileTransfer(rawMsgs, stop) + + // Create new message.Receive with wrong type + receiveMsg := message.Receive{ + MessageType: message.NoType, + } + + // Add message to channel + rawMsgs <- receiveMsg + + // Wait to receive message + select { + case <-time.NewTimer(time.Millisecond).C: + case r := <-receiveChan: + t.Errorf("Callback called when the message is of the wrong type."+ + "\nreceived: %+v", r) + } + + // Stop thread + err := stop.Close() + if err != nil { + t.Errorf("Failed to stop stoppable: %+v", err) + } +} + +// Tests that Manager.readNewFileTransferMessage returns the expected sender ID, +// file size, and preview. +func TestManager_readNewFileTransferMessage(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + + // Create new message.Send containing marshalled NewFileTransfer + recipient := id.NewIdFromString("recipient", id.User, t) + expectedFileName := "testFile" + key, _ := ftCrypto.NewTransferKey(NewPrng(42)) + mac := []byte("transferMac") + numParts, expectedFileSize, retry := uint16(16), uint32(256), float32(1.5) + expectedPreview := []byte("filePreview") + sendMsg, err := newNewFileTransferE2eMessage(recipient, expectedFileName, + key, mac, numParts, expectedFileSize, retry, expectedPreview) + if err != nil { + t.Errorf("Failed to create new Send message: %+v", err) + } + + // Create message.Receive with marshalled NewFileTransfer + receiveMsg := message.Receive{ + Payload: sendMsg.Payload, + MessageType: message.NewFileTransfer, + Sender: id.NewIdFromString("sender", id.User, t), + } + + // Read the message + _, fileName, sender, fileSize, preview, err := m.readNewFileTransferMessage(receiveMsg) + if err != nil { + t.Errorf("readNewFileTransferMessage returned an error: %+v", err) + } + + if expectedFileName != fileName { + t.Errorf("Returned file name does not match expected."+ + "\nexpected: %q\nreceived: %q", expectedFileName, fileName) + } + + if !receiveMsg.Sender.Cmp(sender) { + t.Errorf("Returned sender ID does not match expected."+ + "\nexpected: %s\nreceived: %s", receiveMsg.Sender, sender) + } + + if expectedFileSize != fileSize { + t.Errorf("Returned file size does not match expected."+ + "\nexpected: %d\nreceived: %d", expectedFileSize, fileSize) + } + + if !bytes.Equal(expectedPreview, preview) { + t.Errorf("Returned preview does not match expected."+ + "\nexpected: %q\nreceived: %q", expectedPreview, preview) + } +} + +// Error path: tests that Manager.readNewFileTransferMessage returns the +// expected error when the message.Receive has the wrong MessageType. +func TestManager_readNewFileTransferMessage_MessageTypeError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + expectedErr := receiveMessageTypeErr + + // Create message.Receive with marshalled NewFileTransfer + receiveMsg := message.Receive{ + MessageType: message.NoType, + } + + // Read the message + _, _, _, _, _, err := m.readNewFileTransferMessage(receiveMsg) + if err == nil || err.Error() != expectedErr { + t.Errorf("readNewFileTransferMessage did not return the expected "+ + "error when the message.Receive has the wrong MessageType."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: tests that Manager.readNewFileTransferMessage returns the +// expected error when the payload of the message.Receive cannot be +// unmarshalled. +func TestManager_readNewFileTransferMessage_ProtoUnmarshalError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + expectedErr := strings.Split(protoUnmarshalErr, "%")[0] + + // Create message.Receive with marshalled NewFileTransfer + receiveMsg := message.Receive{ + Payload: []byte("invalidPayload"), + MessageType: message.NewFileTransfer, + } + + // Read the message + _, _, _, _, _, err := m.readNewFileTransferMessage(receiveMsg) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("readNewFileTransferMessage did not return the expected "+ + "error when the payload could not be unmarshalled."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} diff --git a/fileTransfer/receive_test.go b/fileTransfer/receive_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e8d732d319bcff656edc064d06ef7e15058408af --- /dev/null +++ b/fileTransfer/receive_test.go @@ -0,0 +1,356 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/primitives/id" + "reflect" + "testing" + "time" +) + +// Tests that Manager.receive returns the correct progress on the callback when +// receiving a single message. +func TestManager_receive(t *testing.T) { + // Build a manager for sending and a manger for receiving + m1 := newTestManager(false, nil, nil, nil, t) + m2 := newTestManager(false, nil, nil, nil, t) + + // Create transfer components + prng := NewPrng(42) + recipient := id.NewIdFromString("recipient", id.User, t) + key, _ := ftCrypto.NewTransferKey(prng) + numParts := uint16(16) + numFps := calcNumberOfFingerprints(numParts, 0.5) + partSize, _ := m1.getPartSize() + file, parts := newFile(numParts, uint32(partSize), prng, t) + fileSize := uint32(len(file)) + mac := ftCrypto.CreateTransferMAC(file, key) + + // Add transfer to sending manager + stID, err := m1.sent.AddTransfer( + recipient, key, parts, numFps, nil, 0, prng) + if err != nil { + t.Errorf("Failed to add new sent transfer: %+v", err) + } + + // Add transfer to receiving manager + rtID, err := m2.received.AddTransfer( + key, mac, fileSize, numParts, numFps, prng) + if err != nil { + t.Errorf("Failed to add new received transfer: %+v", err) + } + + // Generate receive callback that should be called when a message is read + type progressResults struct { + completed bool + received, total uint16 + err error + } + cbChan := make(chan progressResults) + cb := func(completed bool, received, total uint16, err error) { + cbChan <- progressResults{completed, received, total, err} + } + + done0, done1 := make(chan bool), make(chan bool) + go func() { + for i := 0; i < 2; i++ { + select { + case <-time.NewTimer(10 * time.Millisecond).C: + t.Error("Timed out waiting for callback to be called.") + case r := <-cbChan: + switch i { + case 0: + done0 <- true + case 1: + err = checkReceivedProgress(r.completed, r.received, r.total, false, + 1, numParts) + if r.err != nil { + t.Errorf("Callback returned an error: %+v", err) + } + if err != nil { + t.Error(err) + } + done1 <- true + } + } + } + }() + + rt, err := m2.received.GetTransfer(rtID) + if err != nil { + t.Errorf("Failed to get received transfer %s: %+v", rtID, err) + } + rt.AddProgressCB(cb, time.Millisecond) + + <-done0 + + // Build message.Receive with cMix message that has file part + st, err := m1.sent.GetTransfer(stID) + if err != nil { + t.Errorf("Failed to get sent transfer %s: %+v", stID, err) + } + cMixMsg, err := m1.newCmixMessage(st, 0, prng) + if err != nil { + t.Errorf("Failed to create new cMix message: %+v", err) + } + receiveMsg := message.Receive{Payload: cMixMsg.Marshal()} + + // Start reception thread + rawMsgs := make(chan message.Receive, rawMessageBuffSize) + stop := stoppable.NewSingle(filePartStoppableName) + go m2.receive(rawMsgs, stop) + + // Send message on channel; + rawMsgs <- receiveMsg + + <-done1 + + // Create cMix message with wrong fingerprint + cMixMsg.SetKeyFP(format.NewFingerprint([]byte("invalidFP"))) + receiveMsg = message.Receive{Payload: cMixMsg.Marshal()} + + done := make(chan bool) + go func() { + select { + case <-time.NewTimer(10 * time.Millisecond).C: + case r := <-cbChan: + t.Errorf("Callback should not be called for invalid message: %+v", r) + } + done <- true + }() + + // Send message on channel; + rawMsgs <- receiveMsg + + <-done +} + +// Tests that Manager.receive the progress callback is not called when the +// stoppable is triggered. +func TestManager_receive_Stop(t *testing.T) { + // Build a manager for sending and a manger for receiving + m1 := newTestManager(false, nil, nil, nil, t) + m2 := newTestManager(false, nil, nil, nil, t) + + // Create transfer components + prng := NewPrng(42) + recipient := id.NewIdFromString("recipient", id.User, t) + key, _ := ftCrypto.NewTransferKey(prng) + numParts := uint16(16) + numFps := calcNumberOfFingerprints(numParts, 0.5) + partSize, _ := m1.getPartSize() + file, parts := newFile(numParts, uint32(partSize), prng, t) + fileSize := uint32(len(file)) + mac := ftCrypto.CreateTransferMAC(file, key) + + // Add transfer to sending manager + stID, err := m1.sent.AddTransfer( + recipient, key, parts, numFps, nil, 0, prng) + if err != nil { + t.Errorf("Failed to add new sent transfer: %+v", err) + } + + // Add transfer to receiving manager + rtID, err := m2.received.AddTransfer( + key, mac, fileSize, numParts, numFps, prng) + if err != nil { + t.Errorf("Failed to add new received transfer: %+v", err) + } + + // Generate receive callback that should be called when a message is read + type progressResults struct { + completed bool + received, total uint16 + err error + } + cbChan := make(chan progressResults) + cb := func(completed bool, received, total uint16, err error) { + cbChan <- progressResults{completed, received, total, err} + } + + done0, done1 := make(chan bool), make(chan bool) + go func() { + for i := 0; i < 2; i++ { + select { + case <-time.NewTimer(20 * time.Millisecond).C: + done1 <- true + case r := <-cbChan: + switch i { + case 0: + done0 <- true + case 1: + t.Errorf("Callback should not have been called: %+v", r) + done1 <- true + } + } + } + }() + + rt, err := m2.received.GetTransfer(rtID) + if err != nil { + t.Errorf("Failed to get received transfer %s: %+v", rtID, err) + } + rt.AddProgressCB(cb, time.Millisecond) + + <-done0 + + // Build message.Receive with cMix message that has file part + st, err := m1.sent.GetTransfer(stID) + if err != nil { + t.Errorf("Failed to get sent transfer %s: %+v", stID, err) + } + cMixMsg, err := m1.newCmixMessage(st, 0, prng) + if err != nil { + t.Errorf("Failed to create new cMix message: %+v", err) + } + receiveMsg := message.Receive{ + Payload: cMixMsg.Marshal(), + } + + // Start reception thread + rawMsgs := make(chan message.Receive, rawMessageBuffSize) + stop := stoppable.NewSingle(filePartStoppableName) + go m2.receive(rawMsgs, stop) + + // Trigger stoppable + err = stop.Close() + if err != nil { + t.Errorf("Failed to close stoppable: %+v", err) + } + + for stop.IsStopping() { + + } + + // Send message on channel; + rawMsgs <- receiveMsg + + <-done1 +} + +// Tests that Manager.readMessage reads the message without errors and that it +// reports the correct progress on the callback. It also gets the file and +// checks that the part is where it should be. +func TestManager_readMessage(t *testing.T) { + + // Build a manager for sending and a manger for receiving + m1 := newTestManager(false, nil, nil, nil, t) + m2 := newTestManager(false, nil, nil, nil, t) + + // Create transfer components + prng := NewPrng(42) + recipient := id.NewIdFromString("recipient", id.User, t) + key, _ := ftCrypto.NewTransferKey(prng) + numParts := uint16(16) + numFps := calcNumberOfFingerprints(numParts, 0.5) + partSize, _ := m1.getPartSize() + file, parts := newFile(numParts, uint32(partSize), prng, t) + fileSize := uint32(len(file)) + mac := ftCrypto.CreateTransferMAC(file, key) + + // Add transfer to sending manager + stID, err := m1.sent.AddTransfer( + recipient, key, parts, numFps, nil, 0, prng) + if err != nil { + t.Errorf("Failed to add new sent transfer: %+v", err) + } + + // Add transfer to receiving manager + rtID, err := m2.received.AddTransfer( + key, mac, fileSize, numParts, numFps, prng) + if err != nil { + t.Errorf("Failed to add new received transfer: %+v", err) + } + + // Generate receive callback that should be called when a message is read + type progressResults struct { + completed bool + received, total uint16 + err error + } + cbChan := make(chan progressResults, 2) + cb := func(completed bool, received, total uint16, err error) { + cbChan <- progressResults{completed, received, total, err} + } + + done0, done1 := make(chan bool), make(chan bool) + go func() { + for i := 0; i < 2; i++ { + select { + case <-time.NewTimer(10 * time.Millisecond).C: + t.Error("Timed out waiting for callback to be called.") + case r := <-cbChan: + switch i { + case 0: + done0 <- true + case 1: + err := checkReceivedProgress(r.completed, r.received, r.total, false, + 1, numParts) + if r.err != nil { + t.Errorf("Callback returned an error: %+v", err) + } + if err != nil { + t.Error(err) + } + done1 <- true + } + } + } + }() + + rt, err := m2.received.GetTransfer(rtID) + if err != nil { + t.Errorf("Failed to get received transfer %s: %+v", rtID, err) + } + rt.AddProgressCB(cb, time.Millisecond) + + <-done0 + + // Build message.Receive with cMix message that has file part + st, err := m1.sent.GetTransfer(stID) + if err != nil { + t.Errorf("Failed to get sent transfer %s: %+v", stID, err) + } + cMixMsg, err := m1.newCmixMessage(st, 0, prng) + if err != nil { + t.Errorf("Failed to create new cMix message: %+v", err) + } + receiveMsg := message.Receive{ + Payload: cMixMsg.Marshal(), + } + + // Read receive message + receivedCMixMsg, err := m2.readMessage(receiveMsg) + if err != nil { + t.Errorf("readMessage returned an error: %+v", err) + } + + if !reflect.DeepEqual(cMixMsg, receivedCMixMsg) { + t.Errorf("Received cMix message does not match sent."+ + "\nexpected: %+v\nreceived: %+v", cMixMsg, receivedCMixMsg) + } + + <-done1 + + // Get the file and check that the part was added to it + fileData, err := rt.GetFile() + if err == nil { + t.Error("GetFile did not return an error when parts are missing.") + } + + if !bytes.Equal(parts[0], fileData[:partSize]) { + t.Errorf("Part failed to be added to part store."+ + "\nexpected: %q\nreceived: %q", parts[0], fileData[:partSize]) + } +} diff --git a/fileTransfer/send.go b/fileTransfer/send.go new file mode 100644 index 0000000000000000000000000000000000000000..7d9f9e9b5865fb154ec2f071b8b447eb788d3dd7 --- /dev/null +++ b/fileTransfer/send.go @@ -0,0 +1,449 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "encoding/binary" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/api" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/stoppable" + ftStorage "gitlab.com/elixxir/client/storage/fileTransfer" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "time" +) + +// Error messages. +const ( + // Manager.sendParts + sendManyCmixWarn = "Failed to send %d file parts %v via SendManyCMIX: %+v" + setInProgressErr = "Failed to set parts %v to in-progress for transfer %s" + getRoundResultsErr = "Failed to get round results for round %d: %+v" + + // Manager.buildMessages + noSentTransferWarn = "Could not get transfer %s for part %d: %+v" + maxRetriesErr = "Stopping message transfer: %+v" + newCmixMessageErr = "Failed to assemble cMix message for file part %d on transfer %s: %+v" + + // Manager.makeRoundEventCallback + finishTransferPanic = "Failed to set part(s) to finished for transfer %s: %+v" + roundFailureWarn = "Failed to send file parts for transmission %s on round %d: %s" + unsetInProgressPanic = "Failed to remove parts from in-progress list for transfer %s after %s" + roundFailureCbErr = "Failed to send parts %v for transfer %s on round %d: %s" + roundFailErr = "round failed" + trackerTimeoutErr = "tracker timed out" + + // getRandomNumParts + getRandomNumPartsRandPanic = "Failed to generate random number of file parts to send: %+v" +) + +// Duration to wait for round to finish before timing out. +const roundResultsTimeout = 60 * time.Second + +// Duration to wait for send batch to fill before sending partial batch. +const pollSleepDuration = 100 * time.Millisecond + +// sendThread waits on the sendQueue channel for parts to send. Once its +// receives a random number between 1 and 11 of file parts, they are encrypted, +// put into cMix messages, and sent to their recipients. Failed messages are +// added to the end of the queue. +func (m *Manager) sendThread(stop *stoppable.Single, getNumParts getRngNum) { + jww.DEBUG.Print("Starting file part sending thread.") + + // Calculate the average amount of data sent via SendManyCMIX + avgNumMessages := (minPartsSendPerRound + maxPartsSendPerRound) / 2 + avgSendSize := avgNumMessages * (8192 / 8) + + // Calculate the delay needed to reach max throughput + delay := time.Duration((int(time.Second) * avgSendSize) / m.maxThroughput) + + // Batch of parts read from the queue to be sent + var partList []queuedPart + + // The size of each batch + var numParts int + + timer := time.NewTimer(pollSleepDuration) + lastSend := time.Time{} + + // Loop forever polling the sendQueue channel for new file parts to send. If + // the channel is empty, then polling is suspended for pollSleepDuration. If + // the network is not healthy, then polling is suspended until the network + // becomes healthy. + for { + timer = time.NewTimer(pollSleepDuration) + select { + case <-stop.Quit(): + // Close the thread when the stoppable is triggered + m.closeSendThread(partList, stop) + + return + case healthy := <-m.healthy: + // If the network is unhealthy, wait until it becomes healthy + jww.TRACE.Print("Suspending file part sending thread: network is " + + "unhealthy") + for !healthy { + healthy = <-m.healthy + } + jww.TRACE.Print("Resuming file part sending thread: network is " + + "healthy") + case part := <-m.sendQueue: + // When a part is received from the queue, add it to the list of + // parts to be sent + + // If the batch is empty (a send just occurred), start a new batch + if partList == nil { + rng := m.rng.GetStream() + numParts = getNumParts(rng) + rng.Close() + partList = make([]queuedPart, 0, numParts) + } + + partList = append(partList, part) + + // If the batch is full, then send the parts + if len(partList) == numParts { + quit := m.handleSend(&partList, &lastSend, delay, stop) + if quit { + return + } + } else { + timer = time.NewTimer(pollSleepDuration) + } + case <-timer.C: + // If the timeout is reached, send an incomplete batch + + // Skip if there are no parts to send + if len(partList) == 0 { + timer = time.NewTimer(pollSleepDuration) + continue + } + + quit := m.handleSend(&partList, &lastSend, delay, stop) + if quit { + return + } + } + } +} + +// closeSendThread safely stops the sending thread by saving unsent parts to the +// queue and setting the stoppable to stopped. +func (m *Manager) closeSendThread(partList []queuedPart, stop *stoppable.Single) { + // Exit the thread if the stoppable is triggered + jww.DEBUG.Print("Stopping file part sending thread: stoppable triggered") + + // Add all the unsent parts back in the queue + for _, part := range partList { + m.sendQueue <- part + } + + stop.ToStopped() +} + +// handleSend handles the sending of parts with bandwidth limitations. On a +// successful send, the partList is cleared and lastSend is updated to the send +// timestamp. When the stoppable is triggered, closing is automatically handled. +// Returns true if the stoppable has been triggered and the sending thread +// should quit. +func (m *Manager) handleSend(partList *[]queuedPart, lastSend *time.Time, + delay time.Duration, stop *stoppable.Single) bool { + // Bandwidth limiter: wait to send until the delay has been reached so that + // the bandwidth is limited to the maximum throughput + if netTime.Since(*lastSend) < delay { + select { + case <-stop.Quit(): + // Close the thread when the stoppable is triggered + m.closeSendThread(*partList, stop) + + return true + case <-time.NewTimer(delay - netTime.Since(*lastSend)).C: + } + } + + // Send all the messages + rng := m.rng.GetStream() + err := m.sendParts(*partList, rng) + *lastSend = netTime.Now() + if err != nil { + jww.FATAL.Panic(err) + } + rng.Close() + + // Clear partList once done + *partList = nil + + return false +} + +// sendParts handles the composing and sending of a cMix message for each part +// in the list. All errors returned are fatal errors. +func (m *Manager) sendParts(partList []queuedPart, rng csprng.Source) error { + + // Build cMix messages + messages, transfers, groupedParts, partsToResend, err := m.buildMessages( + partList, rng) + if err != nil { + return err + } + + // Exit if there are no parts to send + if len(messages) == 0 { + return nil + } + + // Send parts + rid, _, err := m.net.SendManyCMIX(messages, params.GetDefaultCMIX()) + if err != nil { + // If an error occurs, then print a warning and add the file parts back + // to the queue to try sending again + jww.WARN.Printf(sendManyCmixWarn, len(messages), groupedParts, err) + + // Add parts back to queue + for _, partIndex := range partsToResend { + m.sendQueue <- partList[partIndex] + } + + return nil + } + + // Set all parts to in-progress + for tid, transfer := range transfers { + err, exists := transfer.SetInProgress(rid, groupedParts[tid]...) + if err != nil { + return errors.Errorf(setInProgressErr, groupedParts[tid], tid) + } + + transfer.CallProgressCB(nil) + + // Set up tracker waiting for the round to end to update state and + // progress; skip if the tracker has already been launched for this + // transfer and round ID + if !exists { + roundResultCB := m.makeRoundEventCallback(rid, tid, transfer) + err = m.getRoundResults( + []id.Round{rid}, roundResultsTimeout, roundResultCB) + if err != nil { + return errors.Errorf(getRoundResultsErr, rid, err) + } + } + } + + return nil +} + +// buildMessages builds the list of cMix messages to send via SendManyCmix. Also +// returns three separate lists used for later progress tracking. The first, a +// map that contains each unique transfer for each part in the list. The second, +// a map of part numbers being sent grouped by their transfer. The last a list +// of partList index of parts that will be sent. Any part that encounters a +// non-fatal error will be skipped and will not be included in an of the lists. +// All errors returned are fatal errors. +func (m *Manager) buildMessages(partList []queuedPart, rng csprng.Source) ( + []message.TargetedCmixMessage, map[ftCrypto.TransferID]*ftStorage.SentTransfer, + map[ftCrypto.TransferID][]uint16, []int, error) { + messages := make([]message.TargetedCmixMessage, 0, len(partList)) + transfers := map[ftCrypto.TransferID]*ftStorage.SentTransfer{} + groupedParts := map[ftCrypto.TransferID][]uint16{} + partsToResend := make([]int, 0, len(partList)) + + for i, part := range partList { + // Lookup the transfer by the ID; if the transfer does not exist, then + // print a warning and skip this message + transfer, err := m.sent.GetTransfer(part.tid) + if err != nil { + jww.WARN.Printf(noSentTransferWarn, part.tid, part.partNum, err) + continue + } + + // Generate new cMix message with encrypted file part + cmixMsg, err := m.newCmixMessage(transfer, part.partNum, rng) + if err == ftStorage.MaxRetriesErr { + // If the max number of retries has been reached, then report the + // error on the callback, delete the transfer, and skip to the next + // message + go transfer.CallProgressCB(errors.Errorf(maxRetriesErr, err)) + continue + } else if err != nil { + // For all other errors, return an error + return nil, nil, nil, nil, + errors.Errorf(newCmixMessageErr, part.partNum, part.tid, err) + } + + // Construct TargetedCmixMessage + msg := message.TargetedCmixMessage{ + Recipient: transfer.GetRecipient(), + Message: cmixMsg, + } + + // Add to list of messages to send + messages = append(messages, msg) + transfers[part.tid] = transfer + groupedParts[part.tid] = append(groupedParts[part.tid], part.partNum) + partsToResend = append(partsToResend, i) + } + + return messages, transfers, groupedParts, partsToResend, nil +} + +// newCmixMessage creates a new cMix message with an encrypted file part, its +// MAC, and fingerprint. +func (m *Manager) newCmixMessage(transfer *ftStorage.SentTransfer, + partNum uint16, rng csprng.Source) (format.Message, error) { + // Create new empty cMix message + cmixMsg := format.NewMessage(m.store.Cmix().GetGroup().GetP().ByteLen()) + + // Create new empty file part message of size equal to the available payload + // size in the cMix message + partMsg, err := newPartMessage(cmixMsg.ContentsSize()) + if err != nil { + return cmixMsg, err + } + + // Get encrypted file part, file part MAC, padding (nonce), and fingerprint + encPart, mac, padding, fp, err := transfer.GetEncryptedPart( + partNum, partMsg.getPartSize(), rng) + if err != nil { + return cmixMsg, err + } + + // Construct file part message from padding ( + partMsg.setPadding(padding) + partMsg.setPartNum(partNum) + err = partMsg.setPart(encPart) + if err != nil { + return cmixMsg, err + } + + // Construct cMix message + cmixMsg.SetContents(partMsg.marshal()) + cmixMsg.SetKeyFP(fp) + cmixMsg.SetMac(mac) + + return cmixMsg, nil +} + +// makeRoundEventCallback returns an api.RoundEventCallback that is called once +// the round file parts were sent on either succeeds or fails. If the round +// succeeds, then all file parts are marked as finished and the progress +// callback is called with the current progress. If the round fails, then each +// part is removed from the in-progress list, added to the end of the sending +// queue, and the callback called with an error. +func (m *Manager) makeRoundEventCallback(rid id.Round, tid ftCrypto.TransferID, + transfer *ftStorage.SentTransfer) api.RoundEventCallback { + + return func(allSucceeded, timedOut bool, _ map[id.Round]api.RoundResult) { + if allSucceeded { + // If the round succeeded, then set all parts for this round to + // finished and call the progress callback + err := transfer.FinishTransfer(rid) + if err != nil { + jww.FATAL.Panicf(finishTransferPanic, tid, err) + } + + transfer.CallProgressCB(nil) + } else { + // If the round failed, then remove all parts for this round from + // the in-progress list, call the progress callback with an error, + // and add the parts back into the queue + + // Determine the correct error message, for logging + roundErr := roundFailErr + if timedOut { + roundErr = trackerTimeoutErr + } + + jww.WARN.Printf(roundFailureWarn, tid, rid, roundErr) + + // Remove parts from in-progress list + partsToResend, err := transfer.UnsetInProgress(rid) + if err != nil { + jww.FATAL.Panicf(unsetInProgressPanic, tid, roundErr) + } + + // Return the error on the progress callback + cbErr := errors.Errorf( + roundFailureCbErr, partsToResend, tid, rid, roundErr) + transfer.CallProgressCB(cbErr) + + // Add all the unsent parts back in the queue + for _, partNum := range partsToResend { + m.sendQueue <- queuedPart{ + tid: tid, + partNum: partNum, + } + } + } + } +} + +// queueParts sends an entry for each part in a transfer into the sendQueue +// channel. +func (m *Manager) queueParts(tid ftCrypto.TransferID, numParts uint16) { + for i := uint16(0); i < numParts; i++ { + m.sendQueue <- queuedPart{tid, i} + } +} + +// getPartSize determines the maximum size for each file part in bytes. The size +// is calculated based on the content size of a cMix message. Returns an error +// if a file part message cannot fit into a cMix message payload. +func (m *Manager) getPartSize() (int, error) { + // Create new empty cMix message + cmixMsg := format.NewMessage(m.store.Cmix().GetGroup().GetP().ByteLen()) + + // Create new empty file part message of size equal to the available payload + // size in the cMix message + partMsg, err := newPartMessage(cmixMsg.ContentsSize()) + if err != nil { + return 0, err + } + + return partMsg.getPartSize(), nil +} + +// partitionFile splits the file into parts of the specified part size. +func partitionFile(file []byte, partSize int) [][]byte { + // Initialize part list to the correct size + numParts := (len(file) + partSize - 1) / partSize + parts := make([][]byte, 0, numParts) + buff := bytes.NewBuffer(file) + + for n := buff.Next(partSize); len(n) > 0; n = buff.Next(partSize) { + parts = append(parts, n) + } + + return parts +} + +// getRngNum takes in a PRNG source and returns a random number. This type makes +// it easier to test by allowing custom functions that return expected values. +type getRngNum func(rng csprng.Source) int + +// getRandomNumParts returns a random number between minPartsSendPerRound and +// maxPartsSendPerRound, inclusive. +func getRandomNumParts(rng csprng.Source) int { + // Generate random bytes + b, err := csprng.Generate(8, rng) + if err != nil { + jww.FATAL.Panicf(getRandomNumPartsRandPanic, err) + } + + // Convert bytes to integer + num := binary.LittleEndian.Uint64(b) + + // Return random number that is minPartsSendPerRound <= num <= max + return int((num % (maxPartsSendPerRound)) + minPartsSendPerRound) +} diff --git a/fileTransfer/sendNew.go b/fileTransfer/sendNew.go new file mode 100644 index 0000000000000000000000000000000000000000..22282ea6a0a2547c7eff8c83f0ec79e4e3e8699c --- /dev/null +++ b/fileTransfer/sendNew.go @@ -0,0 +1,77 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "github.com/golang/protobuf/proto" + "github.com/pkg/errors" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/interfaces/params" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/primitives/id" +) + +// Error messages. +const ( + protoMarshalErr = "failed to form new file transfer message: %+v" + sendE2eErr = "failed to send new file transfer message via E2E to recipient %s: %+v" +) + +// sendNewFileTransfer sends the initial file transfer message over E2E. +func (m *Manager) sendNewFileTransfer(recipient *id.ID, fileName string, + key ftCrypto.TransferKey, mac []byte, numParts uint16, fileSize uint32, + retry float32, preview []byte) error { + + // Create Send message with marshalled NewFileTransfer + sendMsg, err := newNewFileTransferE2eMessage(recipient, fileName, key, mac, + numParts, fileSize, retry, preview) + if err != nil { + return errors.Errorf(protoMarshalErr, err) + } + + // Send E2E message + rounds, _, _, err := m.net.SendE2E(sendMsg, params.GetDefaultE2E(), nil) + if err != nil && len(rounds) == 0 { + return errors.Errorf(sendE2eErr, recipient, err) + } + + return nil +} + +// newNewFileTransferE2eMessage generates the message.Send for the given +// recipient containing the marshalled NewFileTransfer message. +func newNewFileTransferE2eMessage(recipient *id.ID, fileName string, + key ftCrypto.TransferKey, mac []byte, numParts uint16, fileSize uint32, + retry float32, preview []byte) (message.Send, error) { + + // Construct NewFileTransfer message + protoMsg := &NewFileTransfer{ + FileName: fileName, + TransferKey: key.Bytes(), + TransferMac: mac, + NumParts: uint32(numParts), + Size: fileSize, + Retry: retry, + Preview: preview, + } + + // Marshal the message + marshalledMsg, err := proto.Marshal(protoMsg) + if err != nil { + return message.Send{}, err + } + + // Create message.Send of the type NewFileTransfer + sendMsg := message.Send{ + Recipient: recipient, + Payload: marshalledMsg, + MessageType: message.NewFileTransfer, + } + + return sendMsg, nil +} diff --git a/fileTransfer/sendNew_test.go b/fileTransfer/sendNew_test.go new file mode 100644 index 0000000000000000000000000000000000000000..54342a84e84bedfe30877e101dbf3193200422a4 --- /dev/null +++ b/fileTransfer/sendNew_test.go @@ -0,0 +1,119 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "fmt" + "github.com/golang/protobuf/proto" + "gitlab.com/elixxir/client/interfaces/message" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/primitives/id" + "reflect" + "strings" + "testing" +) + +// Tests that the E2E message sent via Manager.sendNewFileTransfer matches +// expected. +func TestManager_sendNewFileTransfer(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + + recipient := id.NewIdFromString("recipient", id.User, t) + fileName := "testFile" + key, _ := ftCrypto.NewTransferKey(NewPrng(42)) + mac := []byte("transferMac") + numParts, fileSize, retry := uint16(16), uint32(256), float32(1.5) + preview := []byte("filePreview") + + expected, err := newNewFileTransferE2eMessage(recipient, fileName, key, mac, + numParts, fileSize, retry, preview) + if err != nil { + t.Errorf("Failed to create new Send message: %+v", err) + } + + err = m.sendNewFileTransfer(recipient, fileName, key, mac, numParts, + fileSize, retry, preview) + if err != nil { + t.Errorf("sendNewFileTransfer returned an error: %+v", err) + } + + received := m.net.(*testNetworkManager).GetE2eMsg(0) + + if !reflect.DeepEqual(expected, received) { + t.Errorf("Received E2E message does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, received) + } +} + +// Error path: tests that Manager.sendNewFileTransfer returns the expected error +// when SendE2E fails. +func TestManager_sendNewFileTransfer_E2eError(t *testing.T) { + // Create new test manager with a SendE2E error triggered + m := newTestManager(true, nil, nil, nil, t) + + recipient := id.NewIdFromString("recipient", id.User, t) + key, _ := ftCrypto.NewTransferKey(NewPrng(42)) + + expectedErr := fmt.Sprintf(sendE2eErr, recipient, "") + err := m.sendNewFileTransfer(recipient, "", key, nil, 16, 256, 1.5, nil) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("sendNewFileTransfer di dnot return the expected error when "+ + "SendE2E failed.\nexpected: %s\nreceived: %+v", expectedErr, err) + } + + if len(m.net.(*testNetworkManager).e2eMessages) > 0 { + t.Errorf("%d E2E messeage(s) found when SendE2E should have failed.", + len(m.net.(*testNetworkManager).e2eMessages)) + } +} + +// Tests that newNewFileTransferE2eMessage returns a message.Send with the +// correct recipient and message type and that the payload can be unmarshalled +// into the correct NewFileTransfer. +func Test_newNewFileTransferE2eMessage(t *testing.T) { + recipient := id.NewIdFromString("recipient", id.User, t) + key, _ := ftCrypto.NewTransferKey(NewPrng(42)) + expected := &NewFileTransfer{ + FileName: "testFile", + TransferKey: key.Bytes(), + TransferMac: []byte("transferMac"), + NumParts: 16, + Size: 256, + Retry: 1.5, + Preview: []byte("filePreview"), + } + + sendMsg, err := newNewFileTransferE2eMessage(recipient, expected.FileName, + key, expected.TransferMac, uint16(expected.NumParts), expected.Size, + expected.Retry, expected.Preview) + if err != nil { + t.Errorf("newNewFileTransferE2eMessage returned an error: %+v", err) + } + + if sendMsg.MessageType != message.NewFileTransfer { + t.Errorf("Send message has wrong MessageType."+ + "\nexpected: %d\nreceived: %d", message.NewFileTransfer, + sendMsg.MessageType) + } + + if !sendMsg.Recipient.Cmp(recipient) { + t.Errorf("Send message has wrong Recipient."+ + "\nexpected: %s\nreceived: %s", recipient, sendMsg.Recipient) + } + + received := &NewFileTransfer{} + err = proto.Unmarshal(sendMsg.Payload, received) + if err != nil { + t.Errorf("Failed to unmarshal received NewFileTransfer: %+v", err) + } + + if !reflect.DeepEqual(expected, received) { + t.Errorf("Received NewFileTransfer does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, received) + } +} diff --git a/fileTransfer/send_test.go b/fileTransfer/send_test.go new file mode 100644 index 0000000000000000000000000000000000000000..09474a13dbbe41d17bf8d38ef751f355fc5a1072 --- /dev/null +++ b/fileTransfer/send_test.go @@ -0,0 +1,922 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "fmt" + "gitlab.com/elixxir/client/stoppable" + ftStorage "gitlab.com/elixxir/client/storage/fileTransfer" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/ekv" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/primitives/id" + "math/rand" + "reflect" + "sort" + "strings" + "sync" + "testing" + "time" +) + +// Tests that Manager.sendThread successfully sends the parts and reports their +// progress on the callback. +func TestManager_sendThread(t *testing.T) { + m, sti, _ := newTestManagerWithTransfers([]uint16{12, 4, 1}, false, nil, t) + + // Add three transfers + partsToSend := [][]uint16{ + {0, 1, 3, 5, 6, 7}, + {1, 2, 3, 0}, + {0}, + } + + var wg sync.WaitGroup + for i, st := range sti { + wg.Add(1) + go func(i int, st sentTransferInfo) { + for j := 0; j < 2; j++ { + select { + case <-time.NewTimer(20 * time.Millisecond).C: + t.Errorf("Timed out waiting for callback #%d", i) + case r := <-st.cbChan: + if j > 0 { + err := checkSentProgress(r.completed, r.sent, r.arrived, + r.total, false, uint16(len(partsToSend[i])), 0, + st.numParts) + if err != nil { + t.Errorf("%d: %+v", i, err) + } + if r.err != nil { + t.Errorf("Callback returned an error (%d): %+v", i, r.err) + } + } + } + } + wg.Done() + }(i, st) + } + + // Create queued part list, add parts from each transfer, and shuffle + queuedParts := make([]queuedPart, 0, 11) + for i, sendingParts := range partsToSend { + for _, part := range sendingParts { + queuedParts = append(queuedParts, queuedPart{sti[i].tid, part}) + } + } + rand.Shuffle(len(queuedParts), func(i, j int) { + queuedParts[i], queuedParts[j] = queuedParts[j], queuedParts[i] + }) + + // Crate custom getRngNum function that always returns 11 + getNumParts := func(rng csprng.Source) int { + return len(queuedParts) + } + + // Start sending thread + stop := stoppable.NewSingle("testSendThreadStoppable") + go m.sendThread(stop, getNumParts) + + // Add parts to queue + for _, part := range queuedParts { + m.sendQueue <- part + } + + wg.Wait() + + err := stop.Close() + if err != nil { + t.Errorf("Failed to stop stoppable: %+v", err) + } + + if len(m.net.(*testNetworkManager).GetMsgList(0)) != len(queuedParts) { + t.Errorf("Not all messages were received.\nexpected: %d\nreceived: %d", + len(queuedParts), len(m.net.(*testNetworkManager).GetMsgList(0))) + } +} + +// Tests that Manager.sendThread successfully sends a partially filled batch +// of the correct length when its times out waiting for messages. +func TestManager_sendThread_Timeout(t *testing.T) { + m, sti, _ := newTestManagerWithTransfers([]uint16{12, 4, 1}, false, nil, t) + + // Add three transfers + partsToSend := [][]uint16{ + {0, 1, 3, 5, 6, 7}, + {1, 2, 3, 0}, + {0}, + } + + var wg sync.WaitGroup + for i, st := range sti[:1] { + wg.Add(1) + go func(i int, st sentTransferInfo) { + for j := 0; j < 2; j++ { + select { + case <-time.NewTimer(20*time.Millisecond + pollSleepDuration).C: + t.Errorf("Timed out waiting for callback #%d", i) + case r := <-st.cbChan: + if j > 0 { + err := checkSentProgress(r.completed, r.sent, r.arrived, + r.total, false, 5, 0, + st.numParts) + if err != nil { + t.Errorf("%d: %+v", i, err) + } + if r.err != nil { + t.Errorf("Callback returned an error (%d): %+v", i, r.err) + } + } + } + } + wg.Done() + }(i, st) + } + + // Create queued part list, add parts from each transfer, and shuffle + queuedParts := make([]queuedPart, 0, 11) + for i, sendingParts := range partsToSend { + for _, part := range sendingParts { + queuedParts = append(queuedParts, queuedPart{sti[i].tid, part}) + } + } + + // Crate custom getRngNum function that always returns 11 + getNumParts := func(rng csprng.Source) int { + return len(queuedParts) + } + + // Start sending thread + stop := stoppable.NewSingle("testSendThreadStoppable") + go m.sendThread(stop, getNumParts) + + // Add parts to queue + for _, part := range queuedParts[:5] { + m.sendQueue <- part + } + + time.Sleep(pollSleepDuration) + + wg.Wait() + + err := stop.Close() + if err != nil { + t.Errorf("Failed to stop stoppable: %+v", err) + } + + if len(m.net.(*testNetworkManager).GetMsgList(0)) != len(queuedParts[:5]) { + t.Errorf("Not all messages were received.\nexpected: %d\nreceived: %d", + len(queuedParts[:5]), len(m.net.(*testNetworkManager).GetMsgList(0))) + } +} + +// Tests that Manager.sendParts sends all the correct cMix messages and calls +// the progress callbacks with the correct values. +func TestManager_sendParts(t *testing.T) { + m, sti, _ := newTestManagerWithTransfers([]uint16{12, 4, 1}, false, nil, t) + prng := NewPrng(42) + + // Add three transfers + partsToSend := [][]uint16{ + {0, 1, 3, 5, 6, 7}, + {1, 2, 3, 0}, + {0}, + } + + var wg sync.WaitGroup + for i, st := range sti { + wg.Add(1) + go func(i int, st sentTransferInfo) { + for j := 0; j < 2; j++ { + select { + case <-time.NewTimer(20 * time.Millisecond).C: + t.Errorf("Timed out waiting for callback #%d", i) + case r := <-st.cbChan: + if j > 0 { + err := checkSentProgress(r.completed, r.sent, r.arrived, + r.total, false, uint16(len(partsToSend[i])), 0, + st.numParts) + if err != nil { + t.Errorf("%d: %+v", i, err) + } + if r.err != nil { + t.Errorf("Callback returned an error (%d): %+v", i, r.err) + } + } + } + } + wg.Done() + }(i, st) + } + + // Create queued part list, add parts from each transfer, and shuffle + queuedParts := make([]queuedPart, 0, 11) + for i, sendingParts := range partsToSend { + for _, part := range sendingParts { + queuedParts = append(queuedParts, queuedPart{sti[i].tid, part}) + } + } + rand.Shuffle(len(queuedParts), func(i, j int) { + queuedParts[i], queuedParts[j] = queuedParts[j], queuedParts[i] + }) + + err := m.sendParts(queuedParts, prng) + if err != nil { + t.Errorf("sendParts returned an error: %+v", err) + } + + m.net.(*testNetworkManager).GetMsgList(0) + + // Check that each recipient is connected to the correct transfer and that + // the fingerprint is correct + for i, tcm := range m.net.(*testNetworkManager).GetMsgList(0) { + index := 0 + for ; !sti[index].recipient.Cmp(tcm.Recipient); index++ { + + } + transfer, err := m.sent.GetTransfer(sti[index].tid) + if err != nil { + t.Errorf("Failed to get transfer %s: %+v", sti[index].tid, err) + } + + var fpFound bool + for _, fp := range ftCrypto.GenerateFingerprints( + transfer.GetTransferKey(), 15) { + if fp == tcm.Message.GetKeyFP() { + fpFound = true + } + } + if !fpFound { + t.Errorf("Fingeprint %s not found (%d).", tcm.Message.GetKeyFP(), i) + } + } + + wg.Wait() +} + +// Error path: tests that, on SendManyCMIX failure, Manager.sendParts adds the +// parts back into the queue, does not call the callback, and does not update +// the progress. +func TestManager_sendParts_SendManyCmixError(t *testing.T) { + m, sti, _ := newTestManagerWithTransfers([]uint16{12, 4, 1}, true, nil, t) + prng := NewPrng(42) + partsToSend := [][]uint16{ + {0, 1, 3, 5, 6, 7}, + {1, 2, 3, 0}, + {0}, + } + + var wg sync.WaitGroup + for i, st := range sti { + wg.Add(1) + go func(i int, st sentTransferInfo) { + for j := 0; j < 2; j++ { + select { + case <-time.NewTimer(10 * time.Millisecond).C: + if j < 1 { + t.Errorf("Timed out waiting for callback #%d (%d)", i, j) + } + case r := <-st.cbChan: + if j > 0 { + t.Errorf("Callback called on send failure: %+v", r) + } + } + } + wg.Done() + }(i, st) + } + + // Create queued part list, add parts from each transfer, and shuffle + queuedParts := make([]queuedPart, 0, 11) + for i, sendingParts := range partsToSend { + for _, part := range sendingParts { + queuedParts = append(queuedParts, queuedPart{sti[i].tid, part}) + } + } + + err := m.sendParts(queuedParts, prng) + if err != nil { + t.Errorf("sendParts returned an error: %+v", err) + } + + if len(m.net.(*testNetworkManager).GetMsgList(0)) > 0 { + t.Errorf("Sent %d cMix message(s) when sending should have failed.", + len(m.net.(*testNetworkManager).GetMsgList(0))) + } + + if len(m.sendQueue) != len(queuedParts) { + t.Errorf("Failed to add all parts to queue after send failure."+ + "\nexpected: %d\nreceived: %d", len(queuedParts), len(m.sendQueue)) + } + + wg.Wait() +} + +// Tests that Manager.buildMessages returns the expected values for a group +// of 11 file parts from three different transfers. +func TestManager_buildMessages(t *testing.T) { + m, sti, _ := newTestManagerWithTransfers([]uint16{12, 4, 1}, false, nil, t) + prng := NewPrng(42) + partsToSend := [][]uint16{ + {0, 1, 3, 5, 6, 7}, + {1, 2, 3, 0}, + {0}, + } + idMap := map[ftCrypto.TransferID]int{ + sti[0].tid: 0, + sti[1].tid: 1, + sti[2].tid: 2, + } + + // Create queued part list, add parts from each transfer, and shuffle + queuedParts := make([]queuedPart, 0, 11) + for i, sendingParts := range partsToSend { + for _, part := range sendingParts { + queuedParts = append(queuedParts, queuedPart{sti[i].tid, part}) + } + } + rand.Shuffle(len(queuedParts), func(i, j int) { + queuedParts[i], queuedParts[j] = queuedParts[j], queuedParts[i] + }) + + // Build the messages + prng = NewPrng(2) + messages, transfers, groupedParts, partsToResend, err := m.buildMessages( + queuedParts, prng) + if err != nil { + t.Errorf("buildMessages returned an error: %+v", err) + } + + // Check that the transfer map has all the transfers + for i, st := range sti { + _, exists := transfers[st.tid] + if !exists { + t.Errorf("Transfer %s (#%d) not in transfers map.", st.tid, i) + } + } + + // Check that partsToResend contains all parts because none should have + // errored + if len(partsToResend) != len(queuedParts) { + sort.SliceStable(partsToResend, func(i, j int) bool { + return partsToResend[i] < partsToResend[j] + }) + for i, j := range partsToResend { + if i != j { + t.Errorf("Part at index %d not found. Found %d.", i, j) + } + } + } + + // Check that all the parts are present and grouped correctly + for tid, partNums := range groupedParts { + sort.SliceStable(partNums, func(i, j int) bool { + return partNums[i] < partNums[j] + }) + sort.SliceStable(partsToSend[idMap[tid]], func(i, j int) bool { + return partsToSend[idMap[tid]][i] < partsToSend[idMap[tid]][j] + }) + if !reflect.DeepEqual(partsToSend[idMap[tid]], partNums) { + t.Errorf("Incorrect parts for transfer %s."+ + "\nexpected: %v\nreceived: %v", + tid, partsToSend[idMap[tid]], partNums) + } + } + + // Check that each recipient is connected to the correct transfer and that + // the fingerprint is correct + for i, tcm := range messages { + index := 0 + for ; !sti[index].recipient.Cmp(tcm.Recipient); index++ { + + } + transfer, err := m.sent.GetTransfer(sti[index].tid) + if err != nil { + t.Errorf("Failed to get transfer %s: %+v", sti[index].tid, err) + } + + var fpFound bool + for _, fp := range ftCrypto.GenerateFingerprints( + transfer.GetTransferKey(), 15) { + if fp == tcm.Message.GetKeyFP() { + fpFound = true + } + } + if !fpFound { + t.Errorf("Fingeprint %s not found (%d).", tcm.Message.GetKeyFP(), i) + } + } +} + +// Tests that Manager.buildMessages skips file parts with deleted transfers or +// transfers that have run out of fingerprints. +func TestManager_buildMessages_MessageBuildFailureError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + + // Add transfer + + type callbackResults struct { + completed bool + sent, arrived, total uint16 + err error + } + + callbackChan := make(chan callbackResults, 10) + progressCB := func(completed bool, sent, arrived, total uint16, err error) { + callbackChan <- callbackResults{completed, sent, arrived, total, err} + } + + done0, done1 := make(chan bool), make(chan bool) + go func() { + for i := 0; i < 2; i++ { + select { + case <-time.NewTimer(20 * time.Millisecond).C: + t.Error("Timed out waiting for callback.") + case r := <-callbackChan: + switch i { + case 0: + done0 <- true + case 1: + expectedErr := fmt.Sprintf( + maxRetriesErr, ftStorage.MaxRetriesErr) + if i > 0 && (r.err == nil || r.err.Error() != expectedErr) { + t.Errorf("Callback received unexpected error when max "+ + "retries should have been reached."+ + "\nexpected: %s\nreceived: %v", expectedErr, r.err) + } + done1 <- true + } + } + } + }() + + prng := NewPrng(42) + recipient := id.NewIdFromString("recipient", id.User, t) + key, _ := ftCrypto.NewTransferKey(prng) + _, parts := newFile(4, 64, prng, t) + tid, err := m.sent.AddTransfer( + recipient, key, parts, 3, progressCB, time.Millisecond, prng) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + <-done0 + + // Create queued part list add parts + queuedParts := []queuedPart{ + {tid, 0}, + {tid, 1}, + {tid, 2}, + {tid, 3}, + {ftCrypto.UnmarshalTransferID([]byte("invalidID")), 3}, + } + + // Build the messages + prng = NewPrng(46) + messages, transfers, groupedParts, partsToResend, err := m.buildMessages( + queuedParts, prng) + if err != nil { + t.Errorf("buildMessages returned an error: %+v", err) + } + + <-done1 + + // Check that partsToResend contains all parts because none should have + // errored + if len(partsToResend) != len(queuedParts)-2 { + sort.SliceStable(partsToResend, func(i, j int) bool { + return partsToResend[i] < partsToResend[j] + }) + for i, j := range partsToResend { + if i != j { + t.Errorf("Part at index %d not found. Found %d.", i, j) + } + } + } + + if len(messages) != 3 { + t.Errorf("Length of messages incorrect."+ + "\nexpected: %d\nreceived: %d", 3, len(messages)) + } + + if len(transfers) != 1 { + t.Errorf("Length of transfers incorrect."+ + "\nexpected: %d\nreceived: %d", 1, len(transfers)) + } + + if len(groupedParts) != 1 { + t.Errorf("Length of grouped parts incorrect."+ + "\nexpected: %d\nreceived: %d", 1, len(groupedParts)) + } +} + +// Tests that Manager.buildMessages returns the expected error when a queued +// part has an invalid part number. +func TestManager_buildMessages_NewCmixMessageError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + + // Add transfer + prng := NewPrng(42) + recipient := id.NewIdFromString("recipient", id.User, t) + key, _ := ftCrypto.NewTransferKey(prng) + _, parts := newFile(12, 405, prng, t) + tid, err := m.sent.AddTransfer(recipient, key, parts, 15, nil, 0, prng) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + // Create queued part list and add a single part with an invalid part number + queuedParts := []queuedPart{{tid, uint16(len(parts))}} + + // Build the messages + expectedErr := fmt.Sprintf( + newCmixMessageErr, queuedParts[0].partNum, tid, "") + _, _, _, _, err = m.buildMessages(queuedParts, prng) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("buildMessages did not return the expected error when the "+ + "queuedPart part number is invalid.\nexpected: %s\nreceived: %+v", + expectedErr, err) + } + +} + +// Tests that Manager.newCmixMessage returns a format.Message with the correct +// MAC, fingerprint, and contents. +func TestManager_newCmixMessage(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + prng := NewPrng(42) + tid, _ := ftCrypto.NewTransferID(prng) + key, _ := ftCrypto.NewTransferKey(prng) + recipient := id.NewIdFromString("recipient", id.User, t) + partSize, _ := m.getPartSize() + _, parts := newFile(16, uint32(partSize), prng, t) + numFps := calcNumberOfFingerprints(uint16(len(parts)), 1.5) + kv := versioned.NewKV(make(ekv.Memstore)) + + transfer, err := ftStorage.NewSentTransfer(recipient, tid, key, parts, + numFps, nil, 0, kv) + if err != nil { + t.Errorf("Failed to create a new SentTransfer: %+v", err) + } + + cmixMsg, err := m.newCmixMessage(transfer, 0, prng) + if err != nil { + t.Errorf("newCmixMessage returned an error: %+v", err) + } + + fp := ftCrypto.GenerateFingerprints(key, numFps)[0] + + if cmixMsg.GetKeyFP() != fp { + t.Errorf("cMix message has wrong fingerprint."+ + "\nexpected: %s\nrecieved: %s", fp, cmixMsg.GetKeyFP()) + } + + partMsg, err := unmarshalPartMessage(cmixMsg.GetContents()) + if err != nil { + t.Errorf("Failed to unmarshal part message: %+v", err) + } + + decrPart, err := ftCrypto.DecryptPart(key, partMsg.getPart(), + partMsg.getPadding(), cmixMsg.GetMac(), partMsg.getPartNum()) + if err != nil { + t.Errorf("Failed to decrypt file part: %+v", err) + } + + if !bytes.Equal(decrPart, parts[0]) { + t.Errorf("Decrypted part does not match expected."+ + "\nexpected: %q\nreceived: %q", parts[0], decrPart) + } +} + +// Tests that Manager.makeRoundEventCallback returns a callback that calls the +// progress callback when a round succeeds. +func TestManager_makeRoundEventCallback(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + + // Add transfer + type callbackResults struct { + completed bool + sent, arrived, total uint16 + err error + } + + callbackChan := make(chan callbackResults, 10) + progressCB := func(completed bool, sent, arrived, total uint16, err error) { + callbackChan <- callbackResults{completed, sent, arrived, total, err} + } + + done0, done1 := make(chan bool), make(chan bool) + go func() { + for i := 0; i < 2; i++ { + select { + case <-time.NewTimer(10 * time.Millisecond).C: + t.Error("Timed out waiting for callback.") + case r := <-callbackChan: + switch i { + case 0: + done0 <- true + case 1: + if r.err != nil { + t.Errorf("Callback received error: %v", r.err) + } + if !r.completed { + t.Error("File not marked as completed.") + } + done1 <- true + } + } + } + }() + + prng := NewPrng(42) + recipient := id.NewIdFromString("recipient", id.User, t) + key, _ := ftCrypto.NewTransferKey(prng) + _, parts := newFile(4, 64, prng, t) + tid, err := m.sent.AddTransfer( + recipient, key, parts, 6, progressCB, time.Millisecond, prng) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + <-done0 + + // Create queued part list add parts + queuedParts := []queuedPart{ + {tid, 0}, + {tid, 1}, + {tid, 2}, + {tid, 3}, + } + + rid := id.Round(42) + + _, transfers, groupedParts, _, err := m.buildMessages(queuedParts, prng) + + // Set all parts to in-progress + for tid, transfer := range transfers { + _, _ = transfer.SetInProgress(rid, groupedParts[tid]...) + } + + roundEventCB := m.makeRoundEventCallback(rid, tid, transfers[tid]) + + roundEventCB(true, false, nil) + + <-done1 +} + +// Tests that Manager.makeRoundEventCallback returns a callback that calls the +// progress callback with the correct error when a round fails. +func TestManager_makeRoundEventCallback_RoundFailure(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + + rid := id.Round(42) + + // Add transfer + type callbackResults struct { + completed bool + sent, arrived, total uint16 + err error + } + + callbackChan := make(chan callbackResults, 10) + progressCB := func(completed bool, sent, arrived, total uint16, err error) { + callbackChan <- callbackResults{completed, sent, arrived, total, err} + } + + prng := NewPrng(42) + recipient := id.NewIdFromString("recipient", id.User, t) + key, _ := ftCrypto.NewTransferKey(prng) + _, parts := newFile(4, 64, prng, t) + tid, err := m.sent.AddTransfer( + recipient, key, parts, 6, progressCB, time.Millisecond, prng) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + partsToSend := []uint16{0, 1, 2, 3} + + done0, done1 := make(chan bool), make(chan bool) + go func() { + for i := 0; i < 2; i++ { + select { + case <-time.NewTimer(10 * time.Millisecond).C: + t.Error("Timed out waiting for callback.") + case r := <-callbackChan: + switch i { + case 0: + done0 <- true + case 1: + expectedErr := fmt.Sprintf( + roundFailureCbErr, partsToSend, tid, rid, roundFailErr) + if r.err == nil || !strings.Contains(r.err.Error(), expectedErr) { + t.Errorf("Callback received unexpected error when round "+ + "failed.\nexpected: %s\nreceived: %v", expectedErr, r.err) + } + done1 <- true + } + } + } + }() + + // Create queued part list add parts + queuedParts := make([]queuedPart, len(partsToSend)) + for i := range queuedParts { + queuedParts[i] = queuedPart{tid, uint16(i)} + } + + _, transfers, groupedParts, _, err := m.buildMessages(queuedParts, prng) + + <-done0 + + // Set all parts to in-progress + for tid, transfer := range transfers { + _, _ = transfer.SetInProgress(rid, groupedParts[tid]...) + } + + roundEventCB := m.makeRoundEventCallback(rid, tid, transfers[tid]) + + roundEventCB(false, false, nil) + + <-done1 +} + +// Error path: tests that Manager.makeRoundEventCallback panics when +// SentTransfer.FinishTransfer returns an error because the file parts had not +// previously been set to in-progress. +func TestManager_makeRoundEventCallback_FinishTransferError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + + prng := NewPrng(42) + recipient := id.NewIdFromString("recipient", id.User, t) + key, _ := ftCrypto.NewTransferKey(prng) + _, parts := newFile(4, 64, prng, t) + tid, err := m.sent.AddTransfer( + recipient, key, parts, 6, nil, time.Millisecond, prng) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + // Create queued part list add parts + queuedParts := []queuedPart{{tid, 0}, {tid, 1}, {tid, 2}, {tid, 3}} + rid := id.Round(42) + + _, transfers, _, _, err := m.buildMessages(queuedParts, prng) + + expectedErr := strings.Split(finishTransferPanic, "%")[0] + defer func() { + err := recover() + if err == nil || !strings.Contains(err.(string), expectedErr) { + t.Errorf("makeRoundEventCallback failed to panic or returned the "+ + "wrong error.\nexpected: %s\nreceived: %+v", expectedErr, err) + } + }() + + roundEventCB := m.makeRoundEventCallback(rid, tid, transfers[tid]) + + roundEventCB(true, false, nil) +} + +// Tests that Manager.queueParts adds all the expected parts to the sendQueue +// channel. +func TestManager_queueParts(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + csPrng := NewPrng(42) + prng := rand.New(rand.NewSource(42)) + + // Create map of expected parts + n := 26 + parts := make(map[ftCrypto.TransferID]map[uint16]bool, n) + for i := 0; i < n; i++ { + tid, _ := ftCrypto.NewTransferID(csPrng) + o := uint16(prng.Int31n(64)) + parts[tid] = make(map[uint16]bool, o) + for j := uint16(0); j < o; j++ { + parts[tid][j] = true + } + } + + // Queue all parts + for tid, partList := range parts { + m.queueParts(tid, uint16(len(partList))) + } + + // Read through all parts in channel ensuring none are missing or duplicate + for len(parts) > 0 { + select { + case part := <-m.sendQueue: + partList, exists := parts[part.tid] + if !exists { + t.Errorf("Could not find transfer %s", part.tid) + } else { + _, exists := partList[part.partNum] + if !exists { + t.Errorf("Could not find part number %d for transfer %s", + part.partNum, part.tid) + } + + delete(parts[part.tid], part.partNum) + + if len(parts[part.tid]) == 0 { + delete(parts, part.tid) + } + } + case <-time.NewTimer(time.Millisecond).C: + if len(parts) != 0 { + t.Errorf("Timed out reading parts from channel. Failed to "+ + "read parts for %d/%d transfers.", len(parts), n) + } + return + default: + if len(parts) != 0 { + t.Errorf("Failed to read parts for %d/%d transfers.", + len(parts), n) + } + } + } +} + +// Tests that the part size returned by Manager.getPartSize matches the manually +// calculated part size. +func TestManager_getPartSize(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + + // Calculate the expected part size + primeByteLen := m.store.Cmix().GetGroup().GetP().ByteLen() + cmixMsgUsedLen := format.AssociatedDataSize + filePartMsgUsedLen := fmMinSize + expected := 2*primeByteLen - cmixMsgUsedLen - filePartMsgUsedLen + + // Get the part size + partSize, err := m.getPartSize() + if err != nil { + t.Errorf("getPartSize returned an error: %+v", err) + } + + if expected != partSize { + t.Errorf("Returned part size does not match expected."+ + "\nexpected: %d\nreceived: %d", expected, partSize) + } +} + +// Tests that partitionFile partitions the given file into the expected parts. +func Test_partitionFile(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + partSize := 96 + fileData, expectedParts := newFile(24, uint32(partSize), prng, t) + + receivedParts := partitionFile(fileData, partSize) + + if !reflect.DeepEqual(expectedParts, receivedParts) { + t.Errorf("File parts do not match expected."+ + "\nexpected: %q\nreceived: %q", expectedParts, receivedParts) + } + + fullFile := bytes.Join(receivedParts, nil) + if !bytes.Equal(fileData, fullFile) { + t.Errorf("Full file does not match expected."+ + "\nexpected: %q\nreceived: %q", fileData, fullFile) + } +} + +// Tests that getRandomNumParts returns values between minPartsSendPerRound and +// maxPartsSendPerRound. +func Test_getRandomNumParts(t *testing.T) { + prng := NewPrng(42) + + for i := 0; i < 100; i++ { + num := getRandomNumParts(prng) + if num < minPartsSendPerRound { + t.Errorf("Number %d is smaller than minimum %d (%d)", + num, minPartsSendPerRound, i) + } else if num > maxPartsSendPerRound { + t.Errorf("Number %d is greater than maximum %d (%d)", + num, maxPartsSendPerRound, i) + } + } +} + +// Tests that getRandomNumParts panics for a PRNG that errors. +func Test_getRandomNumParts_PrngPanic(t *testing.T) { + prng := NewPrngErr() + + defer func() { + // Error if a panic does not occur + if err := recover(); err == nil { + t.Error("Failed to panic for broken PRNG.") + } + }() + + _ = getRandomNumParts(prng) +} + +// Tests that getRandomNumParts satisfies the getRngNum type. +func Test_getRandomNumParts_GetRngNumType(t *testing.T) { + var _ getRngNum = getRandomNumParts +} diff --git a/fileTransfer/utils_test.go b/fileTransfer/utils_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e0a0fa6f00c12004a90baeba5c0e20b437e0d9a1 --- /dev/null +++ b/fileTransfer/utils_test.go @@ -0,0 +1,564 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "encoding/binary" + "github.com/pkg/errors" + "gitlab.com/elixxir/client/api" + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/network/gateway" + "gitlab.com/elixxir/client/stoppable" + "gitlab.com/elixxir/client/storage" + ftStorage "gitlab.com/elixxir/client/storage/fileTransfer" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/client/switchboard" + "gitlab.com/elixxir/comms/network" + "gitlab.com/elixxir/crypto/e2e" + "gitlab.com/elixxir/crypto/fastRNG" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/ekv" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/comms/connect" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/id/ephemeral" + "gitlab.com/xx_network/primitives/ndf" + "io" + "math/rand" + "strconv" + "sync" + "testing" + "time" +) + +// newFile generates a file with random data of size numParts * partSize. +// Returns the full file and the file parts. If the partSize allows, each part +// starts with a "|<[PART_001]" and ends with a ">|". +func newFile(numParts uint16, partSize uint32, prng io.Reader, t *testing.T) ( + []byte, [][]byte) { + const ( + prefix = "|<[PART_%3d]" + suffix = ">|" + ) + // Create file buffer of the expected size + fileBuff := bytes.NewBuffer(make([]byte, 0, uint32(numParts)*partSize)) + partList := make([][]byte, numParts) + + // Create new rand.Rand with the seed generated from the io.Reader + b := make([]byte, 8) + _, err := prng.Read(b) + if err != nil { + t.Errorf("Failed to generate random seed: %+v", err) + } + seed := binary.LittleEndian.Uint64(b) + randPrng := rand.New(rand.NewSource(int64(seed))) + + for partNum := range partList { + s := RandStringBytes(int(partSize), randPrng) + if len(s) >= (len(prefix) + len(suffix)) { + partList[partNum] = []byte( + prefix + s[:len(s)-(len(prefix)+len(suffix))] + suffix) + } else { + partList[partNum] = []byte(s) + } + + fileBuff.Write(partList[partNum]) + } + + return fileBuff.Bytes(), partList +} + +const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +// RandStringBytes generates a random string of length n consisting of the +// characters in letterBytes. +func RandStringBytes(n int, prng *rand.Rand) string { + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[prng.Intn(len(letterBytes))] + } + return string(b) +} + +// checkReceivedProgress compares the output of ReceivedTransfer.GetProgress to +// expected values. +func checkReceivedProgress(completed bool, received, total uint16, + eCompleted bool, eReceived, eTotal uint16) error { + if eCompleted != completed || eReceived != received || eTotal != total { + return errors.Errorf("Returned progress does not match expected."+ + "\n completed received total"+ + "\nexpected: %5t %3d %3d"+ + "\nreceived: %5t %3d %3d", + eCompleted, eReceived, eTotal, + completed, received, total) + } + + return nil +} + +// checkSentProgress compares the output of SentTransfer.GetProgress to expected +// values. +func checkSentProgress(completed bool, sent, arrived, total uint16, + eCompleted bool, eSent, eArrived, eTotal uint16) error { + if eCompleted != completed || eSent != sent || eArrived != arrived || + eTotal != total { + return errors.Errorf("Returned progress does not match expected."+ + "\n completed sent arrived total"+ + "\nexpected: %5t %3d %3d %3d"+ + "\nreceived: %5t %3d %3d %3d", + eCompleted, eSent, eArrived, eTotal, + completed, sent, arrived, total) + } + + return nil +} + +//////////////////////////////////////////////////////////////////////////////// +// PRNG // +//////////////////////////////////////////////////////////////////////////////// + +// Prng is a PRNG that satisfies the csprng.Source interface. +type Prng struct{ prng io.Reader } + +func NewPrng(seed int64) csprng.Source { return &Prng{rand.New(rand.NewSource(seed))} } +func (s *Prng) Read(b []byte) (int, error) { return s.prng.Read(b) } +func (s *Prng) SetSeed([]byte) error { return nil } + +// PrngErr is a PRNG that satisfies the csprng.Source interface. However, it +// always returns an error +type PrngErr struct{} + +func NewPrngErr() csprng.Source { return &PrngErr{} } +func (s *PrngErr) Read([]byte) (int, error) { return 0, errors.New("ReadFailure") } +func (s *PrngErr) SetSeed([]byte) error { return errors.New("SetSeedFailure") } + +//////////////////////////////////////////////////////////////////////////////// +// Test Managers // +//////////////////////////////////////////////////////////////////////////////// + +// newTestManager creates a new Manager that has groups stored for testing. One +// of the groups in the list is also returned. +func newTestManager(sendErr bool, sendChan, sendE2eChan chan message.Receive, + receiveCB interfaces.ReceiveCallback, t *testing.T) *Manager { + + kv := versioned.NewKV(make(ekv.Memstore)) + sent, err := ftStorage.NewSentFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to createw new SentFileTransfers: %+v", err) + } + received, err := ftStorage.NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to createw new ReceivedFileTransfers: %+v", err) + } + + net := newTestNetworkManager(sendErr, sendChan, sendE2eChan, t) + + // Register channel for health tracking + healthy := make(chan bool) + net.GetHealthTracker().AddChannel(healthy) + + // Returns an error on function and round failure on callback if sendErr is + // set; otherwise, it reports round successes and returns nil + rr := func(_ []id.Round, _ time.Duration, cb api.RoundEventCallback) error { + cb(!sendErr, false, nil) + if sendErr { + return errors.New("SendError") + } + + return nil + } + + avgNumMessages := (minPartsSendPerRound + maxPartsSendPerRound) / 2 + avgSendSize := avgNumMessages * (8192 / 8) + + m := &Manager{ + receiveCB: receiveCB, + sent: sent, + received: received, + sendQueue: make(chan queuedPart, sendQueueBuffLen), + maxThroughput: int(time.Second) * avgSendSize, + store: storage.InitTestingSession(t), + swb: switchboard.New(), + net: net, + healthy: healthy, + rng: fastRNG.NewStreamGenerator(1000, 10, csprng.NewSystemRNG), + getRoundResults: rr, + } + + return m +} + +// newTestManagerWithTransfers creates a new test manager with transfers added +// to it. +func newTestManagerWithTransfers(numParts []uint16, sendErr bool, + receiveCB interfaces.ReceiveCallback, t *testing.T) ( + *Manager, []sentTransferInfo, []receivedTransferInfo) { + m := newTestManager(sendErr, nil, nil, receiveCB, t) + sti := make([]sentTransferInfo, len(numParts)) + rti := make([]receivedTransferInfo, len(numParts)) + var err error + + partSize, err := m.getPartSize() + if err != nil { + t.Errorf("Failed to get part size: %+v", err) + } + + for i := range sti { + prng := NewPrng(int64(42 + i)) + file, parts := newFile(numParts[i], uint32(partSize), prng, t) + key, _ := ftCrypto.NewTransferKey(prng) + + sti[i] = sentTransferInfo{ + recipient: id.NewIdFromString("recipient"+strconv.Itoa(i), id.User, t), + key: key, + parts: parts, + file: file, + numParts: numParts[i], + numFps: calcNumberOfFingerprints(numParts[i], 0.5), + retry: 0.5, + period: time.Millisecond, + prng: prng, + } + + cbChan := make(chan sentProgressResults, 6) + + cb := func(completed bool, sent, arrived, total uint16, err error) { + cbChan <- sentProgressResults{completed, sent, arrived, total, err} + } + + sti[i].cbChan = cbChan + sti[i].cb = cb + + sti[i].tid, err = m.sent.AddTransfer(sti[i].recipient, sti[i].key, + sti[i].parts, sti[i].numFps, sti[i].cb, sti[i].period, sti[i].prng) + if err != nil { + t.Errorf("Failed to add sent transfer #%d: %+v", i, err) + } + } + + for i := range rti { + prng := NewPrng(int64(42 + i)) + file, parts := newFile(numParts[i], uint32(partSize), prng, t) + key, _ := ftCrypto.NewTransferKey(prng) + + rti[i] = receivedTransferInfo{ + key: key, + mac: ftCrypto.CreateTransferMAC(file, key), + parts: parts, + file: file, + fileSize: uint32(len(file)), + numParts: numParts[i], + numFps: calcNumberOfFingerprints(numParts[i], 0.5), + retry: 0.5, + period: time.Millisecond, + prng: prng, + } + + cbChan := make(chan receivedProgressResults, 6) + + cb := func(completed bool, received, total uint16, err error) { + cbChan <- receivedProgressResults{completed, received, total, err} + } + + rti[i].cbChan = cbChan + rti[i].cb = cb + + rti[i].tid, err = m.received.AddTransfer(rti[i].key, rti[i].mac, + rti[i].fileSize, rti[i].numParts, rti[i].numFps, rti[i].prng) + if err != nil { + t.Errorf("Failed to add received transfer #%d: %+v", i, err) + } + } + + return m, sti, rti +} + +// receivedFtResults is used to return received new file transfer results on a +// channel from a callback. +type receivedFtResults struct { + tid ftCrypto.TransferID + fileName string + sender *id.ID + size uint32 + preview []byte +} + +// sentProgressResults is used to return sent progress results on a channel from +// a callback. +type sentProgressResults struct { + completed bool + sent, arrived, total uint16 + err error +} + +// sentTransferInfo contains information on a sent transfer. +type sentTransferInfo struct { + recipient *id.ID + key ftCrypto.TransferKey + tid ftCrypto.TransferID + parts [][]byte + file []byte + numParts uint16 + numFps uint16 + retry float32 + cb interfaces.SentProgressCallback + cbChan chan sentProgressResults + period time.Duration + prng csprng.Source +} + +// receivedProgressResults is used to return received progress results on a +// channel from a callback. +type receivedProgressResults struct { + completed bool + received, total uint16 + err error +} + +// receivedTransferInfo contains information on a received transfer. +type receivedTransferInfo struct { + key ftCrypto.TransferKey + tid ftCrypto.TransferID + mac []byte + parts [][]byte + file []byte + fileSize uint32 + numParts uint16 + numFps uint16 + retry float32 + cb interfaces.ReceivedProgressCallback + cbChan chan receivedProgressResults + period time.Duration + prng csprng.Source +} + +//////////////////////////////////////////////////////////////////////////////// +// Test Network Manager // +//////////////////////////////////////////////////////////////////////////////// + +func newTestNetworkManager(sendErr bool, sendChan, + sendE2eChan chan message.Receive, t *testing.T) interfaces.NetworkManager { + instanceComms := &connect.ProtoComms{ + Manager: connect.NewManagerTesting(t), + } + + thisInstance, err := network.NewInstanceTesting(instanceComms, getNDF(), + getNDF(), nil, nil, t) + if err != nil { + t.Fatalf("Failed to create new test instance: %v", err) + } + + return &testNetworkManager{ + instance: thisInstance, + rid: 0, + messages: make(map[id.Round][]message.TargetedCmixMessage), + sendErr: sendErr, + health: newTestHealthTracker(), + sendChan: sendChan, + sendE2eChan: sendE2eChan, + } +} + +// testNetworkManager is a test implementation of NetworkManager interface. +type testNetworkManager struct { + instance *network.Instance + updateRid bool + rid id.Round + messages map[id.Round][]message.TargetedCmixMessage + e2eMessages []message.Send + sendErr bool + health testHealthTracker + sendChan chan message.Receive + sendE2eChan chan message.Receive + sync.RWMutex +} + +func (tnm *testNetworkManager) GetMsgList(rid id.Round) []message.TargetedCmixMessage { + tnm.RLock() + defer tnm.RUnlock() + return tnm.messages[rid] +} + +func (tnm *testNetworkManager) GetE2eMsg(i int) message.Send { + tnm.RLock() + defer tnm.RUnlock() + return tnm.e2eMessages[i] +} + +func (tnm *testNetworkManager) SendE2E(msg message.Send, _ params.E2E, _ *stoppable.Single) ( + []id.Round, e2e.MessageID, time.Time, error) { + tnm.Lock() + defer tnm.Unlock() + + if tnm.sendErr { + return nil, e2e.MessageID{}, time.Time{}, errors.New("SendE2E error") + } + + tnm.e2eMessages = append(tnm.e2eMessages, msg) + + if tnm.sendE2eChan != nil { + tnm.sendE2eChan <- message.Receive{ + Payload: msg.Payload, + MessageType: msg.MessageType, + } + } + + return []id.Round{0, 1, 2, 3}, e2e.MessageID{}, time.Time{}, nil +} + +func (tnm *testNetworkManager) SendUnsafe(message.Send, params.Unsafe) ([]id.Round, error) { + return []id.Round{}, nil +} + +func (tnm *testNetworkManager) SendCMIX(format.Message, *id.ID, params.CMIX) (id.Round, ephemeral.Id, error) { + return 0, ephemeral.Id{}, nil +} + +func (tnm *testNetworkManager) SendManyCMIX(messages []message.TargetedCmixMessage, _ params.CMIX) ( + id.Round, []ephemeral.Id, error) { + tnm.Lock() + defer func() { + // Increment the round every two calls to SendManyCMIX + if tnm.updateRid { + tnm.rid++ + tnm.updateRid = false + } else { + tnm.updateRid = true + } + }() + defer tnm.Unlock() + + if tnm.sendErr { + return 0, nil, errors.New("SendManyCMIX error") + } + + tnm.messages[tnm.rid] = messages + + if tnm.sendChan != nil { + for _, msg := range messages { + tnm.sendChan <- message.Receive{ + Payload: msg.Message.Marshal(), + RoundId: tnm.rid, + } + } + } + + return tnm.rid, nil, nil +} + +type dummyEventMgr struct{} + +func (d *dummyEventMgr) Report(int, string, string, string) {} +func (tnm *testNetworkManager) GetEventManager() interfaces.EventManager { + return &dummyEventMgr{} +} + +func (tnm *testNetworkManager) GetInstance() *network.Instance { return tnm.instance } +func (tnm *testNetworkManager) GetHealthTracker() interfaces.HealthTracker { return tnm.health } +func (tnm *testNetworkManager) Follow(interfaces.ClientErrorReport) (stoppable.Stoppable, error) { + return nil, nil +} +func (tnm *testNetworkManager) CheckGarbledMessages() {} +func (tnm *testNetworkManager) InProgressRegistrations() int { return 0 } +func (tnm *testNetworkManager) GetSender() *gateway.Sender { return nil } +func (tnm *testNetworkManager) GetAddressSize() uint8 { return 0 } +func (tnm *testNetworkManager) RegisterAddressSizeNotification(string) (chan uint8, error) { + return nil, nil +} +func (tnm *testNetworkManager) UnregisterAddressSizeNotification(string) {} +func (tnm *testNetworkManager) SetPoolFilter(gateway.Filter) {} +func (tnm *testNetworkManager) GetVerboseRounds() string { return "" } + +type testHealthTracker struct { + chIndex, fnIndex uint64 + channels map[uint64]chan bool + funcs map[uint64]func(bool) + healthy bool +} + +//////////////////////////////////////////////////////////////////////////////// +// Test Health Tracker // +//////////////////////////////////////////////////////////////////////////////// + +func newTestHealthTracker() testHealthTracker { + return testHealthTracker{ + chIndex: 0, + fnIndex: 0, + channels: make(map[uint64]chan bool), + funcs: make(map[uint64]func(bool)), + healthy: true, + } +} + +func (tht testHealthTracker) AddChannel(c chan bool) uint64 { + tht.channels[tht.chIndex] = c + tht.chIndex++ + return tht.chIndex - 1 +} + +func (tht testHealthTracker) RemoveChannel(chanID uint64) { delete(tht.channels, chanID) } + +func (tht testHealthTracker) AddFunc(f func(bool)) uint64 { + tht.funcs[tht.fnIndex] = f + tht.fnIndex++ + return tht.fnIndex - 1 +} + +func (tht testHealthTracker) RemoveFunc(funcID uint64) { delete(tht.funcs, funcID) } +func (tht testHealthTracker) IsHealthy() bool { return tht.healthy } +func (tht testHealthTracker) WasHealthy() bool { return tht.healthy } + +//////////////////////////////////////////////////////////////////////////////// +// NDF Primes // +//////////////////////////////////////////////////////////////////////////////// + +func getNDF() *ndf.NetworkDefinition { + return &ndf.NetworkDefinition{ + E2E: ndf.Group{ + Prime: "E2EE983D031DC1DB6F1A7A67DF0E9A8E5561DB8E8D49413394C049B7A" + + "8ACCEDC298708F121951D9CF920EC5D146727AA4AE535B0922C688B55B3D" + + "D2AEDF6C01C94764DAB937935AA83BE36E67760713AB44A6337C20E78615" + + "75E745D31F8B9E9AD8412118C62A3E2E29DF46B0864D0C951C394A5CBBDC" + + "6ADC718DD2A3E041023DBB5AB23EBB4742DE9C1687B5B34FA48C3521632C" + + "4A530E8FFB1BC51DADDF453B0B2717C2BC6669ED76B4BDD5C9FF558E88F2" + + "6E5785302BEDBCA23EAC5ACE92096EE8A60642FB61E8F3D24990B8CB12EE" + + "448EEF78E184C7242DD161C7738F32BF29A841698978825B4111B4BC3E1E" + + "198455095958333D776D8B2BEEED3A1A1A221A6E37E664A64B83981C46FF" + + "DDC1A45E3D5211AAF8BFBC072768C4F50D7D7803D2D4F278DE8014A47323" + + "631D7E064DE81C0C6BFA43EF0E6998860F1390B5D3FEACAF1696015CB79C" + + "3F9C2D93D961120CD0E5F12CBB687EAB045241F96789C38E89D796138E63" + + "19BE62E35D87B1048CA28BE389B575E994DCA755471584A09EC723742DC3" + + "5873847AEF49F66E43873", + Generator: "2", + }, + CMIX: ndf.Group{ + Prime: "9DB6FB5951B66BB6FE1E140F1D2CE5502374161FD6538DF1648218642" + + "F0B5C48C8F7A41AADFA187324B87674FA1822B00F1ECF8136943D7C55757" + + "264E5A1A44FFE012E9936E00C1D3E9310B01C7D179805D3058B2A9F4BB6F" + + "9716BFE6117C6B5B3CC4D9BE341104AD4A80AD6C94E005F4B993E14F091E" + + "B51743BF33050C38DE235567E1B34C3D6A5C0CEAA1A0F368213C3D19843D" + + "0B4B09DCB9FC72D39C8DE41F1BF14D4BB4563CA28371621CAD3324B6A2D3" + + "92145BEBFAC748805236F5CA2FE92B871CD8F9C36D3292B5509CA8CAA77A" + + "2ADFC7BFD77DDA6F71125A7456FEA153E433256A2261C6A06ED3693797E7" + + "995FAD5AABBCFBE3EDA2741E375404AE25B", + Generator: "5C7FF6B06F8F143FE8288433493E4769C4D988ACE5BE25A0E2480" + + "9670716C613D7B0CEE6932F8FAA7C44D2CB24523DA53FBE4F6EC3595892D" + + "1AA58C4328A06C46A15662E7EAA703A1DECF8BBB2D05DBE2EB956C142A33" + + "8661D10461C0D135472085057F3494309FFA73C611F78B32ADBB5740C361" + + "C9F35BE90997DB2014E2EF5AA61782F52ABEB8BD6432C4DD097BC5423B28" + + "5DAFB60DC364E8161F4A2A35ACA3A10B1C4D203CC76A470A33AFDCBDD929" + + "59859ABD8B56E1725252D78EAC66E71BA9AE3F1DD2487199874393CD4D83" + + "2186800654760E1E34C09E4D155179F9EC0DC4473F996BDCE6EED1CABED8" + + "B6F116F7AD9CF505DF0F998E34AB27514B0FFE7", + }, + } +} diff --git a/go.mod b/go.mod index 25c52a24bf5f20383e0bf30c78f1cd5b11d47ed2..27a55715ff942c4cbdc2bd6fcd82ce1694ebcbd7 100644 --- a/go.mod +++ b/go.mod @@ -18,9 +18,9 @@ require ( github.com/spf13/viper v1.7.1 gitlab.com/elixxir/bloomfilter v0.0.0-20200930191214-10e9ac31b228 gitlab.com/elixxir/comms v0.0.4-0.20211101174956-590ba1b47887 - gitlab.com/elixxir/crypto v0.0.7-0.20211111193842-1e8c19599a2a + gitlab.com/elixxir/crypto v0.0.7-0.20211118181958-04390d5356fa gitlab.com/elixxir/ekv v0.1.5 - gitlab.com/elixxir/primitives v0.0.3-0.20211102233208-a716d5c670b6 + gitlab.com/elixxir/primitives v0.0.3-0.20211111194525-20889b10db75 gitlab.com/xx_network/comms v0.0.4-0.20211014163953-e774276b83ae gitlab.com/xx_network/crypto v0.0.5-0.20211014163843-57b345890686 gitlab.com/xx_network/primitives v0.0.4-0.20211014163031-53405cf191fb diff --git a/go.sum b/go.sum index a46c45f7b34aa5d72ab9bdab631c0a824293903f..3fec2e96ed182103a775e254ff4bafa318b1d2f3 100644 --- a/go.sum +++ b/go.sum @@ -258,8 +258,8 @@ gitlab.com/elixxir/comms v0.0.4-0.20211101174956-590ba1b47887/go.mod h1:rQpTeFVS gitlab.com/elixxir/crypto v0.0.0-20200804182833-984246dea2c4/go.mod h1:ucm9SFKJo+K0N2GwRRpaNr+tKXMIOVWzmyUD0SbOu2c= gitlab.com/elixxir/crypto v0.0.3/go.mod h1:ZNgBOblhYToR4m8tj4cMvJ9UsJAUKq+p0gCp07WQmhA= gitlab.com/elixxir/crypto v0.0.7-0.20211022013957-3a7899285c4c/go.mod h1:teuTEXyqsqo4N/J1sshcTg9xYOt+wNTurop7pkZOiCg= -gitlab.com/elixxir/crypto v0.0.7-0.20211111193842-1e8c19599a2a h1:A9K8N963tMrF6iKjxw6MqK4k/FOKAJe/c+N+TIOrj6w= -gitlab.com/elixxir/crypto v0.0.7-0.20211111193842-1e8c19599a2a/go.mod h1:teuTEXyqsqo4N/J1sshcTg9xYOt+wNTurop7pkZOiCg= +gitlab.com/elixxir/crypto v0.0.7-0.20211118181958-04390d5356fa h1:XwoFEIRO4QXpKzeBZehebO/9kU/C3xew/MS5yDyvu1o= +gitlab.com/elixxir/crypto v0.0.7-0.20211118181958-04390d5356fa/go.mod h1:teuTEXyqsqo4N/J1sshcTg9xYOt+wNTurop7pkZOiCg= gitlab.com/elixxir/ekv v0.1.5 h1:R8M1PA5zRU1HVnTyrtwybdABh7gUJSCvt1JZwUSeTzk= gitlab.com/elixxir/ekv v0.1.5/go.mod h1:e6WPUt97taFZe5PFLPb1Dupk7tqmDCTQu1kkstqJvw4= gitlab.com/elixxir/primitives v0.0.0-20200731184040-494269b53b4d/go.mod h1:OQgUZq7SjnE0b+8+iIAT2eqQF+2IFHn73tOo+aV11mg= @@ -267,8 +267,8 @@ gitlab.com/elixxir/primitives v0.0.0-20200804170709-a1896d262cd9/go.mod h1:p0Vel gitlab.com/elixxir/primitives v0.0.0-20200804182913-788f47bded40/go.mod h1:tzdFFvb1ESmuTCOl1z6+yf6oAICDxH2NPUemVgoNLxc= gitlab.com/elixxir/primitives v0.0.1/go.mod h1:kNp47yPqja2lHSiS4DddTvFpB/4D9dB2YKnw5c+LJCE= gitlab.com/elixxir/primitives v0.0.3-0.20211014164029-06022665b576/go.mod h1:zZy8AlOISFm5IG4G4sylypnz7xNBfZ5mpXiibqJT8+8= -gitlab.com/elixxir/primitives v0.0.3-0.20211102233208-a716d5c670b6 h1:ymWyFBFLcRQiuSId54dq8PVeiV4W7a9737kV46Thjlk= -gitlab.com/elixxir/primitives v0.0.3-0.20211102233208-a716d5c670b6/go.mod h1:zZy8AlOISFm5IG4G4sylypnz7xNBfZ5mpXiibqJT8+8= +gitlab.com/elixxir/primitives v0.0.3-0.20211111194525-20889b10db75 h1:LQ9EmPJvm7T6WxwNltjnYKLF0QSYCrS+hyteb8Dfttk= +gitlab.com/elixxir/primitives v0.0.3-0.20211111194525-20889b10db75/go.mod h1:zZy8AlOISFm5IG4G4sylypnz7xNBfZ5mpXiibqJT8+8= gitlab.com/xx_network/comms v0.0.0-20200805174823-841427dd5023/go.mod h1:owEcxTRl7gsoM8c3RQ5KAm5GstxrJp5tn+6JfQ4z5Hw= gitlab.com/xx_network/comms v0.0.4-0.20211014163953-e774276b83ae h1:jmZWmSm8eH40SX5B5uOw2XaYoHYqVn8daTfa6B80AOs= gitlab.com/xx_network/comms v0.0.4-0.20211014163953-e774276b83ae/go.mod h1:wR9Vx0KZLrIs0g2Efcp0UwFPStjcDRWkg/DJLVQI2vw= diff --git a/groupChat/send.go b/groupChat/send.go index 916410d1759add008948cd8d748c0bcf47cd9add..868d420134b6ebb79c0687f58648311af6445559 100644 --- a/groupChat/send.go +++ b/groupChat/send.go @@ -11,6 +11,7 @@ import ( "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" gs "gitlab.com/elixxir/client/groupChat/groupStore" + "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/crypto/group" "gitlab.com/elixxir/primitives/format" @@ -64,12 +65,13 @@ func (m *Manager) Send(groupID *id.ID, message []byte) (id.Round, time.Time, // createMessages generates a list of cMix messages and a list of corresponding // recipient IDs. -func (m *Manager) createMessages(groupID *id.ID, msg []byte, - timestamp time.Time) (map[id.ID]format.Message, error) { +func (m *Manager) createMessages(groupID *id.ID, msg []byte, timestamp time.Time) ( + []message.TargetedCmixMessage, error) { g, exists := m.gs.Get(groupID) if !exists { - return map[id.ID]format.Message{}, errors.Errorf(newNoGroupErr, groupID) + return []message.TargetedCmixMessage{}, + errors.Errorf(newNoGroupErr, groupID) } return m.newMessages(g, msg, timestamp) @@ -78,9 +80,9 @@ func (m *Manager) createMessages(groupID *id.ID, msg []byte, // newMessages is a private function that allows the passing in of a timestamp // and streamGen instead of a fastRNG.StreamGenerator for easier testing. func (m *Manager) newMessages(g gs.Group, msg []byte, timestamp time.Time) ( - map[id.ID]format.Message, error) { + []message.TargetedCmixMessage, error) { // Create list of cMix messages - messages := make(map[id.ID]format.Message) + messages := make([]message.TargetedCmixMessage, 0, len(g.Members)) // Create channels to receive messages and errors on type msgInfo struct { @@ -120,7 +122,10 @@ func (m *Manager) newMessages(g gs.Group, msg []byte, timestamp time.Time) ( // Return on the first error that occurs return nil, err case info := <-msgChan: - messages[*info.id] = info.msg + messages = append(messages, message.TargetedCmixMessage{ + Recipient: info.id, + Message: info.msg, + }) } } diff --git a/groupChat/send_test.go b/groupChat/send_test.go index 395ffc3ee9e3b640b38e88e0b2e11331d9e8036d..8fd29ba8c60719ef6f0ca2eff8c1d39049ef1c42 100644 --- a/groupChat/send_test.go +++ b/groupChat/send_test.go @@ -11,6 +11,7 @@ import ( "bytes" "encoding/base64" gs "gitlab.com/elixxir/client/groupChat/groupStore" + "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/storage" "gitlab.com/elixxir/crypto/group" "gitlab.com/elixxir/primitives/format" @@ -26,18 +27,18 @@ import ( func TestManager_Send(t *testing.T) { prng := rand.New(rand.NewSource(42)) m, g := newTestManagerWithStore(prng, 10, 0, nil, nil, t) - message := []byte("Group chat message.") + messageBytes := []byte("Group chat message.") sender := m.gs.GetUser().DeepCopy() - _, _, err := m.Send(g.ID, message) + _, _, err := m.Send(g.ID, messageBytes) if err != nil { t.Errorf("Send() returned an error: %+v", err) } // Get messages sent with or return an error if no messages were sent - var messages map[id.ID]format.Message + var messages []message.TargetedCmixMessage if len(m.net.(*testNetworkManager).messages) > 0 { - messages = m.net.(*testNetworkManager).GetMsgMap(0) + messages = m.net.(*testNetworkManager).GetMsgList(0) } else { t.Error("No group cMix messages received.") } @@ -47,56 +48,56 @@ func TestManager_Send(t *testing.T) { // Loop through each message and make sure the recipient ID matches a member // in the group and that each message can be decrypted and have the expected // values - for rid, msg := range messages { + for _, msg := range messages { // Check if recipient ID is in member list var foundMember group.Member for _, mem := range g.Members { - if rid.Cmp(mem.ID) { + if msg.Recipient.Cmp(mem.ID) { foundMember = mem } } // Error if the recipient ID is not found in the member list if foundMember == (group.Member{}) { - t.Errorf("Failed to find ID %s in memorship list.", rid) + t.Errorf("Failed to find ID %s in memorship list.", msg.Recipient) continue } - publicMsg, err := unmarshalPublicMsg(msg.GetContents()) + publicMessage, err := unmarshalPublicMsg(msg.Message.GetContents()) if err != nil { t.Errorf("Failed to unmarshal publicMsg: %+v", err) } // Attempt to read the message messageID, timestamp, senderID, readMsg, err := m.decryptMessage( - g, msg, publicMsg, timeNow) + g, msg.Message, publicMessage, timeNow) if err != nil { - t.Errorf("Failed to read message for %s: %+v", rid.String(), err) + t.Errorf("Failed to read message for %s: %+v", msg.Recipient, err) } - internalMsg, _ := newInternalMsg(publicMsg.GetPayloadSize()) - internalMsg.SetTimestamp(timestamp) - internalMsg.SetSenderID(m.gs.GetUser().ID) - internalMsg.SetPayload(message) - expectedMsgID := group.NewMessageID(g.ID, internalMsg.Marshal()) + internalMessage, _ := newInternalMsg(publicMessage.GetPayloadSize()) + internalMessage.SetTimestamp(timestamp) + internalMessage.SetSenderID(m.gs.GetUser().ID) + internalMessage.SetPayload(messageBytes) + expectedMsgID := group.NewMessageID(g.ID, internalMessage.Marshal()) if expectedMsgID != messageID { t.Errorf("Message ID received for %s too different from expected."+ - "\nexpected: %s\nreceived: %s", &rid, expectedMsgID, messageID) + "\nexpected: %s\nreceived: %s", msg.Recipient, expectedMsgID, messageID) } if !timestamp.Round(5 * time.Second).Equal(timeNow.Round(5 * time.Second)) { t.Errorf("Timestamp received for %s too different from expected."+ - "\nexpected: %s\nreceived: %s", &rid, timeNow, timestamp) + "\nexpected: %s\nreceived: %s", msg.Recipient, timeNow, timestamp) } if !senderID.Cmp(sender.ID) { t.Errorf("Sender ID received for %s incorrect."+ - "\nexpected: %s\nreceived: %s", &rid, sender.ID, senderID) + "\nexpected: %s\nreceived: %s", msg.Recipient, sender.ID, senderID) } - if !bytes.Equal(readMsg, message) { + if !bytes.Equal(readMsg, messageBytes) { t.Errorf("Message received for %s incorrect."+ - "\nexpected: %q\nreceived: %q", &rid, message, readMsg) + "\nexpected: %q\nreceived: %q", msg.Recipient, messageBytes, readMsg) } } } @@ -142,9 +143,9 @@ func TestManager_createMessages(t *testing.T) { prng := rand.New(rand.NewSource(42)) m, g := newTestManagerWithStore(prng, 10, 0, nil, nil, t) - message := []byte("Test group message.") + testMsg := []byte("Test group message.") sender := m.gs.GetUser() - messages, err := m.createMessages(g.ID, message, netTime.Now()) + messages, err := m.createMessages(g.ID, testMsg, netTime.Now()) if err != nil { t.Errorf("createMessages() returned an error: %+v", err) } @@ -152,28 +153,28 @@ func TestManager_createMessages(t *testing.T) { recipients := append(g.Members[:2], g.Members[3:]...) i := 0 - for rid, msg := range messages { + for _, msg := range messages { for _, recipient := range recipients { - if !rid.Cmp(recipient.ID) { + if !msg.Recipient.Cmp(recipient.ID) { continue } - publicMsg, err := unmarshalPublicMsg(msg.GetContents()) + publicMessage, err := unmarshalPublicMsg(msg.Message.GetContents()) if err != nil { t.Errorf("Failed to unmarshal publicMsg: %+v", err) } messageID, timestamp, testSender, testMessage, err := m.decryptMessage( - g, msg, publicMsg, netTime.Now()) + g, msg.Message, publicMessage, netTime.Now()) if err != nil { t.Errorf("Failed to find member to read message %d: %+v", i, err) } - internalMsg, _ := newInternalMsg(publicMsg.GetPayloadSize()) - internalMsg.SetTimestamp(timestamp) - internalMsg.SetSenderID(m.gs.GetUser().ID) - internalMsg.SetPayload(message) - expectedMsgID := group.NewMessageID(g.ID, internalMsg.Marshal()) + internalMessage, _ := newInternalMsg(publicMessage.GetPayloadSize()) + internalMessage.SetTimestamp(timestamp) + internalMessage.SetSenderID(m.gs.GetUser().ID) + internalMessage.SetPayload(testMsg) + expectedMsgID := group.NewMessageID(g.ID, internalMessage.Marshal()) if messageID != expectedMsgID { t.Errorf("Failed to read correct message ID for message %d."+ @@ -185,9 +186,9 @@ func TestManager_createMessages(t *testing.T) { "\nexpected: %s\nreceived: %s", i, sender.ID, testSender) } - if !bytes.Equal(message, testMessage) { + if !bytes.Equal(testMsg, testMessage) { t.Errorf("Failed to read correct message for message %d."+ - "\nexpected: %s\nreceived: %s", i, message, testMessage) + "\nexpected: %s\nreceived: %s", i, testMsg, testMessage) } } i++ @@ -217,10 +218,10 @@ func TestGroup_newMessages(t *testing.T) { prng := rand.New(rand.NewSource(42)) m, g := newTestManager(prng, t) - message := []byte("Test group message.") + testMsg := []byte("Test group message.") sender := m.gs.GetUser() timestamp := netTime.Now() - messages, err := m.newMessages(g, message, timestamp) + messages, err := m.newMessages(g, testMsg, timestamp) if err != nil { t.Errorf("newMessages() returned an error: %+v", err) } @@ -228,28 +229,28 @@ func TestGroup_newMessages(t *testing.T) { recipients := append(g.Members[:2], g.Members[3:]...) i := 0 - for rid, msg := range messages { + for _, msg := range messages { for _, recipient := range recipients { - if !rid.Cmp(recipient.ID) { + if !msg.Recipient.Cmp(recipient.ID) { continue } - publicMsg, err := unmarshalPublicMsg(msg.GetContents()) + publicMessage, err := unmarshalPublicMsg(msg.Message.GetContents()) if err != nil { t.Errorf("Failed to unmarshal publicMsg: %+v", err) } messageID, testTimestamp, testSender, testMessage, err := m.decryptMessage( - g, msg, publicMsg, netTime.Now()) + g, msg.Message, publicMessage, netTime.Now()) if err != nil { t.Errorf("Failed to find member to read message %d.", i) } - internalMsg, _ := newInternalMsg(publicMsg.GetPayloadSize()) - internalMsg.SetTimestamp(timestamp) - internalMsg.SetSenderID(m.gs.GetUser().ID) - internalMsg.SetPayload(message) - expectedMsgID := group.NewMessageID(g.ID, internalMsg.Marshal()) + internalMessage, _ := newInternalMsg(publicMessage.GetPayloadSize()) + internalMessage.SetTimestamp(timestamp) + internalMessage.SetSenderID(m.gs.GetUser().ID) + internalMessage.SetPayload(testMsg) + expectedMsgID := group.NewMessageID(g.ID, internalMessage.Marshal()) if messageID != expectedMsgID { t.Errorf("Failed to read correct message ID for message %d."+ @@ -266,9 +267,9 @@ func TestGroup_newMessages(t *testing.T) { "\nexpected: %s\nreceived: %s", i, sender.ID, testSender) } - if !bytes.Equal(message, testMessage) { + if !bytes.Equal(testMsg, testMessage) { t.Errorf("Failed to read correct message for message %d."+ - "\nexpected: %s\nreceived: %s", i, message, testMessage) + "\nexpected: %s\nreceived: %s", i, testMsg, testMessage) } } i++ @@ -295,13 +296,13 @@ func TestGroup_newCmixMsg(t *testing.T) { m, g := newTestManager(prng, t) // Create test parameters - message := []byte("Test group message.") + testMsg := []byte("Test group message.") mem := g.Members[3] timeNow := netTime.Now() // Create cMix message prng = rand.New(rand.NewSource(42)) - msg, err := m.newCmixMsg(g, message, timeNow, mem, prng) + msg, err := m.newCmixMsg(g, testMsg, timeNow, mem, prng) if err != nil { t.Errorf("newCmixMsg() returned an error: %+v", err) } @@ -316,12 +317,12 @@ func TestGroup_newCmixMsg(t *testing.T) { // Create expected messages cmixMsg := format.NewMessage(m.store.Cmix().GetGroup().GetP().ByteLen()) - publicMsg, _ := newPublicMsg(cmixMsg.ContentsSize()) - internalMsg, _ := newInternalMsg(publicMsg.GetPayloadSize()) - internalMsg.SetTimestamp(timeNow) - internalMsg.SetSenderID(m.gs.GetUser().ID) - internalMsg.SetPayload(message) - payload := internalMsg.Marshal() + publicMessage, _ := newPublicMsg(cmixMsg.ContentsSize()) + internalMessage, _ := newInternalMsg(publicMessage.GetPayloadSize()) + internalMessage.SetTimestamp(timeNow) + internalMessage.SetSenderID(m.gs.GetUser().ID) + internalMessage.SetPayload(testMsg) + payload := internalMessage.Marshal() // Check if key fingerprint is correct expectedFp := group.NewKeyFingerprint(g.Key, salt, mem.ID) @@ -339,24 +340,24 @@ func TestGroup_newCmixMsg(t *testing.T) { } // Attempt to unmarshal public group message - publicMsg, err = unmarshalPublicMsg(msg.GetContents()) + publicMessage, err = unmarshalPublicMsg(msg.GetContents()) if err != nil { t.Errorf("Failed to unmarshal cMix message contents: %+v", err) } // Attempt to decrypt payload - decryptedPayload := group.Decrypt(key, expectedFp, publicMsg.GetPayload()) - internalMsg, err = unmarshalInternalMsg(decryptedPayload) + decryptedPayload := group.Decrypt(key, expectedFp, publicMessage.GetPayload()) + internalMessage, err = unmarshalInternalMsg(decryptedPayload) if err != nil { t.Errorf("Failed to unmarshal decrypted payload contents: %+v", err) } // Check for expected values in internal message - if !internalMsg.GetTimestamp().Equal(timeNow) { + if !internalMessage.GetTimestamp().Equal(timeNow) { t.Errorf("Internal message has wrong timestamp."+ - "\nexpected: %s\nreceived: %s", timeNow, internalMsg.GetTimestamp()) + "\nexpected: %s\nreceived: %s", timeNow, internalMessage.GetTimestamp()) } - sid, err := internalMsg.GetSenderID() + sid, err := internalMessage.GetSenderID() if err != nil { t.Fatalf("Failed to get sender ID from internal message: %+v", err) } @@ -364,9 +365,9 @@ func TestGroup_newCmixMsg(t *testing.T) { t.Errorf("Internal message has wrong sender ID."+ "\nexpected: %s\nreceived: %s", m.gs.GetUser().ID, sid) } - if !bytes.Equal(internalMsg.GetPayload(), message) { + if !bytes.Equal(internalMessage.GetPayload(), testMsg) { t.Errorf("Internal message has wrong payload."+ - "\nexpected: %s\nreceived: %s", message, internalMsg.GetPayload()) + "\nexpected: %s\nreceived: %s", testMsg, internalMessage.GetPayload()) } } @@ -392,12 +393,12 @@ func TestGroup_newCmixMsg_InternalMsgSizeError(t *testing.T) { m, g := newTestManager(prng, t) // Create test parameters - message := make([]byte, 341) + testMsg := make([]byte, 341) mem := group.Member{ID: id.NewIdFromString("memberID", id.User, t)} // Create cMix message prng = rand.New(rand.NewSource(42)) - _, err := m.newCmixMsg(g, message, netTime.Now(), mem, prng) + _, err := m.newCmixMsg(g, testMsg, netTime.Now(), mem, prng) if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Errorf("newCmixMsg() failed to return the expected error"+ "\nexpected: %s\nreceived: %+v", expectedErr, err) @@ -483,16 +484,16 @@ func Test_newSalt_ReadLengthError(t *testing.T) { // Tests that the marshaled internalMsg can be unmarshaled and has all the // original values. func Test_setInternalPayload(t *testing.T) { - internalMsg, err := newInternalMsg(internalMinLen * 2) + internalMessage, err := newInternalMsg(internalMinLen * 2) if err != nil { t.Errorf("Failed to create a new internalMsg: %+v", err) } timestamp := netTime.Now() sender := id.NewIdFromString("sender ID", id.User, t) - message := []byte("This is an internal message.") + testMsg := []byte("This is an internal message.") - payload := setInternalPayload(internalMsg, timestamp, sender, message) + payload := setInternalPayload(internalMessage, timestamp, sender, testMsg) if err != nil { t.Errorf("setInternalPayload() returned an error: %+v", err) } @@ -517,9 +518,9 @@ func Test_setInternalPayload(t *testing.T) { sender, testSender) } - if !bytes.Equal(message, unmarshalled.GetPayload()) { + if !bytes.Equal(testMsg, unmarshalled.GetPayload()) { t.Errorf("Payload does not match original.\nexpected: %v\nreceived: %v", - message, unmarshalled.GetPayload()) + testMsg, unmarshalled.GetPayload()) } } @@ -527,17 +528,17 @@ func Test_setInternalPayload(t *testing.T) { // original values. func Test_setPublicPayload(t *testing.T) { prng := rand.New(rand.NewSource(42)) - publicMsg, err := newPublicMsg(publicMinLen * 2) + publicMessage, err := newPublicMsg(publicMinLen * 2) if err != nil { t.Errorf("Failed to create a new publicMsg: %+v", err) } var salt [group.SaltLen]byte prng.Read(salt[:]) - encryptedPayload := make([]byte, publicMsg.GetPayloadSize()) + encryptedPayload := make([]byte, publicMessage.GetPayloadSize()) copy(encryptedPayload, "This is an internal message.") - payload := setPublicPayload(publicMsg, salt, encryptedPayload) + payload := setPublicPayload(publicMessage, salt, encryptedPayload) if err != nil { t.Errorf("setPublicPayload() returned an error: %+v", err) } diff --git a/groupChat/utils_test.go b/groupChat/utils_test.go index 069772587c34f8b8597b8b56951b383c0c7fe752..612ad97f0c6831cc01632b71217346f4ca0532eb 100644 --- a/groupChat/utils_test.go +++ b/groupChat/utils_test.go @@ -218,7 +218,7 @@ func newTestNetworkManager(sendErr int, t *testing.T) interfaces.NetworkManager return &testNetworkManager{ instance: thisInstance, - messages: []map[id.ID]format.Message{}, + messages: [][]message.TargetedCmixMessage{}, sendErr: sendErr, } } @@ -226,14 +226,14 @@ func newTestNetworkManager(sendErr int, t *testing.T) interfaces.NetworkManager // testNetworkManager is a test implementation of NetworkManager interface. type testNetworkManager struct { instance *network.Instance - messages []map[id.ID]format.Message + messages [][]message.TargetedCmixMessage e2eMessages []message.Send errSkip int sendErr int sync.RWMutex } -func (tnm *testNetworkManager) GetMsgMap(i int) map[id.ID]format.Message { +func (tnm *testNetworkManager) GetMsgList(i int) []message.TargetedCmixMessage { tnm.RLock() defer tnm.RUnlock() return tnm.messages[i] @@ -274,8 +274,9 @@ func (tnm *testNetworkManager) SendCMIX(format.Message, *id.ID, params.CMIX) (id return 0, ephemeral.Id{}, nil } -func (tnm *testNetworkManager) SendManyCMIX(messages map[id.ID]format.Message, - _ params.CMIX) (id.Round, []ephemeral.Id, error) { +func (tnm *testNetworkManager) SendManyCMIX( + messages []message.TargetedCmixMessage, _ params.CMIX) (id.Round, + []ephemeral.Id, error) { if tnm.sendErr == 1 { return 0, nil, errors.New("SendManyCMIX error") } diff --git a/interfaces/fileTransfer.go b/interfaces/fileTransfer.go new file mode 100644 index 0000000000000000000000000000000000000000..3f1c888af925a9603434ecadcf0a8c5fdab29ae7 --- /dev/null +++ b/interfaces/fileTransfer.go @@ -0,0 +1,85 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package interfaces + +import ( + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/primitives/id" + "time" +) + +// SentProgressCallback is a callback function that tracks the progress of +// sending a file. +type SentProgressCallback func(completed bool, sent, arrived, total uint16, + err error) + +// ReceivedProgressCallback is a callback function that tracks the progress of +// receiving a file. +type ReceivedProgressCallback func(completed bool, received, total uint16, + err error) + +// ReceiveCallback is a callback function that notifies the receiver of an +// incoming file transfer. +type ReceiveCallback func(tid ftCrypto.TransferID, fileName string, + sender *id.ID, size uint32, preview []byte) + +// FileTransfer facilities the sending and receiving of large file transfers. +// It allows for progress tracking of both inbound and outbound transfers. +type FileTransfer interface { + // Send sends a file to the recipient. The sender must have an E2E + // relationship with the recipient. + // The retry float is the total amount of data to send relative to the data + // size. Data will be resent on error and will resend up to [(1 + retry) * + // fileSize]. + // The preview stores a preview of the data (such as a thumbnail) and is + // capped at 4 kB in size. + // Returns a unique transfer ID used to identify the transfer. + Send(fileName string, fileData []byte, recipient *id.ID, retry float32, + preview []byte, progressCB SentProgressCallback, period time.Duration) ( + ftCrypto.TransferID, error) + + // RegisterSendProgressCallback allows for the registration of a callback to + // track the progress of an individual sent file transfer. The callback will + // be called immediately when added to report the current status of the + // transfer. It will then call every time a file part is sent, a file part + // arrives, the transfer completes, or an error occurs. It is called at most + // once ever period, which means if events occur faster than the period, + // then they will not be reported and instead the progress will be reported + // once at the end of the period. + RegisterSendProgressCallback(tid ftCrypto.TransferID, + progressCB SentProgressCallback, period time.Duration) error + + // Resend resends a file if Send fails. Returns an error if CloseSend + // was already called or if the transfer did not run out of retries. + Resend(tid ftCrypto.TransferID) error + + // CloseSend deletes a file from the internal storage once a transfer has + // completed or reached the retry limit. Returns an error if the transfer + // has not run out of retries. + CloseSend(tid ftCrypto.TransferID) error + + // RegisterReceiveProgressCallback allows for the registration of a callback + // to track the progress of an individual received file transfer. The + // callback will be called immediately when added to report the current + // status of the transfer. It will then call every time a file part is + // received, the transfer completes, or an error occurs. It is called at + // most once ever period, which means if events occur faster than the + // period, then they will not be reported and instead the progress will be + // reported once at the end of the period. + // Once the callback reports that the transfer has completed, the recipient + // can get the full file by calling Receive. + RegisterReceiveProgressCallback(tid ftCrypto.TransferID, + progressCB ReceivedProgressCallback, period time.Duration) error + + // Receive returns the full file on the completion of the transfer as + // reported by a registered ReceivedProgressCallback. It deletes internal + // references to the data and unregisters any attached progress callback. + // Returns an error if the transfer is not complete, the full file cannot be + // verified, or if the transfer cannot be found. + Receive(tid ftCrypto.TransferID) ([]byte, error) +} diff --git a/interfaces/message/sendMany.go b/interfaces/message/sendMany.go new file mode 100644 index 0000000000000000000000000000000000000000..a14dcef0a10203424e4fb21fdca342767988aef1 --- /dev/null +++ b/interfaces/message/sendMany.go @@ -0,0 +1,20 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package message + +import ( + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/primitives/id" +) + +// TargetedCmixMessage defines a recipient target pair in a sendMany cMix +// message. +type TargetedCmixMessage struct { + Recipient *id.ID + Message format.Message +} diff --git a/interfaces/message/type.go b/interfaces/message/type.go index 5c8012fea55a6f7c6afcc953ef37dbf0daa7b6e0..daf1a6dabdc1762774ac10c4b7e87491ba05bbba 100644 --- a/interfaces/message/type.go +++ b/interfaces/message/type.go @@ -53,4 +53,8 @@ const ( /* Group chat message types */ // A group chat request message sent to all members in a group. GroupCreationRequest = 40 + + // NewFileTransfer is the initial message that initiates a large file + // transfer. + NewFileTransfer = 50 ) diff --git a/interfaces/networkManager.go b/interfaces/networkManager.go index b0d5351b70e290ec024d5d4d36e72c2eed2f4f1b..88d4623dfb9373206bf9fdf8d8919a7904139f55 100644 --- a/interfaces/networkManager.go +++ b/interfaces/networkManager.go @@ -25,7 +25,7 @@ type NetworkManager interface { SendE2E(m message.Send, p params.E2E, stop *stoppable.Single) ([]id.Round, e2e.MessageID, time.Time, error) SendUnsafe(m message.Send, p params.Unsafe) ([]id.Round, error) SendCMIX(message format.Message, recipient *id.ID, p params.CMIX) (id.Round, ephemeral.Id, error) - SendManyCMIX(messages map[id.ID]format.Message, p params.CMIX) (id.Round, []ephemeral.Id, error) + SendManyCMIX(messages []message.TargetedCmixMessage, p params.CMIX) (id.Round, []ephemeral.Id, error) GetInstance() *network.Instance GetHealthTracker() HealthTracker GetEventManager() EventManager diff --git a/keyExchange/utils_test.go b/keyExchange/utils_test.go index 6acb4a73a6224a0230a32ec3a4b3ccabe3738814..29fa74e3455a141de6fc98f6bdccfd5481523d65 100644 --- a/keyExchange/utils_test.go +++ b/keyExchange/utils_test.go @@ -87,7 +87,7 @@ func (t *testNetworkManagerGeneric) SendCMIX(message format.Message, rid *id.ID, } -func (t *testNetworkManagerGeneric) SendManyCMIX(messages map[id.ID]format.Message, p params.CMIX) (id.Round, []ephemeral.Id, error) { +func (t *testNetworkManagerGeneric) SendManyCMIX(messages []message.TargetedCmixMessage, p params.CMIX) (id.Round, []ephemeral.Id, error) { return id.Round(0), []ephemeral.Id{}, nil } @@ -219,7 +219,7 @@ func (t *testNetworkManagerFullExchange) SendCMIX(message format.Message, eid *i return id.Round(0), ephemeral.Id{}, nil } -func (t *testNetworkManagerFullExchange) SendManyCMIX(messages map[id.ID]format.Message, p params.CMIX) (id.Round, []ephemeral.Id, error) { +func (t *testNetworkManagerFullExchange) SendManyCMIX(messages []message.TargetedCmixMessage, p params.CMIX) (id.Round, []ephemeral.Id, error) { return id.Round(0), []ephemeral.Id{}, nil } diff --git a/network/ephemeral/testutil.go b/network/ephemeral/testutil.go index 634518a1a6ba255fc55b47b092e1ed6b42efe1a7..32ce9574c6081a2cb3b201ee1228794e2f06e754 100644 --- a/network/ephemeral/testutil.go +++ b/network/ephemeral/testutil.go @@ -63,7 +63,7 @@ func (t *testNetworkManager) SendCMIX(format.Message, *id.ID, params.CMIX) (id.R return 0, ephemeral.Id{}, nil } -func (t *testNetworkManager) SendManyCMIX(messages map[id.ID]format.Message, p params.CMIX) (id.Round, []ephemeral.Id, error) { +func (t *testNetworkManager) SendManyCMIX(messages []message.TargetedCmixMessage, p params.CMIX) (id.Round, []ephemeral.Id, error) { return 0, []ephemeral.Id{}, nil } diff --git a/network/message/garbled.go b/network/message/garbled.go index e8fb10cbe49845f7370639877c4dcdba45342a39..54ca39fa456095b45ba30a1cd63f7f71bdc40db3 100644 --- a/network/message/garbled.go +++ b/network/message/garbled.go @@ -8,11 +8,15 @@ package message import ( + "fmt" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/id/ephemeral" "gitlab.com/xx_network/primitives/netTime" + "time" ) // Messages can arrive in the network out of order. When message handling fails @@ -76,6 +80,39 @@ func (m *Manager) handleGarbledMessages() { continue } } + } else { + // todo: figure out how to get the ephermal reception id in here. + // we have the raw data, but do not know what address space was + // used int he round + // todo: figure out how to get the round id, the recipient id, and the round timestamp + /* + ephid, err := ephemeral.Marshal(garbledMsg.GetEphemeralRID()) + if err!=nil{ + jww.WARN.Printf("failed to get the ephemeral id for a garbled " + + "message, clearing the message: %+v", err) + garbledMsgs.Remove(garbledMsg) + continue + } + + ephid.Clear(m.)*/ + + raw := message.Receive{ + Payload: grbldMsg.Marshal(), + MessageType: message.Raw, + Sender: &id.ID{}, + EphemeralID: ephemeral.Id{}, + Timestamp: time.Time{}, + Encryption: message.None, + RecipientID: &id.ID{}, + RoundId: 0, + RoundTimestamp: time.Time{}, + } + im := fmt.Sprintf("Garbled/RAW Message reprecessed: keyFP: %v, "+ + "msgDigest: %s", grbldMsg.GetKeyFP(), grbldMsg.Digest()) + jww.INFO.Print(im) + m.Internal.Events.Report(1, "MessageReception", "Garbled", im) + m.Session.GetGarbledMessages().Add(grbldMsg) + m.Switchboard.Speak(raw) } // fail the message if any part of the decryption fails, // unless it is the last attempts and has been in the buffer long diff --git a/network/message/sendCmixUtils.go b/network/message/sendCmixUtils.go index b889b0ad86a7a09e7df10ab1ea72bd03541a4057..15674cd2f0de004d306e44ebc9e1280845fedfa5 100644 --- a/network/message/sendCmixUtils.go +++ b/network/message/sendCmixUtils.go @@ -10,6 +10,7 @@ package message import ( "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" preimage2 "gitlab.com/elixxir/client/interfaces/preimage" "gitlab.com/elixxir/client/storage" @@ -121,7 +122,8 @@ func processRound(instance *network.Instance, session *storage.Session, // the recipient. func buildSlotMessage(msg format.Message, recipient *id.ID, target *id.ID, stream *fastRNG.Stream, senderId *id.ID, bestRound *pb.RoundInfo, - roundKeys *cmix.RoundKeys, param params.CMIX) (*pb.GatewaySlot, format.Message, ephemeral.Id, + roundKeys *cmix.RoundKeys, param params.CMIX) (*pb.GatewaySlot, + format.Message, ephemeral.Id, error) { // Set the ephemeral ID @@ -214,38 +216,42 @@ func handleMissingNodeKeys(instance *network.Instance, } } -// messageMapToStrings serializes a map of IDs and messages into a string of IDs -// and a string of message digests. Intended for use in printing to logs. -func messageMapToStrings(msgList map[id.ID]format.Message) (string, string) { +// messageListToStrings serializes a list of message.TargetedCmixMessage into a +// string of comma seperated recipient IDs and a string of comma seperated +// message digests. Duplicate recipient IDs are printed once. Intended for use +// in printing to log. +func messageListToStrings(msgList []message.TargetedCmixMessage) (string, string) { idStrings := make([]string, 0, len(msgList)) - msgDigests := make([]string, 0, len(msgList)) - for uid, msg := range msgList { - idStrings = append(idStrings, uid.String()) - msgDigests = append(msgDigests, msg.Digest()) + idMap := make(map[id.ID]bool, len(msgList)) + msgDigests := make([]string, len(msgList)) + for i, msg := range msgList { + if !idMap[*msg.Recipient] { + idStrings = append(idStrings, msg.Recipient.String()) + idMap[*msg.Recipient] = true + } + msgDigests[i] = msg.Message.Digest() } - return strings.Join(idStrings, ","), strings.Join(msgDigests, ",") + return strings.Join(idStrings, ", "), strings.Join(msgDigests, ", ") } -// messagesToDigestString serializes a list of messages into a string of message -// digests. Intended for use in printing to the logs. +// messagesToDigestString serializes a list of cMix messages into a string of +// comma seperated message digests. Intended for use in printing to log. func messagesToDigestString(msgs []format.Message) string { - msgDigests := make([]string, 0, len(msgs)) - for _, msg := range msgs { - msgDigests = append(msgDigests, msg.Digest()) + msgDigests := make([]string, len(msgs)) + for i, msg := range msgs { + msgDigests[i] = msg.Digest() } - return strings.Join(msgDigests, ",") + return strings.Join(msgDigests, ", ") } -// ephemeralIdListToString serializes a list of ephemeral IDs into a human- -// readable format. Intended for use in printing to logs. +// ephemeralIdListToString serializes a list of ephemeral IDs into a string of +// comma seperated integer representations. Intended for use in printing to log. func ephemeralIdListToString(idList []ephemeral.Id) string { - idStrings := make([]string, 0, len(idList)) - - for i := 0; i < len(idList); i++ { - ephIdStr := strconv.FormatInt(idList[i].Int64(), 10) - idStrings = append(idStrings, ephIdStr) + idStrings := make([]string, len(idList)) + for i, ephID := range idList { + idStrings[i] = strconv.FormatInt(ephID.Int64(), 10) } return strings.Join(idStrings, ",") diff --git a/network/message/sendManyCmix.go b/network/message/sendManyCmix.go index 31e9ef78ccc7baea47909d687a2b0671994511e4..04b833344bf42e524cba6756bc4770785ffad18a 100644 --- a/network/message/sendManyCmix.go +++ b/network/message/sendManyCmix.go @@ -8,11 +8,15 @@ package message import ( + "fmt" "github.com/golang-collections/collections/set" "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/network/gateway" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/comms/network" @@ -31,43 +35,39 @@ import ( // round the payload was sent or an error if it fails. // WARNING: Potentially Unsafe func (m *Manager) SendManyCMIX(sender *gateway.Sender, - messages map[id.ID]format.Message, p params.CMIX) (id.Round, []ephemeral.Id, - error) { + messages []message.TargetedCmixMessage, p params.CMIX, + stop *stoppable.Single) (id.Round, []ephemeral.Id, error) { - // Create message copies - messagesCopy := make(map[id.ID]format.Message, len(messages)) - for rid, msg := range messages { - messagesCopy[rid] = msg.Copy() - } - - return sendManyCmixHelper(sender, messagesCopy, p, m.Instance, m.Session, - m.nodeRegistration, m.Rng, m.TransmissionID, m.Comms) + return sendManyCmixHelper(sender, messages, p, m.blacklistedNodes, + m.Instance, m.Session, m.nodeRegistration, m.Rng, m.Internal.Events, + m.TransmissionID, m.Comms, stop) } // sendManyCmixHelper is a helper function for Manager.SendManyCMIX. // -// NOTE: Payloads sent are not end to end encrypted, metadata is NOT protected -// with this call; see SendE2E for end to end encryption and full privacy +// NOTE: Payloads sent are not end-to-end encrypted, metadata is NOT protected +// with this call; see SendE2E for end-to-end encryption and full privacy // protection. Internal SendManyCMIX, which bypasses the network check, will -// attempt to send to the network without checking state. It has a built in +// attempt to send to the network without checking state. It has a built-in // retry system which can be configured through the params object. // // If the message is successfully sent, the ID of the round sent it is returned, // which can be registered with the network instance to get a callback on its // status. -func sendManyCmixHelper(sender *gateway.Sender, msgs map[id.ID]format.Message, - param params.CMIX, instance *network.Instance, session *storage.Session, - nodeRegistration chan network.NodeGateway, rng *fastRNG.StreamGenerator, - senderId *id.ID, comms sendCmixCommsInterface) (id.Round, []ephemeral.Id, error) { +func sendManyCmixHelper(sender *gateway.Sender, + msgs []message.TargetedCmixMessage, param params.CMIX, + blacklistedNodes map[string]interface{}, instance *network.Instance, + session *storage.Session, nodeRegistration chan network.NodeGateway, + rng *fastRNG.StreamGenerator, events interfaces.EventManager, + senderId *id.ID, comms sendCmixCommsInterface, stop *stoppable.Single) ( + id.Round, []ephemeral.Id, error) { timeStart := netTime.Now() attempted := set.New() - stream := rng.GetStream() - defer stream.Close() maxTimeout := sender.GetHostParams().SendTimeout - recipientString, msgDigests := messageMapToStrings(msgs) + recipientString, msgDigests := messageListToStrings(msgs) jww.INFO.Printf("Looking for round to send cMix messages to [%s] "+ "(msgDigest: %s)", recipientString, msgDigests) @@ -96,9 +96,25 @@ func sendManyCmixHelper(sender *gateway.Sender, msgs map[id.ID]format.Message, continue } - // Add the round on to the list of attempted so it is not tried again + // Add the round on to the list of attempted rounds so that it is not + // tried again attempted.Insert(bestRound) + // Determine whether the selected round contains any nodes that are + // blacklisted by the params.Network object + containsBlacklisted := false + for _, nodeId := range bestRound.Topology { + if _, isBlacklisted := blacklistedNodes[string(nodeId)]; isBlacklisted { + containsBlacklisted = true + break + } + } + if containsBlacklisted { + jww.WARN.Printf("Round %d contains blacklisted node, skipping...", + bestRound.ID) + continue + } + // Retrieve host and key information from round firstGateway, roundKeys, err := processRound(instance, session, nodeRegistration, bestRound, recipientString, msgDigests) @@ -111,27 +127,30 @@ func sendManyCmixHelper(sender *gateway.Sender, msgs map[id.ID]format.Message, // Build a slot for every message and recipient slots := make([]*pb.GatewaySlot, len(msgs)) - ephemeralIds := make([]ephemeral.Id, len(msgs)) encMsgs := make([]format.Message, len(msgs)) - i := 0 - for recipient, msg := range msgs { - slots[i], encMsgs[i], ephemeralIds[i], err = buildSlotMessage( - msg, &recipient, firstGateway, stream, senderId, bestRound, roundKeys, param) + ephemeralIDs := make([]ephemeral.Id, len(msgs)) + stream := rng.GetStream() + for i, msg := range msgs { + slots[i], encMsgs[i], ephemeralIDs[i], err = buildSlotMessage( + msg.Message, msg.Recipient, firstGateway, stream, senderId, + bestRound, roundKeys, param) if err != nil { + stream.Close() jww.INFO.Printf("error building slot received: %v", err) return 0, []ephemeral.Id{}, errors.Errorf("failed to build "+ - "slot message for %s: %+v", recipient, err) + "slot message for %s: %+v", msg.Recipient, err) } - i++ } + stream.Close() + // Serialize lists into a printable format - ephemeralIdsString := ephemeralIdListToString(ephemeralIds) + ephemeralIDsString := ephemeralIdListToString(ephemeralIDs) encMsgsDigest := messagesToDigestString(encMsgs) jww.INFO.Printf("Sending to EphIDs [%s] (%s) on round %d, "+ "(msgDigest: %s, ecrMsgDigest: %s) via gateway %s", - ephemeralIdsString, recipientString, bestRound.ID, msgDigests, + ephemeralIDsString, recipientString, bestRound.ID, msgDigests, encMsgsDigest, firstGateway) // Wrap slots in the proper message type @@ -156,14 +175,20 @@ func sendManyCmixHelper(sender *gateway.Sender, msgs map[id.ID]format.Message, } return result, err } - result, err := sender.SendToPreferred([]*id.ID{firstGateway}, sendFunc, nil) + result, err := sender.SendToPreferred( + []*id.ID{firstGateway}, sendFunc, stop) + + // Exit if the thread has been stopped + if stoppable.CheckErr(err) { + return 0, []ephemeral.Id{}, err + } // If the comm errors or the message fails to send, continue retrying if err != nil { if !strings.Contains(err.Error(), unrecoverableError) { jww.ERROR.Printf("SendManyCMIX failed to send to EphIDs [%s] "+ "(sources: %s) on round %d, trying a new round %+v", - ephemeralIdsString, recipientString, bestRound.ID, err) + ephemeralIDsString, recipientString, bestRound.ID, err) jww.INFO.Printf("error received, continuing: %v", err) continue } @@ -174,13 +199,15 @@ func sendManyCmixHelper(sender *gateway.Sender, msgs map[id.ID]format.Message, // Return if it sends properly gwSlotResp := result.(*pb.GatewaySlotResponse) if gwSlotResp.Accepted { - jww.INFO.Printf("Successfully sent to EphIDs %v (sources: [%s]) in "+ - "round %d", ephemeralIdsString, recipientString, bestRound.ID) - return id.Round(bestRound.ID), ephemeralIds, nil + m := fmt.Sprintf("Successfully sent to EphIDs %s (sources: [%s]) "+ + "in round %d", ephemeralIDsString, recipientString, bestRound.ID) + jww.INFO.Print(m) + events.Report(1, "MessageSendMany", "Metric", m) + return id.Round(bestRound.ID), ephemeralIDs, nil } else { jww.FATAL.Panicf("Gateway %s returned no error, but failed to "+ "accept message when sending to EphIDs [%s] (%s) on round %d", - firstGateway, ephemeralIdsString, recipientString, bestRound.ID) + firstGateway, ephemeralIDsString, recipientString, bestRound.ID) } } diff --git a/network/message/sendManyCmix_test.go b/network/message/sendManyCmix_test.go index b787b3abc4445f5266a2cea51ced2429cdddba40..476004edd18f481692d70794b3c22d3d8fa79b17 100644 --- a/network/message/sendManyCmix_test.go +++ b/network/message/sendManyCmix_test.go @@ -29,6 +29,7 @@ import ( // Unit test func Test_attemptSendManyCmix(t *testing.T) { sess1 := storage.InitTestingSession(t) + events := &dummyEvent{} numRecipients := 3 recipients := make([]*id.ID, numRecipients) @@ -121,13 +122,17 @@ func Test_attemptSendManyCmix(t *testing.T) { messages[i] = msgCmix } - msgMap := make(map[id.ID]format.Message, numRecipients) + msgList := make([]message.TargetedCmixMessage, numRecipients) for i := 0; i < numRecipients; i++ { - msgMap[*recipients[i]] = msgCmix + msgList[i] = message.TargetedCmixMessage{ + Recipient: recipients[i], + Message: msgCmix, + } } - _, _, err = sendManyCmixHelper(sender, msgMap, params.GetDefaultCMIX(), m.Instance, - m.Session, m.nodeRegistration, m.Rng, m.TransmissionID, &MockSendCMIXComms{t: t}) + _, _, err = sendManyCmixHelper(sender, msgList, params.GetDefaultCMIX(), + make(map[string]interface{}), m.Instance, m.Session, m.nodeRegistration, + m.Rng, events, m.TransmissionID, &MockSendCMIXComms{t: t}, nil) if err != nil { t.Errorf("Failed to sendcmix: %+v", err) } diff --git a/network/rounds/remoteFilters_test.go b/network/rounds/remoteFilters_test.go index 6addedec8555e5bf7a55dec115813d8af0ac7993..924f7f970d9dd3c9f8d915b4f3f1f3524417de11 100644 --- a/network/rounds/remoteFilters_test.go +++ b/network/rounds/remoteFilters_test.go @@ -129,6 +129,10 @@ func TestValidFilterRange(t *testing.T) { RoundRange: roundRange, } + // Fix for test on Windows machines: provides extra buffer between + // time.Now() for the reception.Identity and the mixmessages.ClientBlooms + time.Sleep(time.Millisecond) + msg := &mixmessages.ClientBlooms{ Period: int64(12 * time.Hour), FirstTimestamp: time.Now().UnixNano(), diff --git a/network/send.go b/network/send.go index 8fff3ac152e65a55822d6e824aba7f7f6d4f09eb..4752b9f6a4d41a666fb09785192469330b2e8fa6 100644 --- a/network/send.go +++ b/network/send.go @@ -36,10 +36,10 @@ func (m *manager) SendCMIX(msg format.Message, recipient *id.ID, param params.CM // SendManyCMIX sends many "raw" CMIX message payloads to each of the // provided recipients. Used for group chat functionality. Returns the // round ID of the round the payload was sent or an error if it fails. -func (m *manager) SendManyCMIX(messages map[id.ID]format.Message, +func (m *manager) SendManyCMIX(msgs []message.TargetedCmixMessage, p params.CMIX) (id.Round, []ephemeral.Id, error) { - return m.message.SendManyCMIX(m.sender, messages, p) + return m.message.SendManyCMIX(m.sender, msgs, p, nil) } // SendUnsafe sends an unencrypted payload to the provided recipient diff --git a/single/manager_test.go b/single/manager_test.go index a1658a0b733bb21df35e17deef53072a76f7cf21..e240211160072a32b9b966d893a5610103afe6f8 100644 --- a/single/manager_test.go +++ b/single/manager_test.go @@ -315,7 +315,7 @@ func (tnm *testNetworkManager) SendCMIX(msg format.Message, _ *id.ID, _ params.C return id.Round(rand.Uint64()), ephemeral.Id{}, nil } -func (tnm *testNetworkManager) SendManyCMIX(messages map[id.ID]format.Message, p params.CMIX) (id.Round, []ephemeral.Id, error) { +func (tnm *testNetworkManager) SendManyCMIX(messages []message.TargetedCmixMessage, p params.CMIX) (id.Round, []ephemeral.Id, error) { if tnm.cmixTimeout != 0 { time.Sleep(tnm.cmixTimeout) } else if tnm.cmixErr { @@ -326,7 +326,7 @@ func (tnm *testNetworkManager) SendManyCMIX(messages map[id.ID]format.Message, p defer tnm.Unlock() for _, msg := range messages { - tnm.msgs = append(tnm.msgs, msg) + tnm.msgs = append(tnm.msgs, msg.Message) } return id.Round(rand.Uint64()), []ephemeral.Id{}, nil diff --git a/storage/e2e/key_test.go b/storage/e2e/key_test.go index 0f93e69a2818967cbfd53caaf26d034a8c224296..c548b8c78b496dc2981269c9c4a1a33fe0977d6e 100644 --- a/storage/e2e/key_test.go +++ b/storage/e2e/key_test.go @@ -9,6 +9,7 @@ package e2e import ( "bytes" + "gitlab.com/elixxir/client/storage/utility" "gitlab.com/elixxir/client/storage/versioned" "gitlab.com/elixxir/crypto/cyclic" dh "gitlab.com/elixxir/crypto/diffieHellman" @@ -197,7 +198,7 @@ func getSession(t *testing.T) *Session { grp: grp, } - keyState, err := newStateVector(versioned.NewKV(make(ekv.Memstore)), "keyState", rand.Uint32()) + keyState, err := utility.NewStateVector(versioned.NewKV(make(ekv.Memstore)), "keyState", rand.Uint32()) if err != nil { panic(err) } diff --git a/storage/e2e/relationship_test.go b/storage/e2e/relationship_test.go index 1ad211fcd55a389e758711a22ac8bcff24fdd76e..78912e0b89242f5131a1be45d75c7308f1b70fea 100644 --- a/storage/e2e/relationship_test.go +++ b/storage/e2e/relationship_test.go @@ -14,7 +14,6 @@ import ( "gitlab.com/elixxir/ekv" "gitlab.com/xx_network/primitives/id" "reflect" - "sync" "testing" ) @@ -42,7 +41,7 @@ func TestRelationship_MarshalUnmarshal(t *testing.T) { t.Fatal(err) } - // compare sb2 sesh list and map + // compare sb2 session list and map if !relationshipsEqual(sb, sb2) { t.Error("session buffs not equal") } @@ -289,7 +288,7 @@ func TestRelationship_GetNewestRekeyableSession(t *testing.T) { } // make the very newest session unrekeyable: the previous session - //sb.sessions[1].negotiationStatus = Confirmed + // sb.sessions[1].negotiationStatus = Confirmed sb.sessions[0].negotiationStatus = Unconfirmed session5 := sb.getNewestRekeyableSession() @@ -320,9 +319,9 @@ func TestRelationship_GetSessionForSending(t *testing.T) { unconfirmedRekey.partnerPubKey, unconfirmedRekey.partnerSource, Sending, unconfirmedRekey.e2eParams) sb.sessions[0].negotiationStatus = Unconfirmed - sb.sessions[0].keyState.numkeys = 2000 + sb.sessions[0].keyState.SetNumKeysTEST(2000, t) sb.sessions[0].rekeyThreshold = 1000 - sb.sessions[0].keyState.numAvailable = 600 + sb.sessions[0].keyState.SetNumAvailableTEST(600, t) sending := sb.getSessionForSending() if sending.GetID() != sb.sessions[0].GetID() { t.Error("got an unexpected session") @@ -340,9 +339,9 @@ func TestRelationship_GetSessionForSending(t *testing.T) { unconfirmedActive.partnerPubKey, unconfirmedActive.partnerSource, Sending, unconfirmedActive.e2eParams) sb.sessions[0].negotiationStatus = Unconfirmed - sb.sessions[0].keyState.numkeys = 2000 + sb.sessions[0].keyState.SetNumKeysTEST(2000, t) sb.sessions[0].rekeyThreshold = 1000 - sb.sessions[0].keyState.numAvailable = 2000 + sb.sessions[0].keyState.SetNumAvailableTEST(2000, t) sending = sb.getSessionForSending() if sending.GetID() != sb.sessions[0].GetID() { t.Error("got an unexpected session") @@ -361,9 +360,9 @@ func TestRelationship_GetSessionForSending(t *testing.T) { confirmedRekey.partnerPubKey, confirmedRekey.partnerSource, Sending, confirmedRekey.e2eParams) sb.sessions[0].negotiationStatus = Confirmed - sb.sessions[0].keyState.numkeys = 2000 + sb.sessions[0].keyState.SetNumKeysTEST(2000, t) sb.sessions[0].rekeyThreshold = 1000 - sb.sessions[0].keyState.numAvailable = 600 + sb.sessions[0].keyState.SetNumAvailableTEST(600, t) sending = sb.getSessionForSending() if sending.GetID() != sb.sessions[0].GetID() { t.Error("got an unexpected session") @@ -381,8 +380,8 @@ func TestRelationship_GetSessionForSending(t *testing.T) { Sending, confirmedActive.e2eParams) sb.sessions[0].negotiationStatus = Confirmed - sb.sessions[0].keyState.numkeys = 2000 - sb.sessions[0].keyState.numAvailable = 2000 + sb.sessions[0].keyState.SetNumKeysTEST(2000, t) + sb.sessions[0].keyState.SetNumAvailableTEST(2000, t) sb.sessions[0].rekeyThreshold = 1000 sending = sb.getSessionForSending() if sending.GetID() != sb.sessions[0].GetID() { @@ -468,7 +467,7 @@ func TestSessionBuff_TriggerNegotiation(t *testing.T) { session.partnerPubKey, session.partnerSource, Sending, session.e2eParams) session.negotiationStatus = Confirmed - // The added session isn't ready for rekey so it's not returned here + // The added session isn't ready for rekey, so it's not returned here negotiations := sb.TriggerNegotiation() if len(negotiations) != 0 { t.Errorf("should have had zero negotiations: %+v", negotiations) @@ -478,7 +477,7 @@ func TestSessionBuff_TriggerNegotiation(t *testing.T) { session2 = sb.AddSession(session2.myPrivKey, session2.partnerPubKey, session2.partnerPubKey, session2.partnerSource, Sending, session2.e2eParams) - session2.keyState.numAvailable = 4 + session2.keyState.SetNumAvailableTEST(4, t) session2.negotiationStatus = Confirmed negotiations = sb.TriggerNegotiation() if len(negotiations) != 1 { @@ -553,8 +552,8 @@ func makeTestRelationshipManager(t *testing.T) *Manager { // Revises a session to fit a sessionbuff and saves it to the sessionbuff's kv store func adaptToBuff(session *Session, buff *relationship, t *testing.T) { session.relationship = buff - session.keyState.kv = buff.manager.kv - err := session.keyState.save() + session.keyState.SetKvTEST(buff.manager.kv, t) + err := session.keyState.SaveTEST(t) if err != nil { t.Fatal(err) } @@ -595,38 +594,5 @@ func relationshipsEqual(buff *relationship, buff2 *relationship) bool { } func Test_relationship_getNewestRekeyableSession(t *testing.T) { - type fields struct { - manager *Manager - t RelationshipType - kv *versioned.KV - sessions []*Session - sessionByID map[SessionID]*Session - fingerprint []byte - mux sync.RWMutex - sendMux sync.Mutex - } - tests := []struct { - name string - fields fields - want *Session - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := &relationship{ - manager: tt.fields.manager, - t: tt.fields.t, - kv: tt.fields.kv, - sessions: tt.fields.sessions, - sessionByID: tt.fields.sessionByID, - fingerprint: tt.fields.fingerprint, - mux: tt.fields.mux, - sendMux: tt.fields.sendMux, - } - if got := r.getNewestRekeyableSession(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("getNewestRekeyableSession() = %v, want %v", got, tt.want) - } - }) - } + // TODO: Add test cases. } diff --git a/storage/e2e/session.go b/storage/e2e/session.go index 9f74ebf35efe310214237e6da36f9f4029c872be..b49404bb3d3b3931917fd568f9ed026a0faf6a6d 100644 --- a/storage/e2e/session.go +++ b/storage/e2e/session.go @@ -13,6 +13,7 @@ import ( "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/storage/utility" "gitlab.com/elixxir/client/storage/versioned" "gitlab.com/elixxir/crypto/cyclic" dh "gitlab.com/elixxir/crypto/diffieHellman" @@ -63,7 +64,7 @@ type Session struct { // Received Keys dirty bits // Each bit represents a single Key - keyState *stateVector + keyState *utility.StateVector //mutex mux sync.RWMutex @@ -213,12 +214,12 @@ func (s *Session) Delete() { sessionErr := s.kv.Delete(sessionKey, currentSessionVersion) if stateVectorErr != nil && sessionErr != nil { - jww.ERROR.Printf("Error deleting state vector with key %v: %v", stateVectorKey, stateVectorErr.Error()) + jww.ERROR.Printf("Error deleting state vector %s: %v", s.keyState, stateVectorErr.Error()) jww.ERROR.Panicf("Error deleting session with key %v: %v", sessionKey, sessionErr) } else if sessionErr != nil { jww.ERROR.Panicf("Error deleting session with key %v: %v", sessionKey, sessionErr) } else if stateVectorErr != nil { - jww.ERROR.Panicf("Error deleting state vector with key %v: %v", stateVectorKey, stateVectorErr.Error()) + jww.ERROR.Panicf("Error deleting state vector %s: %v", s.keyState, stateVectorErr.Error()) } } @@ -330,7 +331,7 @@ func (s *Session) unmarshal(b []byte) error { s.partner, _ = id.Unmarshal(sd.Partner) copy(s.partnerSource[:], sd.Trigger) - s.keyState, err = loadStateVector(s.kv, "") + s.keyState, err = utility.LoadStateVector(s.kv, "") if err != nil { return err } @@ -582,7 +583,7 @@ func (s *Session) generate(kv *versioned.KV) *versioned.KV { // To generate the state vector key correctly, // basekey must be computed as the session ID is the hash of basekey var err error - s.keyState, err = newStateVector(kv, "", numKeys) + s.keyState, err = utility.NewStateVector(kv, "", numKeys) if err != nil { jww.FATAL.Printf("Failed key generation: %s", err) } diff --git a/storage/e2e/session_test.go b/storage/e2e/session_test.go index 258cea9edab79a1e0fe6f214b4b6a78e4814e9d3..2d1d3dffd9c28d28e709de99ee8454a7e946f24c 100644 --- a/storage/e2e/session_test.go +++ b/storage/e2e/session_test.go @@ -10,6 +10,7 @@ package e2e import ( "errors" "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/storage/utility" "gitlab.com/elixxir/client/storage/versioned" dh "gitlab.com/elixxir/crypto/diffieHellman" "gitlab.com/elixxir/crypto/fastRNG" @@ -29,7 +30,7 @@ func TestSession_generate_noPrivateKeyReceive(t *testing.T) { partnerPrivKey := dh.GeneratePrivateKey(dh.DefaultPrivateKeyLength, grp, rng) partnerPubKey := dh.GeneratePublicKey(partnerPrivKey, grp) - //create context objects for general use + // create context objects for general use fps := newFingerprints() ctx := &context{ fa: &fps, @@ -37,7 +38,7 @@ func TestSession_generate_noPrivateKeyReceive(t *testing.T) { rng: fastRNG.NewStreamGenerator(1, 0, csprng.NewSystemRNG), } - //build the session + // build the session s := &Session{ partnerPubKey: partnerPubKey, e2eParams: params.GetDefaultE2ESessionParams(), @@ -47,33 +48,33 @@ func TestSession_generate_noPrivateKeyReceive(t *testing.T) { t: Receive, } - //run the generate command + // run the generate command s.generate(versioned.NewKV(make(ekv.Memstore))) - //check that it generated a private key + // check that it generated a private key if s.myPrivKey == nil { t.Errorf("Private key was not generated when missing") } - //verify the basekey is correct + // verify the base key is correct expectedBaseKey := dh.GenerateSessionKey(s.myPrivKey, s.partnerPubKey, grp) if expectedBaseKey.Cmp(s.baseKey) != 0 { t.Errorf("generated base key does not match expected base key") } - //verify the rekeyThreshold was generated + // verify the rekeyThreshold was generated if s.rekeyThreshold == 0 { t.Errorf("rekeyThreshold not generated") } - //verify keystates where created + // verify key states was created if s.keyState == nil { t.Errorf("keystates not generated") } - //verify keys were registered in the fingerprintMap - for keyNum := uint32(0); keyNum < s.keyState.numkeys; keyNum++ { + // verify keys were registered in the fingerprintMap + for keyNum := uint32(0); keyNum < s.keyState.GetNumKeys(); keyNum++ { key := newKey(s, keyNum) if _, ok := fps.toKey[key.Fingerprint()]; !ok { t.Errorf("key %v not in fingerprint map", keyNum) @@ -90,14 +91,14 @@ func TestSession_generate_PrivateKeySend(t *testing.T) { myPrivKey := dh.GeneratePrivateKey(dh.DefaultPrivateKeyLength, grp, rng) - //create context objects for general use + // create context objects for general use fps := newFingerprints() ctx := &context{ fa: &fps, grp: grp, } - //build the session + // build the session s := &Session{ myPrivKey: myPrivKey, partnerPubKey: partnerPubKey, @@ -108,33 +109,33 @@ func TestSession_generate_PrivateKeySend(t *testing.T) { t: Send, } - //run the generate command + // run the generate command s.generate(versioned.NewKV(make(ekv.Memstore))) - //check that it generated a private key + // check that it generated a private key if s.myPrivKey.Cmp(myPrivKey) != 0 { t.Errorf("Public key was generated when not missing") } - //verify the basekey is correct + // verify the base key is correct expectedBaseKey := dh.GenerateSessionKey(s.myPrivKey, s.partnerPubKey, grp) if expectedBaseKey.Cmp(s.baseKey) != 0 { t.Errorf("generated base key does not match expected base key") } - //verify the rekeyThreshold was generated + // verify the rekeyThreshold was generated if s.rekeyThreshold == 0 { t.Errorf("rekeyThreshold not generated") } - //verify keystates where created + // verify keyState was created if s.keyState == nil { t.Errorf("keystates not generated") } - //verify keys were not registered in the fingerprintMap - for keyNum := uint32(0); keyNum < s.keyState.numkeys; keyNum++ { + // verify keys were not registered in the fingerprintMap + for keyNum := uint32(0); keyNum < s.keyState.GetNumKeys(); keyNum++ { key := newKey(s, keyNum) if _, ok := fps.toKey[key.Fingerprint()]; ok { t.Errorf("key %v in fingerprint map", keyNum) @@ -187,9 +188,9 @@ func TestSession_Load(t *testing.T) { } // Key state should also be loaded and equivalent to the other session // during loadSession() - err = cmpKeyState(sessionA.keyState, sessionB.keyState) - if err != nil { - t.Error(err) + if !reflect.DeepEqual(sessionA.keyState, sessionB.keyState) { + t.Errorf("Two key states do not match.\nsessionA: %+v\nsessionB: %+v", + sessionA.keyState, sessionB.keyState) } // For everything else, just make sure it's populated if sessionB.relationship == nil { @@ -200,31 +201,6 @@ func TestSession_Load(t *testing.T) { } } -func cmpKeyState(a *stateVector, b *stateVector) error { - // ignore ctx, mux - if a.key != b.key { - return errors.New("keys differed") - } - if a.numAvailable != b.numAvailable { - return errors.New("numAvailable differed") - } - if a.firstAvailable != b.firstAvailable { - return errors.New("firstAvailable differed") - } - if a.numkeys != b.numkeys { - return errors.New("numkeys differed") - } - if len(a.vect) != len(b.vect) { - return errors.New("vect differed") - } - for i := range a.vect { - if a.vect[i] != b.vect[i] { - return errors.New("vect differed") - } - } - return nil -} - // Create a new session. Marshal and unmarshal it func TestSession_Serialization(t *testing.T) { s, ctx := makeTestSession() @@ -262,7 +238,7 @@ func cmpSerializedFields(a *Session, b *Session) error { return errors.New("minKeys differed") } if a.e2eParams.NumRekeys != b.e2eParams.NumRekeys { - return errors.New("numRekeys differed") + return errors.New("NumRekeys differed") } if a.e2eParams.MinNumKeys != b.e2eParams.MinNumKeys { return errors.New("minNumKeys differed") @@ -297,7 +273,7 @@ func TestSession_PopKey(t *testing.T) { } // PopKey should return the first available key if key.keyNum != 0 { - t.Error("First key popped should have keynum 0") + t.Error("First key popped should have keyNum 0") } } @@ -311,7 +287,7 @@ func TestSession_Delete(t *testing.T) { s.Delete() // Getting the keys that should have been stored should now result in an error - _, err = s.kv.Get(stateVectorKey, 0) + _, err = utility.LoadStateVector(s.kv, "") if err == nil { t.Error("State vector was gettable") } @@ -330,7 +306,7 @@ func TestSession_PopKey_Error(t *testing.T) { s, _ := makeTestSession() // Construct a specific state vector that will quickly run out of keys var err error - s.keyState, err = newStateVector(s.kv, "", 0) + s.keyState, err = utility.NewStateVector(s.kv, "", 0) if err != nil { t.Fatal(err) } @@ -342,7 +318,7 @@ func TestSession_PopKey_Error(t *testing.T) { } // PopRekey should return the next key -// There's no boundary, except for the number of keynums in the state vector +// There's no boundary, except for the number of keyNums in the state vector func TestSession_PopReKey(t *testing.T) { s, _ := makeTestSession() key, err := s.PopReKey() @@ -357,7 +333,7 @@ func TestSession_PopReKey(t *testing.T) { } // PopReKey should return the first available key if key.keyNum != 0 { - t.Error("First key popped should have keynum 0") + t.Error("First key popped should have keyNum 0") } } @@ -367,7 +343,7 @@ func TestSession_PopReKey_Err(t *testing.T) { s, _ := makeTestSession() // Construct a specific state vector that will quickly run out of keys var err error - s.keyState, err = newStateVector(s.kv, "", 0) + s.keyState, err = utility.NewStateVector(s.kv, "", 0) if err != nil { t.Fatal(err) } @@ -377,7 +353,7 @@ func TestSession_PopReKey_Err(t *testing.T) { } } -// Simple test that shows the base key can get got +// Simple test that shows the base key can be got func TestSession_GetBaseKey(t *testing.T) { s, _ := makeTestSession() baseKey := s.GetBaseKey() @@ -389,8 +365,8 @@ func TestSession_GetBaseKey(t *testing.T) { // Smoke test for GetID func TestSession_GetID(t *testing.T) { s, _ := makeTestSession() - id := s.GetID() - if len(id.Marshal()) == 0 { + sid := s.GetID() + if len(sid.Marshal()) == 0 { t.Error("Zero length for session ID!") } } @@ -430,24 +406,24 @@ func TestSession_IsConfirmed(t *testing.T) { func TestSession_Status(t *testing.T) { s, _ := makeTestSession() var err error - s.keyState, err = newStateVector(s.kv, "", 500) + s.keyState, err = utility.NewStateVector(s.kv, "", 500) if err != nil { t.Fatal(err) } - s.keyState.numAvailable = 0 + s.keyState.SetNumAvailableTEST(0, t) if s.Status() != RekeyEmpty { t.Error("status should have been rekey empty with no keys left") } - s.keyState.numAvailable = 1 + s.keyState.SetNumAvailableTEST(1, t) if s.Status() != Empty { t.Error("Status should have been empty") } // Passing the rekeyThreshold should result in a rekey being needed - s.keyState.numAvailable = s.keyState.numkeys - s.rekeyThreshold + s.keyState.SetNumAvailableTEST(s.keyState.GetNumKeys()-s.rekeyThreshold, t) if s.Status() != RekeyNeeded { t.Error("Just past the rekeyThreshold, rekey should be needed") } - s.keyState.numAvailable = s.keyState.numkeys + s.keyState.SetNumAvailableTEST(s.keyState.GetNumKeys(), t) s.rekeyThreshold = 450 if s.Status() != Active { t.Errorf("If all keys available, session should be active, recieved: %s", s.Status()) @@ -539,8 +515,8 @@ func TestSession_SetNegotiationStatus(t *testing.T) { func TestSession_TriggerNegotiation(t *testing.T) { s, _ := makeTestSession() // Set up num keys used to be > rekeyThreshold: should partnerSource negotiation - s.keyState.numAvailable = 50 - s.keyState.numkeys = 100 + s.keyState.SetNumAvailableTEST(50, t) + s.keyState.SetNumKeysTEST(100, t) s.rekeyThreshold = 49 s.negotiationStatus = Confirmed @@ -611,7 +587,7 @@ func makeTestSession() (*Session, *context) { myPrivKey := dh.GeneratePrivateKey(dh.DefaultPrivateKeyLength, grp, rng) baseKey := dh.GenerateSessionKey(myPrivKey, partnerPubKey, grp) - //create context objects for general use + // create context objects for general use fps := newFingerprints() ctx := &context{ fa: &fps, @@ -641,7 +617,7 @@ func makeTestSession() (*Session, *context) { partner: &id.ID{}, } var err error - s.keyState, err = newStateVector(s.kv, + s.keyState, err = utility.NewStateVector(s.kv, "", 1024) if err != nil { panic(err) diff --git a/storage/e2e/stateVector.go b/storage/e2e/stateVector.go deleted file mode 100644 index fea42c68cec3475ba27d1e6a16f884c7b8b73463..0000000000000000000000000000000000000000 --- a/storage/e2e/stateVector.go +++ /dev/null @@ -1,255 +0,0 @@ -/////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -/////////////////////////////////////////////////////////////////////////////// - -package e2e - -import ( - "encoding/json" - "fmt" - "github.com/pkg/errors" - jww "github.com/spf13/jwalterweatherman" - "gitlab.com/elixxir/client/storage/versioned" - "gitlab.com/xx_network/primitives/netTime" - "sync" -) - -const currentStateVectorVersion = 0 -const stateVectorKey = "stateVector" - -type stateVector struct { - kv *versioned.KV - key string - - // Bitfield for key states - // If a key is clean, its bit will be 0 - // Otherwise, it's dirty/used/not available, and its bit will be 1 - vect []uint64 - - firstAvailable uint32 - numkeys uint32 - numAvailable uint32 - - mux sync.RWMutex -} - -// Fields must be exported for json marshal to serialize them -type stateVectorDisk struct { - Vect []uint64 - FirstAvailable uint32 - NumAvailable uint32 - Numkeys uint32 -} - -func newStateVector(kv *versioned.KV, key string, numkeys uint32) (*stateVector, error) { - numBlocks := (numkeys + 63) / 64 - - sv := &stateVector{ - kv: kv, - vect: make([]uint64, numBlocks), - key: stateVectorKey + key, - firstAvailable: 0, - numAvailable: numkeys, - numkeys: numkeys, - } - - return sv, sv.save() -} - -func loadStateVector(kv *versioned.KV, key string) (*stateVector, error) { - sv := &stateVector{ - kv: kv, - key: stateVectorKey + key, - } - - obj, err := kv.Get(sv.key, currentStateVectorVersion) - if err != nil { - return nil, err - } - - err = sv.unmarshal(obj.Data) - if err != nil { - return nil, err - } - - return sv, nil -} - -func (sv *stateVector) save() error { - now := netTime.Now() - - data, err := sv.marshal() - if err != nil { - return err - } - - obj := versioned.Object{ - Version: currentStateVectorVersion, - Timestamp: now, - Data: data, - } - - return sv.kv.Set(sv.key, currentStateVectorVersion, &obj) -} - -func (sv *stateVector) Use(keynum uint32) { - sv.mux.Lock() - defer sv.mux.Unlock() - - block := keynum / 64 - pos := keynum % 64 - - sv.vect[block] |= 1 << pos - - if keynum == sv.firstAvailable { - sv.nextAvailable() - } - - sv.numAvailable-- - - if err := sv.save(); err != nil { - jww.FATAL.Printf("Failed to save %s on Use(): %s", sv, err) - } -} - -func (sv *stateVector) GetNumAvailable() uint32 { - sv.mux.RLock() - defer sv.mux.RUnlock() - return sv.numAvailable -} - -func (sv *stateVector) GetNumUsed() uint32 { - sv.mux.RLock() - defer sv.mux.RUnlock() - return sv.numkeys - sv.numAvailable -} - -func (sv *stateVector) Used(keynum uint32) bool { - sv.mux.RLock() - defer sv.mux.RUnlock() - - return sv.used(keynum) -} - -func (sv *stateVector) used(keynum uint32) bool { - block := keynum / 64 - pos := keynum % 64 - - return (sv.vect[block]>>pos)&1 == 1 -} - -func (sv *stateVector) Next() (uint32, error) { - sv.mux.Lock() - defer sv.mux.Unlock() - - if sv.firstAvailable >= sv.numkeys { - return sv.numkeys, errors.New("No keys remaining") - } - - next := sv.firstAvailable - - sv.nextAvailable() - sv.numAvailable-- - - if err := sv.save(); err != nil { - jww.FATAL.Printf("Failed to save %s on Next(): %s", sv, err) - } - - return next, nil - -} - -func (sv *stateVector) GetNumKeys() uint32 { - return sv.numkeys -} - -//returns a list of unused keys -func (sv *stateVector) GetUnusedKeyNums() []uint32 { - sv.mux.RLock() - defer sv.mux.RUnlock() - - keyNums := make([]uint32, 0, sv.numAvailable) - - for keyNum := sv.firstAvailable; keyNum < sv.numkeys; keyNum++ { - if !sv.used(keyNum) { - keyNums = append(keyNums, keyNum) - } - } - - return keyNums -} - -//returns a list of used keys -func (sv *stateVector) GetUsedKeyNums() []uint32 { - sv.mux.RLock() - defer sv.mux.RUnlock() - - keyNums := make([]uint32, 0, sv.numkeys-sv.numAvailable) - - for keyNum := sv.firstAvailable; keyNum < sv.numkeys; keyNum++ { - if sv.used(keyNum) { - keyNums = append(keyNums, keyNum) - } - } - - return keyNums -} - -//Adheres to the stringer interface -func (sv *stateVector) String() string { - return fmt.Sprintf("stateVector: %s", sv.key) -} - -//Deletes the state vector from storage -func (sv *stateVector) Delete() error { - return sv.kv.Delete(sv.key, currentStateVectorVersion) -} - -// finds the next used state and sets that as firstAvailable. This does not -// execute a store and a store must be executed after. -func (sv *stateVector) nextAvailable() { - - //plus one so we start at the next one - pos := sv.firstAvailable + 1 - block := pos / 64 - - for block < uint32(len(sv.vect)) && (sv.vect[block]>>(pos%64))&1 == 1 { - pos++ - block = pos / 64 - } - - sv.firstAvailable = pos -} - -//ekv functions -func (sv *stateVector) marshal() ([]byte, error) { - svd := stateVectorDisk{} - - svd.FirstAvailable = sv.firstAvailable - svd.Numkeys = sv.numkeys - svd.NumAvailable = sv.numAvailable - svd.Vect = sv.vect - - return json.Marshal(&svd) -} - -func (sv *stateVector) unmarshal(b []byte) error { - - svd := stateVectorDisk{} - - err := json.Unmarshal(b, &svd) - - if err != nil { - return err - } - - sv.firstAvailable = svd.FirstAvailable - sv.numkeys = svd.Numkeys - sv.numAvailable = svd.NumAvailable - sv.vect = svd.Vect - - return nil -} diff --git a/storage/e2e/stateVector_test.go b/storage/e2e/stateVector_test.go deleted file mode 100644 index e67d251c3c07c4f6ef3014080dee93a88a7846e4..0000000000000000000000000000000000000000 --- a/storage/e2e/stateVector_test.go +++ /dev/null @@ -1,277 +0,0 @@ -/////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -/////////////////////////////////////////////////////////////////////////////// - -package e2e - -import ( - "fmt" - "gitlab.com/elixxir/client/storage/versioned" - "gitlab.com/elixxir/ekv" - "math" - "math/bits" - "reflect" - "testing" -) - -// GetNumAvailable gets the number of slots left in the state vector -func TestStateVector_GetNumAvailable(t *testing.T) { - const numAvailable = 23 - sv := &stateVector{ - numAvailable: numAvailable, - } - // At the start, NumAvailable should be the same as numKeys - // as none of the keys have been used - if sv.GetNumAvailable() != numAvailable { - t.Errorf("expected %v available, actually %v available", numAvailable, sv.GetNumAvailable()) - } -} - -// Shows that GetNumUsed returns the number of slots used in the state vector -func TestStateVector_GetNumUsed(t *testing.T) { - const numAvailable = 23 - const numKeys = 50 - sv := &stateVector{ - numkeys: numKeys, - numAvailable: numAvailable, - } - - if sv.GetNumUsed() != numKeys-numAvailable { - t.Errorf("Expected %v used, got %v", numKeys-numAvailable, sv.GetNumUsed()) - } -} - -func TestStateVector_GetNumKeys(t *testing.T) { - const numKeys = 32 - sv, err := newStateVector(versioned.NewKV(make(ekv.Memstore)), "key", numKeys) - if err != nil { - t.Fatal(err) - } - - // GetNumKeys should always be the same as numKeys - if sv.GetNumKeys() != numKeys { - t.Errorf("expected %v available, actually %v available", numKeys, sv.GetNumAvailable()) - } -} - -// Shows that Next mutates vector state as expected -// Shows that Next can find key indexes all throughout the bitfield -func TestStateVector_Next(t *testing.T) { - // Expected results: all keynums, and beyond the last key - expectedFirstAvail := []uint32{139, 145, 300, 360, 420, 761, 868, 875, 893, 995} - - const numKeys = 1000 - sv, err := newStateVector(versioned.NewKV(make(ekv.Memstore)), "key", numKeys) - if err != nil { - t.Fatal(err) - } - - // Set all bits to dirty to start - for i := range sv.vect { - sv.vect[i] = 0xffffffffffffffff - } - - // Set a few clean bits randomly - const numBitsSet = 10 - for i := 0; i < numBitsSet; i++ { - keyNum := expectedFirstAvail[i] - // Set a bit clean in the state vector - vectIndex := keyNum / 64 - bitIndex := keyNum % 64 - sv.vect[vectIndex] &= ^bits.RotateLeft64(uint64(1), int(bitIndex)) - } - - sv.numAvailable = numBitsSet - sv.nextAvailable() - - // Calling Next ten times should give all of the keyNums we set - // It should change firstAvailable, but doesn't mutate the bit field itself - // (that should be done with Use) - for numCalls := 0; numCalls < numBitsSet; numCalls++ { - keyNum, err := sv.Next() - if err != nil { - t.Fatal(err) - } - if keyNum != expectedFirstAvail[numCalls] { - t.Errorf("keynum %v didn't match expected %v at index %v", keyNum, expectedFirstAvail[numCalls], numCalls) - } - } - - // One more call should cause an error - _, err = sv.Next() - if err == nil { - t.Error("Calling Next() after all keys have been found should result in error, as firstAvailable is more than numKeys") - } - // firstAvailable should now be beyond the end of the bitfield - if sv.firstAvailable < numKeys { - t.Error("Last Next() call should have set firstAvailable beyond numKeys") - } -} - -// Shows that Next mutates vector state as expected -// Shows that Next can find key indexes all throughout the bitfield -// A bug was found when the next avalible was in the first index of a word, this tests that case -func TestStateVector_Next_EdgeCase(t *testing.T) { - // Expected results: all keynums, and beyond the last key - //expectedFirstAvail := []uint32{139, 145, 300, 360, 420, 761, 868, 875, 893, 995} - - const numKeys = 1000 - sv, err := newStateVector(versioned.NewKV(make(ekv.Memstore)), "key", numKeys) - if err != nil { - t.Fatal(err) - } - - // Set a few clean bits randomly - sv.vect[0] = math.MaxUint64 - - sv.firstAvailable = 0 - sv.nextAvailable() - - // firstAvailable should now be beyond the end of the bitfield - if sv.firstAvailable != 64 { - t.Errorf("Next avalivle skiped the first of the next word, should be 64, is %d", sv.firstAvailable) - } -} - -// Shows that Use() mutates the state vector itself -func TestStateVector_Use(t *testing.T) { - // These keyNums will be set to dirty with Use - keyNums := []uint32{139, 145, 300, 360, 420, 761, 868, 875, 893, 995} - - const numKeys = 1000 - sv, err := newStateVector(versioned.NewKV(make(ekv.Memstore)), "key", numKeys) - if err != nil { - t.Fatal(err) - } - - // Expected vector states as bits are set - var expectedVect [][]uint64 - expectedVect = append(expectedVect, []uint64{0, 0, 0x800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) - expectedVect = append(expectedVect, []uint64{0, 0, 0x20800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) - expectedVect = append(expectedVect, []uint64{0, 0, 0x20800, 0, 0x100000000000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) - expectedVect = append(expectedVect, []uint64{0, 0, 0x20800, 0, 0x100000000000, 0x10000000000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) - expectedVect = append(expectedVect, []uint64{0, 0, 0x20800, 0, 0x100000000000, 0x10000000000, 0x1000000000, 0, 0, 0, 0, 0, 0, 0, 0, 0}) - expectedVect = append(expectedVect, []uint64{0, 0, 0x20800, 0, 0x100000000000, 0x10000000000, 0x1000000000, 0, 0, 0, 0, 0x200000000000000, 0, 0, 0, 0}) - expectedVect = append(expectedVect, []uint64{0, 0, 0x20800, 0, 0x100000000000, 0x10000000000, 0x1000000000, 0, 0, 0, 0, 0x200000000000000, 0, 0x1000000000, 0, 0}) - expectedVect = append(expectedVect, []uint64{0, 0, 0x20800, 0, 0x100000000000, 0x10000000000, 0x1000000000, 0, 0, 0, 0, 0x200000000000000, 0, 0x81000000000, 0, 0}) - expectedVect = append(expectedVect, []uint64{0, 0, 0x20800, 0, 0x100000000000, 0x10000000000, 0x1000000000, 0, 0, 0, 0, 0x200000000000000, 0, 0x2000081000000000, 0, 0}) - expectedVect = append(expectedVect, []uint64{0, 0, 0x20800, 0, 0x100000000000, 0x10000000000, 0x1000000000, 0, 0, 0, 0, 0x200000000000000, 0, 0x2000081000000000, 0, 0x800000000}) - - for numCalls := range keyNums { - // These calls to Use won't set nextAvailable, because the first keyNum set - sv.Use(keyNums[numCalls]) - if !reflect.DeepEqual(expectedVect[numCalls], sv.vect) { - t.Errorf("sv.vect differed from expected at index %v", numCalls) - fmt.Println(sv.vect) - } - } -} - -func TestStateVector_Used(t *testing.T) { - // These keyNums should be used - keyNums := []uint32{139, 145, 300, 360, 420, 761, 868, 875, 893, 995} - - const numKeys = 1000 - sv, err := newStateVector(versioned.NewKV(make(ekv.Memstore)), "key", numKeys) - if err != nil { - t.Fatal(err) - } - sv.vect = []uint64{0, 0, 0x20800, 0, 0x100000000000, 0x10000000000, 0x1000000000, 0, 0, 0, 0, 0x200000000000000, 0, 0x2000081000000000, 0, 0x800000000} - - for i := uint32(0); i < numKeys; i++ { - // if i is in keyNums, Used should be true - // otherwise, it should be false - found := false - for j := range keyNums { - if i == keyNums[j] { - found = true - break - } - } - if sv.Used(i) != found { - t.Errorf("at keynum %v Used should have been %v but was %v", i, found, sv.Used(i)) - } - } -} - -// Shows that the GetUsedKeyNums method returns the correct keynums -func TestStateVector_GetUsedKeyNums(t *testing.T) { - // These keyNums should be used - keyNums := []uint32{139, 145, 300, 360, 420, 761, 868, 875, 893, 995} - - const numKeys = 1000 - sv, err := newStateVector(versioned.NewKV(make(ekv.Memstore)), "key", numKeys) - if err != nil { - t.Fatal(err) - } - sv.vect = []uint64{0, 0, 0x20800, 0, 0x100000000000, 0x10000000000, 0x1000000000, 0, 0, 0, 0, 0x200000000000000, 0, 0x2000081000000000, 0, 0x800000000} - sv.numAvailable = uint32(numKeys - len(keyNums)) - - usedKeyNums := sv.GetUsedKeyNums() - for i := range keyNums { - if usedKeyNums[i] != keyNums[i] { - t.Errorf("used keynums at %v: expected %v, got %v", i, keyNums[i], usedKeyNums[i]) - } - } -} - -// Shows that GetUnusedKeyNums gets all clean keynums -func TestStateVector_GetUnusedKeyNums(t *testing.T) { - // These keyNums should not be used - keyNums := []uint32{139, 145, 300, 360, 420, 761, 868, 875, 893, 995} - - const numKeys = 1000 - sv, err := newStateVector(versioned.NewKV(make(ekv.Memstore)), "key", numKeys) - if err != nil { - t.Fatal(err) - } - sv.vect = []uint64{0xffffffffffffffff, 0xffffffffffffffff, 0xfffffffffffdf7ff, 0xffffffffffffffff, 0xffffefffffffffff, 0xfffffeffffffffff, 0xffffffefffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xfdffffffffffffff, 0xffffffffffffffff, 0xdffff7efffffffff, 0xffffffffffffffff, 0xfffffff7ffffffff} - sv.numAvailable = uint32(len(keyNums)) - sv.firstAvailable = keyNums[0] - - unusedKeyNums := sv.GetUnusedKeyNums() - for i := range keyNums { - if unusedKeyNums[i] != keyNums[i] { - t.Errorf("unused keynums at %v: expected %v, got %v", i, keyNums[i], unusedKeyNums[i]) - } - } - if len(keyNums) != len(unusedKeyNums) { - t.Error("array lengths differed, so arrays must be different") - } -} - -// Serializing and deserializing should result in the same state vector -func TestLoadStateVector(t *testing.T) { - keyNums := []uint32{139, 145, 300, 360, 420, 761, 868, 875, 893, 995} - const numKeys = 1000 - - kv := versioned.NewKV(make(ekv.Memstore)) - - sv, err := newStateVector(kv, "key", numKeys) - if err != nil { - t.Fatal(err) - } - sv.vect = []uint64{0xffffffffffffffff, 0xffffffffffffffff, - 0xfffffffffffdf7ff, 0xffffffffffffffff, 0xffffefffffffffff, - 0xfffffeffffffffff, 0xffffffefffffffff, 0xffffffffffffffff, - 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, - 0xfdffffffffffffff, 0xffffffffffffffff, 0xdffff7efffffffff, - 0xffffffffffffffff, 0xfffffff7ffffffff} - sv.numAvailable = uint32(len(keyNums)) - sv.firstAvailable = keyNums[0] - - err = sv.save() - if err != nil { - t.Fatal(err) - } - sv2, err := loadStateVector(kv, "key") - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(sv.vect, sv2.vect) { - t.Error("state vectors different after deserialization") - } -} diff --git a/storage/e2e/store_test.go b/storage/e2e/store_test.go index 016184b58e803a9b92d763be3c880b8a1b4dc05e..28f00b0cc6c08700705ff75264fd8ffa5e1f6f52 100644 --- a/storage/e2e/store_test.go +++ b/storage/e2e/store_test.go @@ -224,7 +224,7 @@ func TestStore_PopKey(t *testing.T) { } if key != nil { t.Errorf("PopKey() did not return a nil Key when it should not exist."+ - "\n\texpected: +%v\n\treceived: %+v", nil, key) + "\n\texpected: %+v\n\treceived: %+v", nil, key) } // Add a Key diff --git a/storage/fileTransfer/partInfo.go b/storage/fileTransfer/partInfo.go new file mode 100644 index 0000000000000000000000000000000000000000..1502921d10395f76a67b1a36397dd1dfb08f8fc0 --- /dev/null +++ b/storage/fileTransfer/partInfo.go @@ -0,0 +1,75 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "encoding/binary" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "strconv" +) + +// partInfo contains the transfer ID and fingerprint number for a file part. +type partInfo struct { + id ftCrypto.TransferID + fpNum uint16 +} + +// newPartInfo generates a new partInfo with the specified transfer ID and +// fingerprint number for a file part. +func newPartInfo(tid ftCrypto.TransferID, fpNum uint16) *partInfo { + pi := &partInfo{ + id: tid, + fpNum: fpNum, + } + + return pi +} + +// marshal serializes the partInfo into a byte slice. +func (pi *partInfo) marshal() []byte { + // Construct the buffer + buff := bytes.NewBuffer(nil) + buff.Grow(ftCrypto.TransferIdLength + 2) + + // Write the transfer ID to the buffer + buff.Write(pi.id.Bytes()) + + // Write the fingerprint number to the buffer + b := make([]byte, 2) + binary.LittleEndian.PutUint16(b, pi.fpNum) + buff.Write(b) + + // Return the serialized data + return buff.Bytes() +} + +// unmarshalPartInfo deserializes the byte slice into a partInfo. +func unmarshalPartInfo(b []byte) *partInfo { + buff := bytes.NewBuffer(b) + + // Read transfer ID from the buffer + transferIDBytes := buff.Next(ftCrypto.TransferIdLength) + transferID := ftCrypto.UnmarshalTransferID(transferIDBytes) + + // Read the fingerprint number from the buffer + fpNumBytes := buff.Next(2) + fpNum := binary.LittleEndian.Uint16(fpNumBytes) + + // Return the reconstructed partInfo + return &partInfo{ + id: transferID, + fpNum: fpNum, + } +} + +// String prints a string representation of partInfo. This functions satisfies +// the fmt.Stringer interface. +func (pi *partInfo) String() string { + return "{id:" + pi.id.String() + " fpNum:" + strconv.Itoa(int(pi.fpNum)) + "}" +} diff --git a/storage/fileTransfer/partInfo_test.go b/storage/fileTransfer/partInfo_test.go new file mode 100644 index 0000000000000000000000000000000000000000..b0ce83a3e10a4c5c9f309313b69b9a4602904cfc --- /dev/null +++ b/storage/fileTransfer/partInfo_test.go @@ -0,0 +1,74 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// +package fileTransfer + +import ( + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "reflect" + "testing" +) + +// Tests that newPartInfo creates the expected partInfo. +func Test_newPartInfo(t *testing.T) { + // Created expected partInfo + expected := &partInfo{ + id: ftCrypto.UnmarshalTransferID([]byte("TestTransferID")), + fpNum: 16, + } + + // Create new partInfo + received := newPartInfo(expected.id, expected.fpNum) + + if !reflect.DeepEqual(expected, received) { + t.Fatalf("New partInfo does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, received) + } +} + +// Tests that a partInfo that is marshalled with partInfo.marshal and +// unmarshalled with unmarshalPartInfo matches the original. +func Test_partInfo_marshal_unmarshalPartInfo(t *testing.T) { + expectedPI := newPartInfo( + ftCrypto.UnmarshalTransferID([]byte("TestTransferID")), 25) + + piBytes := expectedPI.marshal() + receivedPI := unmarshalPartInfo(piBytes) + + if !reflect.DeepEqual(expectedPI, receivedPI) { + t.Errorf("Marshalled and unmarshalled partInfo does not match original."+ + "\nexpected: %+v\nreceived: %+v", expectedPI, receivedPI) + } +} + +// Consistency test of partInfo.String. +func Test_partInfo_String(t *testing.T) { + prng := NewPrng(42) + expectedStrings := []string{ + "{id:U4x/lrFkvxuXu59LtHLon1sUhPJSCcnZND6SugndnVI= fpNum:0}", + "{id:39ebTXZCm2F6DJ+fDTulWwzA1hRMiIU1hBrL4HCbB1g= fpNum:1}", + "{id:CD9h03W8ArQd9PkZKeGP2p5vguVOdI6B555LvW/jTNw= fpNum:2}", + "{id:uoQ+6NY+jE/+HOvqVG2PrBPdGqwEzi6ih3xVec+ix44= fpNum:3}", + "{id:GwuvrogbgqdREIpC7TyQPKpDRlp4YgYWl4rtDOPGxPM= fpNum:4}", + "{id:rnvD4ElbVxL+/b4MECiH4QDazS2IX2kstgfaAKEcHHA= fpNum:5}", + "{id:ceeWotwtwlpbdLLhKXBeJz8FySMmgo4rBW44F2WOEGE= fpNum:6}", + "{id:SYlH/fNEQQ7UwRYCP6jjV2tv7Sf/iXS6wMr9mtBWkrE= fpNum:7}", + "{id:NhnnOJZN/ceejVNDc2Yc/WbXT+weG4lJGrcjbkt1IWI= fpNum:8}", + "{id:kM8r60LDyicyhWDxqsBnzqbov0bUqytGgEAsX7KCDog= fpNum:9}", + "{id:XTJg8d6XgoPUoJo2+WwglBdG4+1NpkaprotPp7T8OiA= fpNum:10}", + "{id:uvoade0yeoa4sMOa8c/Ss7USGep5Uzq/RI0sR50yYHU= fpNum:11}", + } + + for i, expected := range expectedStrings { + tid, _ := ftCrypto.NewTransferID(prng) + pi := newPartInfo(tid, uint16(i)) + + if expected != pi.String() { + t.Errorf("partInfo #%d string does not match expected."+ + "\nexpected: %s\nreceived: %s", i, expected, pi.String()) + } + } +} diff --git a/storage/fileTransfer/partStore.go b/storage/fileTransfer/partStore.go new file mode 100644 index 0000000000000000000000000000000000000000..f29655d61e2d553cd42fe288d01bacaf03a90b0b --- /dev/null +++ b/storage/fileTransfer/partStore.go @@ -0,0 +1,277 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "encoding/binary" + "github.com/pkg/errors" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/xx_network/primitives/netTime" + "strconv" + "sync" +) + +// Storage keys and versions. +const ( + partsStoreVersion = 0 + partsStoreKey = "FileTransferPart" + partsListVersion = 0 + partsListKey = "FileTransferList" +) + +// Error messages. +const ( + loadPartListErr = "failed to get parts list from storage: %+v" + loadPartsErr = "failed to load part #%d from storage: %+v" + savePartsErr = "failed to save part #%d to storage: %+v" +) + +// partStore stores the file parts in memory/storage. +type partStore struct { + parts map[uint16][]byte // File parts, keyed on their number in order + numParts uint16 // Number of parts in full file + kv *versioned.KV + mux sync.RWMutex +} + +// newPartStore generates a new empty partStore and saves it to storage. +func newPartStore(kv *versioned.KV, numParts uint16) (*partStore, error) { + + // Construct empty partStore of the specified size + ps := &partStore{ + parts: make(map[uint16][]byte, numParts), + numParts: numParts, + kv: kv, + } + + // Save to storage + return ps, ps.save() +} + +// newPartStore generates a new empty partStore and saves it to storage. +func newPartStoreFromParts(kv *versioned.KV, parts ...[]byte) (*partStore, + error) { + + // Construct empty partStore of the specified size + ps := &partStore{ + parts: partSliceToMap(parts...), + numParts: uint16(len(parts)), + kv: kv, + } + + // Save to storage + return ps, ps.save() +} + +// addPart adds a file part to the list of parts, saves it to storage, and +// regenerates the list of all file parts and saves it to storage. +func (ps *partStore) addPart(part []byte, partNum uint16) error { + ps.mux.Lock() + defer ps.mux.Unlock() + + ps.parts[partNum] = part + + err := ps.savePart(partNum) + if err != nil { + return err + } + + return ps.saveList() +} + +// getPart returns the part at the given part number. +func (ps *partStore) getPart(partNum uint16) ([]byte, bool) { + ps.mux.Lock() + defer ps.mux.Unlock() + + part, exists := ps.parts[partNum] + return part, exists +} + +// getFile returns all file parts concatenated into a single file. Returns the +// entire file as a byte slice and the number of parts missing. If the int is 0, +// then no parts are missing and the returned file is complete. +func (ps *partStore) getFile() ([]byte, int) { + ps.mux.Lock() + defer ps.mux.Unlock() + + // Get the length of one of the parts (all parts should be the same size) + partLength := 0 + for _, part := range ps.parts { + partLength = len(part) + break + } + + // Create new empty buffer of the size of the whole file + buff := bytes.NewBuffer(nil) + buff.Grow(int(ps.numParts) * partLength) + + // Loop through the map in order and add each file part that exists + var missingParts int + for i := uint16(0); i < ps.numParts; i++ { + part, exists := ps.parts[i] + if exists { + buff.Write(part) + } else { + missingParts++ + } + } + + return buff.Bytes(), missingParts +} + +// len returns the number of parts stored. +func (ps *partStore) len() int { + return len(ps.parts) +} + +// loadPartStore loads all the file parts from storage into memory. +func loadPartStore(kv *versioned.KV) (*partStore, error) { + // Get list of saved file parts + vo, err := kv.Get(partsListKey, partsListVersion) + if err != nil { + return nil, errors.Errorf(loadPartListErr, err) + } + + // Unmarshal saved data into a list + numParts, list := unmarshalPartList(vo.Data) + + // Initialize part map + ps := &partStore{ + parts: make(map[uint16][]byte, numParts), + numParts: numParts, + kv: kv, + } + + // Load each part from storage and add to the map + for _, partNum := range list { + vo, err := kv.Get(makePartsKey(partNum), partsStoreVersion) + if err != nil { + return nil, errors.Errorf(loadPartsErr, partNum, err) + } + + ps.parts[partNum] = vo.Data + } + + return ps, nil +} + +// save stores a list of all file parts and the individual parts in storage. +func (ps *partStore) save() error { + ps.mux.Lock() + defer ps.mux.Unlock() + + // Save the individual file parts to storage + for partNum := range ps.parts { + err := ps.savePart(partNum) + if err != nil { + return errors.Errorf(savePartsErr, partNum, err) + } + } + + // Save the part list to storage + return ps.saveList() +} + +// saveList stores the list of all file parts in storage. +func (ps *partStore) saveList() error { + obj := &versioned.Object{ + Version: partsStoreVersion, + Timestamp: netTime.Now(), + Data: ps.marshalList(), + } + + return ps.kv.Set(partsListKey, partsListVersion, obj) +} + +// savePart stores an individual file part to storage. +func (ps *partStore) savePart(partNum uint16) error { + obj := &versioned.Object{ + Version: partsStoreVersion, + Timestamp: netTime.Now(), + Data: ps.parts[partNum], + } + + return ps.kv.Set(makePartsKey(partNum), partsStoreVersion, obj) +} + +// delete removes all the file parts and file list from storage. +func (ps *partStore) delete() error { + ps.mux.Lock() + defer ps.mux.Unlock() + + for partNum := range ps.parts { + err := ps.kv.Delete(makePartsKey(partNum), partsStoreVersion) + if err != nil { + return err + } + } + + return ps.kv.Delete(partsListKey, partsListVersion) +} + +// marshalList creates a list of part numbers that are currently stored and +// returns them as a byte list to be saved to storage. +func (ps *partStore) marshalList() []byte { + // Create new buffer of the correct size + // (numParts (2 bytes) + (2*length of parts)) + buff := bytes.NewBuffer(nil) + buff.Grow(2 + (2 * len(ps.parts))) + + // Write numParts to buffer + b := make([]byte, 2) + binary.LittleEndian.PutUint16(b, ps.numParts) + buff.Write(b) + + for partNum := range ps.parts { + b = make([]byte, 2) + binary.LittleEndian.PutUint16(b, partNum) + buff.Write(b) + } + + return buff.Bytes() +} + +// unmarshalPartList unmarshalls a byte slice into a list of part numbers. +func unmarshalPartList(b []byte) (uint16, []uint16) { + buff := bytes.NewBuffer(b) + + // Read numParts from the buffer + numParts := binary.LittleEndian.Uint16(buff.Next(2)) + + // Initialize the list to the number of saved parts + list := make([]uint16, 0, len(b)/2) + + // Read each uint16 from the buffer and save into the list + for next := buff.Next(2); len(next) == 2; next = buff.Next(2) { + part := binary.LittleEndian.Uint16(next) + list = append(list, part) + } + + return numParts, list +} + +// partSliceToMap converts a slice of file parts, in order, to a map of file +// parts keyed on their part number. +func partSliceToMap(parts ...[]byte) map[uint16][]byte { + // Initialise map to the correct size + partMap := make(map[uint16][]byte, len(parts)) + + // Add each file part to the map + for partNum, part := range parts { + partMap[uint16(partNum)] = part + } + + return partMap +} + +// makePartsKey generates the key used to save a part to storage. +func makePartsKey(partNum uint16) string { + return partsStoreKey + strconv.Itoa(int(partNum)) +} diff --git a/storage/fileTransfer/partStore_test.go b/storage/fileTransfer/partStore_test.go new file mode 100644 index 0000000000000000000000000000000000000000..4a096e64cfbbdba216176185187079f895578d9c --- /dev/null +++ b/storage/fileTransfer/partStore_test.go @@ -0,0 +1,560 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "encoding/binary" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/ekv" + "io" + "math/rand" + "reflect" + "sort" + "strings" + "testing" +) + +// Tests that newPartStore produces the expected new partStore and that an empty +// parts list is saved to storage. +func Test_newPartStore(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + numParts := uint16(16) + + expectedPS := &partStore{ + parts: make(map[uint16][]byte, numParts), + numParts: numParts, + kv: kv, + } + + ps, err := newPartStore(kv, numParts) + if err != nil { + t.Errorf("newPartStore returned an error: %+v", err) + } + + if !reflect.DeepEqual(expectedPS, ps) { + t.Errorf("Returned incorrect partStore.\nexpected: %+v\nreceived: %+v", + expectedPS, ps) + } + + _, err = kv.Get(partsListKey, partsListVersion) + if err != nil { + t.Errorf("Failed to load part list from storage: %+v", err) + } +} + +// Tests that newPartStoreFromParts produces the expected partStore filled with +// the given parts. +func Test_newPartStoreFromParts(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + numParts := uint16(16) + + // Generate part slice and part map filled with the same data + partSlice := make([][]byte, numParts) + partMap := make(map[uint16][]byte, numParts) + for i := uint16(0); i < numParts; i++ { + b := make([]byte, 32) + prng.Read(b) + + partSlice[i] = b + partMap[i] = b + } + + expectedPS := &partStore{ + parts: partMap, + numParts: uint16(len(partMap)), + kv: kv, + } + + ps, err := newPartStoreFromParts(kv, partSlice...) + if err != nil { + t.Errorf("newPartStoreFromParts returned an error: %+v", err) + } + + if !reflect.DeepEqual(expectedPS, ps) { + t.Errorf("Returned incorrect partStore.\nexpected: %+v\nreceived: %+v", + expectedPS, ps) + } + + loadedPS, err := loadPartStore(kv) + if err != nil { + t.Errorf("Failed to load partStore from storage: %+v", err) + } + + if !reflect.DeepEqual(expectedPS, loadedPS) { + t.Errorf("Loaded partStore does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expectedPS, loadedPS) + } +} + +// Tests that a part added via partStore.addPart can be loaded from memory and +// storage. +func Test_partStore_addPart(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + numParts := uint16(16) + ps, _ := newRandomPartStore(numParts, kv, prng, t) + + expectedPart := []byte("part data") + partNum := uint16(17) + + err := ps.addPart(expectedPart, partNum) + if err != nil { + t.Errorf("addPart returned an error: %+v", err) + } + + // Check if part is in memory + part, exists := ps.parts[partNum] + if !exists || !bytes.Equal(expectedPart, part) { + t.Errorf("Failed to get part #%d from memory."+ + "\nexpected: %+v\nreceived: %+v", partNum, expectedPart, part) + } + + // Load part from storage + vo, err := kv.Get(makePartsKey(partNum), partsStoreVersion) + if err != nil { + t.Errorf("Failed to load part from storage: %+v", err) + } + + // Check that the part loaded from storage is correct + if !bytes.Equal(expectedPart, vo.Data) { + t.Errorf("Part saved to storage unexpected"+ + "\nexpected: %+v\nreceived: %+v", expectedPart, vo.Data) + } +} + +// Tests that partStore.getPart returns the expected part. +func Test_partStore_getPart(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + numParts := uint16(16) + ps, _ := newRandomPartStore(numParts, kv, prng, t) + + expectedPart := []byte("part data") + partNum := uint16(17) + + err := ps.addPart(expectedPart, partNum) + if err != nil { + t.Errorf("addPart returned an error: %+v", err) + } + + // Check if part is in memory + part, exists := ps.getPart(partNum) + if !exists || !bytes.Equal(expectedPart, part) { + t.Errorf("Failed to get part #%d from memory."+ + "\nexpected: %+v\nreceived: %+v", partNum, expectedPart, part) + } +} + +// Tests that partStore.getFile returns all parts concatenated in order. +func Test_partStore_getFile(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + ps, expectedFile := newRandomPartStore(16, kv, prng, t) + + // Pull data from file + receivedFile, missingParts := ps.getFile() + if missingParts != 0 { + t.Errorf("File has missing parts.\nexpected: %d\nreceived: %d", + 0, missingParts) + } + + // Check correctness of reconstructions + if !bytes.Equal(receivedFile, expectedFile) { + t.Fatalf("Full reconstructed file does not match expected."+ + "\nexpected: %v\nreceived: %v", expectedFile, receivedFile) + } +} + +// Tests that partStore.getFile returns all parts concatenated in order. +func Test_partStore_getFile_MissingPartsError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + numParts := uint16(16) + ps, _ := newRandomPartStore(numParts, kv, prng, t) + + // Delete half the parts + for partNum := uint16(0); partNum < numParts; partNum++ { + if partNum%2 == 0 { + delete(ps.parts, partNum) + } + } + + // Pull data from file + _, missingParts := ps.getFile() + if missingParts != 8 { + t.Errorf("Missing incorrect number of parts."+ + "\nexpected: %d\nreceived: %d", 0, missingParts) + } +} + +// Tests that partStore.len returns 0 for a new map and the correct value when +// parts are added +func Test_partStore_len(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + numParts := uint16(16) + ps, _ := newPartStore(kv, numParts) + + if ps.len() != 0 { + t.Errorf("Length of new partStore is not 0."+ + "\nexpected: %d\nreceived: %d", 0, ps.len()) + } + + addedParts := 5 + for i := 0; i < addedParts; i++ { + _ = ps.addPart([]byte("test"), uint16(i)) + } + + if ps.len() != addedParts { + t.Errorf("Length of new partStore incorrect."+ + "\nexpected: %d\nreceived: %d", addedParts, ps.len()) + } +} + +// Tests that loadPartStore gets the expected partStore from storage. +func Test_loadPartStore(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + expectedPS, _ := newRandomPartStore(16, kv, prng, t) + + err := expectedPS.save() + if err != nil { + t.Errorf("Failed to save parts to storage: %+v", err) + } + + ps, err := loadPartStore(kv) + if err != nil { + t.Errorf("loadPartStore returned an error: %+v", err) + } + + if !reflect.DeepEqual(expectedPS, ps) { + t.Errorf("Loaded partStore does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expectedPS, ps) + } +} + +// Error path: tests that loadPartStore returns the expected error when no part +// list is saved to storage. +func Test_loadPartStore_NoSavedListErr(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + expectedErr := strings.Split(loadPartListErr, ":")[0] + _, err := loadPartStore(kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadPartStore failed to return the expected error when no "+ + "object is saved in storage.\nexpected: %s\nreceived: %+v", + expectedErr, err) + } +} + +// Error path: tests that loadPartStore returns the expected error when no parts +// are saved to storage. +func Test_loadPartStore_NoPartErr(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + expectedErr := strings.Split(loadPartsErr, "#")[0] + ps, _ := newRandomPartStore(16, kv, prng, t) + + err := ps.saveList() + if err != nil { + t.Errorf("Failed to save part list to storage: %+v", err) + } + + _, err = loadPartStore(kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadPartStore failed to return the expected error when no "+ + "object is saved in storage.\nexpected: %s\nreceived: %+v", + expectedErr, err) + } +} + +// Tests that partStore.save correctly stores the part list to storage by +// reading it from storage and ensuring all part numbers are present. It then +// makes sure that all the parts are stored in storage. +func Test_partStore_save(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + + numParts := uint16(16) + ps, _ := newRandomPartStore(numParts, kv, prng, t) + + err := ps.save() + if err != nil { + t.Errorf("save returned an error: %+v", err) + } + + vo, err := kv.Get(partsListKey, partsListVersion) + if err != nil { + t.Errorf("Failed to load part list from storage: %+v", err) + } + + numList := make(map[uint16]bool, numParts) + for i := uint16(0); i < numParts; i++ { + numList[i] = true + } + + buff := bytes.NewBuffer(vo.Data[2:]) + for next := buff.Next(2); len(next) == 2; next = buff.Next(2) { + partNum := binary.LittleEndian.Uint16(next) + if !numList[partNum] { + t.Errorf("Part number %d is not a correct part number.", partNum) + } else { + delete(numList, partNum) + } + } + + if len(numList) != 0 { + t.Errorf("File part numbers missing from list: %+v", numList) + } + + loadedNumParts, list := unmarshalPartList(vo.Data) + + if numParts != loadedNumParts { + t.Errorf("Loaded numParts does not match expected."+ + "\nexpected: %d\nrecieved: %d", numParts, loadedNumParts) + } + + for _, partNum := range list { + vo, err := kv.Get(makePartsKey(partNum), partsStoreVersion) + if err != nil { + t.Errorf("Failed to load part #%d from storage: %+v", partNum, err) + } + + if !bytes.Equal(ps.parts[partNum], vo.Data) { + t.Errorf("Part data #%d loaded from storage unexpected."+ + "\nexpected: %+v\nreceived: %+v", + partNum, ps.parts[partNum], vo.Data) + } + } +} + +// Tests that partStore.saveList saves the expected list to storage. +func Test_partStore_saveList(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + numParts := uint16(16) + expectedPS, _ := newRandomPartStore(numParts, kv, prng, t) + + err := expectedPS.saveList() + if err != nil { + t.Errorf("saveList returned an error: %+v", err) + } + + vo, err := kv.Get(partsListKey, partsListVersion) + if err != nil { + t.Errorf("Failed to load part list from storage: %+v", err) + } + + numList := make(map[uint16]bool, numParts) + for i := uint16(0); i < numParts; i++ { + numList[i] = true + } + + buff := bytes.NewBuffer(vo.Data[2:]) + for next := buff.Next(2); len(next) == 2; next = buff.Next(2) { + partNum := binary.LittleEndian.Uint16(next) + if !numList[partNum] { + t.Errorf("Part number %d is not a correct part number.", partNum) + } else { + delete(numList, partNum) + } + } + + if len(numList) != 0 { + t.Errorf("File part numbers missing from list: %+v", numList) + } +} + +// Tests that a part saved via partStore.savePart can be loaded from storage. +func Test_partStore_savePart(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + ps, err := newPartStore(kv, 16) + if err != nil { + t.Fatalf("Failed to create new partStore: %+v", err) + } + + expectedPart := []byte("part data") + partNum := uint16(5) + ps.parts[partNum] = expectedPart + + err = ps.savePart(partNum) + if err != nil { + t.Errorf("savePart returned an error: %+v", err) + } + + // Load part from storage + vo, err := kv.Get(makePartsKey(partNum), partsStoreVersion) + if err != nil { + t.Errorf("Failed to load part from storage: %+v", err) + } + + // Check that the part loaded from storage is correct + if !bytes.Equal(expectedPart, vo.Data) { + t.Errorf("Part saved to storage unexpected"+ + "\nexpected: %+v\nreceived: %+v", expectedPart, vo.Data) + } +} + +// Tests that partStore.delete deletes all stored items. +func Test_partStore_delete(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + numParts := uint16(16) + ps, _ := newRandomPartStore(numParts, kv, prng, t) + + err := ps.delete() + if err != nil { + t.Errorf("delete returned an error: %+v", err) + } + + _, err = kv.Get(partsListKey, partsListVersion) + if err == nil { + t.Error("Able to load part list from storage when it should have been " + + "deleted.") + } + + for partNum := range ps.parts { + _, err := kv.Get(makePartsKey(partNum), partsStoreVersion) + if err == nil { + t.Errorf("Loaded part #%d from storage when it should have been "+ + "deleted.", partNum) + } + } +} + +// Tests that a list marshalled via partStore.marshalList can be unmarshalled +// with unmarshalPartList to get the original list. +func Test_partStore_marshalList_unmarshalPartList(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + numParts := uint16(16) + ps, _ := newRandomPartStore(numParts, kv, prng, t) + + expected := make([]uint16, 0, 16) + for partNum := range ps.parts { + expected = append(expected, partNum) + } + + byteList := ps.marshalList() + + loadedNumParts, list := unmarshalPartList(byteList) + + if numParts != loadedNumParts { + t.Errorf("Loaded numParts does not match expected."+ + "\nexpected: %d\nrecieved: %d", numParts, loadedNumParts) + } + + sort.SliceStable(list, func(i, j int) bool { return list[i] < list[j] }) + sort.SliceStable(expected, func(i, j int) bool { return expected[i] < expected[j] }) + + if !reflect.DeepEqual(expected, list) { + t.Errorf("Failed to marshal and unmarshal part list."+ + "\nexpected: %+v\nreceived: %+v", expected, list) + } +} + +// Tests that partSliceToMap correctly maps all the parts in a slice of parts +// to a map of parts keyed on their part number. +func Test_partSliceToMap(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + + // Create list of file parts with random data + partSlice := make([][]byte, 8) + for i := range partSlice { + partSlice[i] = make([]byte, 16) + prng.Read(partSlice[i]) + } + + // Convert the slice of parts to a map + partMap := partSliceToMap(partSlice...) + + // Check that each part in the map matches a part in the slice + for partNum, slicePart := range partSlice { + // Check that the part exists in the map + mapPart, exists := partMap[uint16(partNum)] + if !exists { + t.Errorf("Part number %d does not exist in the map.", partNum) + } + + // Check that the part in the map is correct + if !bytes.Equal(slicePart, mapPart) { + t.Errorf("Part found in map does not match expected part in slice."+ + "\nexpected: %+v\nreceived: %+v", slicePart, mapPart) + } + + delete(partMap, uint16(partNum)) + } + + // Make sure there are no extra parts in the map + if len(partMap) != 0 { + t.Errorf("Part map contains %d extra parts not in the slice."+ + "\nparts: %+v", len(partMap), partMap) + } +} + +// Tests the consistency of makePartsKey. +func Test_makePartsKey_consistency(t *testing.T) { + expectedStrings := []string{ + partsStoreKey + "0", + partsStoreKey + "1", + partsStoreKey + "2", + partsStoreKey + "3", + } + + for i, expected := range expectedStrings { + key := makePartsKey(uint16(i)) + if key != expected { + t.Errorf("Key #%d does not match expected."+ + "\nexpected: %q\nreceived: %q", i, expected, key) + } + } +} + +// newRandomPartStore creates a new partStore filled with random data. Returns +// the random partStore and a slice of the original file. +func newRandomPartStore(numParts uint16, kv *versioned.KV, prng io.Reader, + t *testing.T) (*partStore, []byte) { + + partSize := 64 + + ps, err := newPartStore(kv, numParts) + if err != nil { + t.Fatalf("Failed to create new partStore: %+v", err) + } + + fileBuff := bytes.NewBuffer(nil) + fileBuff.Grow(int(numParts) * partSize) + + for partNum := uint16(0); partNum < numParts; partNum++ { + ps.parts[partNum] = make([]byte, partSize) + _, err := prng.Read(ps.parts[partNum]) + if err != nil { + t.Errorf("Failed to generate random part (%d): %+v", partNum, err) + } + fileBuff.Write(ps.parts[partNum]) + } + + return ps, fileBuff.Bytes() +} + +// newRandomPartSlice returns a list of file parts and the file in one piece. +func newRandomPartSlice(numParts uint16, prng io.Reader, t *testing.T) ([][]byte, []byte) { + partSize := 64 + fileBuff := bytes.NewBuffer(make([]byte, 0, int(numParts)*partSize)) + partList := make([][]byte, numParts) + for partNum := range partList { + partList[partNum] = make([]byte, partSize) + _, err := prng.Read(partList[partNum]) + if err != nil { + t.Errorf("Failed to generate random part (%d): %+v", partNum, err) + } + fileBuff.Write(partList[partNum]) + } + + return partList, fileBuff.Bytes() +} diff --git a/storage/fileTransfer/receiveFileTransfers.go b/storage/fileTransfer/receiveFileTransfers.go new file mode 100644 index 0000000000000000000000000000000000000000..99bc6921be29638ab849fe70916270f3a4d4a4d6 --- /dev/null +++ b/storage/fileTransfer/receiveFileTransfers.go @@ -0,0 +1,341 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "github.com/pkg/errors" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/primitives/netTime" + "sync" +) + +const ( + receivedFileTransfersPrefix = "FileTransferReceivedFileTransfersStore" + receivedFileTransfersKey = "ReceivedFileTransfers" + receivedFileTransfersVersion = 0 +) + +// Error messages for ReceivedFileTransfers. +const ( + saveReceivedTransfersListErr = "failed to save list of received items in transfer map to storage: %+v" + loadReceivedTransfersListErr = "failed to load list of received items in transfer map from storage: %+v" + loadReceivedFileTransfersErr = "failed to load received transfers from storage: %+v" + + newReceivedTransferErr = "failed to create new received transfer: %+v" + getReceivedTransferErr = "received transfer with ID %s not found" + addTransferNewIdErr = "could not generate new transfer ID: %+v" + noFingerprintErr = "no part found with fingerprint %s" + addPartErr = "failed to add part number %d/%d to transfer %s: %+v" + deleteReceivedTransferErr = "failed to delete received transfer with ID %s from store: %+v" +) + +// ReceivedFileTransfers contains information for tracking a received +// ReceivedTransfer to its transfer ID. It also maps a received part, partInfo, +// to the message fingerprint. +type ReceivedFileTransfers struct { + transfers map[ftCrypto.TransferID]*ReceivedTransfer + info map[format.Fingerprint]*partInfo + mux sync.Mutex + kv *versioned.KV +} + +// NewReceivedFileTransfers creates a new ReceivedFileTransfers with empty maps. +func NewReceivedFileTransfers(kv *versioned.KV) (*ReceivedFileTransfers, error) { + rft := &ReceivedFileTransfers{ + transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), + info: make(map[format.Fingerprint]*partInfo), + kv: kv.Prefix(receivedFileTransfersPrefix), + } + + return rft, rft.saveTransfersList() +} + +// AddTransfer creates a new empty ReceivedTransfer, adds it to the transfers +// map, and adds an entry for each file part fingerprint to the fingerprint map. +func (rft *ReceivedFileTransfers) AddTransfer(key ftCrypto.TransferKey, + transferMAC []byte, fileSize uint32, numParts, numFps uint16, + rng csprng.Source) (ftCrypto.TransferID, error) { + + rft.mux.Lock() + defer rft.mux.Unlock() + + // Generate new transfer ID + tid, err := ftCrypto.NewTransferID(rng) + if err != nil { + return tid, errors.Errorf(addTransferNewIdErr, err) + } + + // Generate a new ReceivedTransfer and add it to the map + rft.transfers[tid], err = NewReceivedTransfer( + tid, key, transferMAC, fileSize, numParts, numFps, rft.kv) + if err != nil { + return tid, errors.Errorf(newReceivedTransferErr, err) + } + + // Add part info for each file part to the map + rft.addFingerprints(key, tid, numFps) + + // Update list of transfers in storage + err = rft.saveTransfersList() + if err != nil { + return tid, errors.Errorf(saveReceivedTransfersListErr, err) + } + + return tid, nil +} + +// addFingerprints generates numFps fingerprints, creates a new partInfo +// for each, and adds each to the info map. +func (rft *ReceivedFileTransfers) addFingerprints(key ftCrypto.TransferKey, + tid ftCrypto.TransferID, numFps uint16) { + + // Generate list of fingerprints + fps := ftCrypto.GenerateFingerprints(key, numFps) + + // Add fingerprints to map + for fpNum, fp := range fps { + rft.info[fp] = newPartInfo(tid, uint16(fpNum)) + } +} + +// GetTransfer returns the ReceivedTransfer with the given transfer ID. An error +// is returned if no corresponding transfer is found. +func (rft *ReceivedFileTransfers) GetTransfer(tid ftCrypto.TransferID) ( + *ReceivedTransfer, error) { + rft.mux.Lock() + defer rft.mux.Unlock() + + rt, exists := rft.transfers[tid] + if !exists { + return nil, errors.Errorf(getReceivedTransferErr, tid) + } + + return rt, nil +} + +// DeleteTransfer removes the ReceivedTransfer with the associated transfer ID +// from memory and storage. +func (rft *ReceivedFileTransfers) DeleteTransfer(tid ftCrypto.TransferID) error { + rft.mux.Lock() + defer rft.mux.Unlock() + + // Return an error if the transfer does not exist + rt, exists := rft.transfers[tid] + if !exists { + return errors.Errorf(getReceivedTransferErr, tid) + } + + // Remove all unused fingerprints from map + for n, err := rt.fpVector.Next(); err == nil; n, err = rt.fpVector.Next() { + // Generate fingerprint + fp := ftCrypto.GenerateFingerprint(rt.key, uint16(n)) + + // Delete fingerprint from map + delete(rft.info, fp) + } + + // Delete all data the transfer saved to storage + err := rft.transfers[tid].delete() + if err != nil { + return errors.Errorf(deleteReceivedTransferErr, tid, err) + } + + // Delete the transfer from memory + delete(rft.transfers, tid) + + // Update the transfers list for the removed transfer + err = rft.saveTransfersList() + if err != nil { + return errors.Errorf(saveReceivedTransfersListErr, err) + } + + return nil +} + +// AddPart adds the file part to its corresponding transfer. The fingerprint +// number and transfer ID are looked up using the fingerprint. Then the part is +// added to the transfer with the corresponding transfer ID. Returns the +// transfer that the part was added to so that a progress callback can be +// called. Returns the transfer ID so that it can be used for logging. +func (rft *ReceivedFileTransfers) AddPart(encryptedPart, padding, mac []byte, + partNum uint16, fp format.Fingerprint) (*ReceivedTransfer, + ftCrypto.TransferID, error) { + rft.mux.Lock() + defer rft.mux.Unlock() + + // Lookup the part info for the given fingerprint + info, exists := rft.info[fp] + if !exists { + return nil, ftCrypto.TransferID{}, errors.Errorf(noFingerprintErr, fp) + } + + // Lookup the transfer with the ID in the part info + transfer, exists := rft.transfers[info.id] + if !exists { + return nil, info.id, errors.Errorf(getReceivedTransferErr, info.id) + } + + // Add the part to the transfer + err := transfer.AddPart(encryptedPart, padding, mac, partNum, info.fpNum) + if err != nil { + return transfer, info.id, errors.Errorf( + addPartErr, partNum, transfer.numParts, info.id, err) + } + + // Remove the part info from the map + delete(rft.info, fp) + + return transfer, info.id, nil +} + +// GetFile returns the combined file parts as a single byte slice for the given +// transfer ID. An error is returned if no file with the given transfer ID is +// found or if the file is missing parts. +func (rft *ReceivedFileTransfers) GetFile(tid ftCrypto.TransferID) ([]byte, error) { + rft.mux.Lock() + defer rft.mux.Unlock() + + // Check if the transfer exists + rt, exists := rft.transfers[tid] + if !exists { + return nil, errors.Errorf(getReceivedTransferErr, tid) + } + + return rt.GetFile() +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// LoadReceivedFileTransfers loads all ReceivedFileTransfers from storage. +func LoadReceivedFileTransfers(kv *versioned.KV) (*ReceivedFileTransfers, error) { + rft := &ReceivedFileTransfers{ + transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), + info: make(map[format.Fingerprint]*partInfo), + kv: kv.Prefix(receivedFileTransfersPrefix), + } + + // Get the list of transfer IDs corresponding to each received transfer from + // storage + transfersList, err := rft.loadTransfersList() + if err != nil { + return nil, errors.Errorf(loadReceivedTransfersListErr, err) + } + + // Load transfers and fingerprints into the maps + err = rft.load(transfersList) + if err != nil { + return nil, errors.Errorf(loadReceivedFileTransfersErr, err) + } + + return rft, nil +} + +// NewOrLoadReceivedFileTransfers loads all ReceivedFileTransfers from storage, +// if they exist. Otherwise, a new ReceivedFileTransfers is returned. +func NewOrLoadReceivedFileTransfers(kv *versioned.KV) (*ReceivedFileTransfers, error) { + rft := &ReceivedFileTransfers{ + transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), + info: make(map[format.Fingerprint]*partInfo), + kv: kv.Prefix(receivedFileTransfersPrefix), + } + + // If the transfer list cannot be loaded from storage, then create a new + // ReceivedFileTransfers + vo, err := rft.kv.Get(receivedFileTransfersKey, receivedFileTransfersVersion) + if err != nil { + return NewReceivedFileTransfers(kv) + } + + // Unmarshal data into list of saved transfer IDs + transfersList := unmarshalTransfersList(vo.Data) + + // Load transfers and fingerprints into the maps + err = rft.load(transfersList) + if err != nil { + return nil, errors.Errorf(loadReceivedFileTransfersErr, err) + } + + return rft, nil +} + +// saveTransfersList saves a list of items in the transfers map to storage. +func (rft *ReceivedFileTransfers) saveTransfersList() error { + // Create new versioned object with a list of items in the transfers map + obj := &versioned.Object{ + Version: receivedFileTransfersVersion, + Timestamp: netTime.Now(), + Data: rft.marshalTransfersList(), + } + + // Save list of items in the transfers map to storage + return rft.kv.Set(receivedFileTransfersKey, receivedFileTransfersVersion, obj) +} + +// loadTransfersList gets the list of transfer IDs corresponding to each saved +// received transfer from storage. +func (rft *ReceivedFileTransfers) loadTransfersList() ([]ftCrypto.TransferID, error) { + // Get transfers list from storage + vo, err := rft.kv.Get(receivedFileTransfersKey, receivedFileTransfersVersion) + if err != nil { + return nil, err + } + + // Unmarshal data into list of saved transfer IDs + return unmarshalTransfersList(vo.Data), nil +} + +// load gets each ReceivedTransfer in the list from storage and adds them to the +// map. Also adds all unused fingerprints in each ReceivedTransfer to the info +// map. +func (rft *ReceivedFileTransfers) load(list []ftCrypto.TransferID) error { + // Load each sentTransfer from storage into the map + for _, tid := range list { + // Load the transfer with the given transfer ID from storage + rt, err := loadReceivedTransfer(tid, rft.kv) + if err != nil { + return err + } + + // Add transfer to transfer map + rft.transfers[tid] = rt + + // Load all unused fingerprints into the info map + for n := uint32(0); n < rt.fpVector.GetNumKeys(); n++ { + if !rt.fpVector.Used(n) { + fpNum := uint16(n) + + // Generate fingerprint + fp := ftCrypto.GenerateFingerprint(rt.key, fpNum) + + // Add to map + rft.info[fp] = &partInfo{tid, fpNum} + } + } + } + + return nil +} + +// marshalTransfersList creates a list of all transfer IDs in the transfers map +// and serialises it. +func (rft *ReceivedFileTransfers) marshalTransfersList() []byte { + buff := bytes.NewBuffer(nil) + buff.Grow(ftCrypto.TransferIdLength * len(rft.transfers)) + + for tid := range rft.transfers { + buff.Write(tid.Bytes()) + } + + return buff.Bytes() +} diff --git a/storage/fileTransfer/receiveFileTransfers_test.go b/storage/fileTransfer/receiveFileTransfers_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c5c1d8ddd9f94958222c4f7f4aa0691a84d2465f --- /dev/null +++ b/storage/fileTransfer/receiveFileTransfers_test.go @@ -0,0 +1,867 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "fmt" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/ekv" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/primitives/netTime" + "math/rand" + "reflect" + "sort" + "strings" + "testing" +) + +// Tests that NewReceivedFileTransfers creates a new object with empty maps and +// that it is saved to storage +func TestNewReceivedFileTransfers(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + expectedRFT := &ReceivedFileTransfers{ + transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), + info: make(map[format.Fingerprint]*partInfo), + kv: kv.Prefix(receivedFileTransfersPrefix), + } + + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Errorf("NewReceivedFileTransfers returned an error: %+v", err) + } + + // Check that the new ReceivedFileTransfers matches the expected + if !reflect.DeepEqual(expectedRFT, rft) { + t.Errorf("New ReceivedFileTransfers does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expectedRFT, rft) + } + + // Ensure that the transfer list is saved to storage + _, err = expectedRFT.kv.Get( + receivedFileTransfersKey, receivedFileTransfersVersion) + if err != nil { + t.Errorf("Failed to load transfer list from storage: %+v", err) + } +} + +// Tests that ReceivedFileTransfers.AddTransfer adds a new transfer and adds all +// the fingerprints to the info map. +func TestReceivedFileTransfers_AddTransfer(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + } + + // Generate info for new transfer + key, _ := ftCrypto.NewTransferKey(prng) + mac := []byte("transferMAC") + fileSize := uint32(256) + numParts, numFps := uint16(16), uint16(24) + + // Add the transfer + tid, err := rft.AddTransfer(key, mac, fileSize, numParts, numFps, prng) + if err != nil { + t.Errorf("AddTransfer returned an error: %+v", err) + } + + // Check that the transfer was added to the map + transfer, exists := rft.transfers[tid] + if !exists { + t.Errorf("Transfer with ID %s not found.", tid) + } else { + if transfer.GetFileSize() != fileSize { + t.Errorf("New transfer has incorrect file size."+ + "\nexpected: %d\nreceived: %d", fileSize, transfer.GetFileSize()) + } + if transfer.GetNumParts() != numParts { + t.Errorf("New transfer has incorrect number of parts."+ + "\nexpected: %d\nreceived: %d", numParts, transfer.GetNumParts()) + } + if transfer.GetTransferKey() != key { + t.Errorf("New transfer has incorrect transfer key."+ + "\nexpected: %s\nreceived: %s", key, transfer.GetTransferKey()) + } + if transfer.GetNumFps() != numFps { + t.Errorf("New transfer has incorrect number of fingerprints."+ + "\nexpected: %d\nreceived: %d", numFps, transfer.GetNumFps()) + } + } + + // Check that the transfer was added to storage + _, err = loadReceivedTransfer(tid, rft.kv) + if err != nil { + t.Errorf("Transfer with ID %s not found in storage: %+v", tid, err) + } + + // Check that all the fingerprints are in the info map + for fpNum, fp := range ftCrypto.GenerateFingerprints(key, numFps) { + info, exists := rft.info[fp] + if !exists { + t.Errorf("Part fingerprint %s (#%d) not found.", fp, fpNum) + } + + if int(info.fpNum) != fpNum { + t.Errorf("Fingerprint %s has incorrect fingerprint number."+ + "\nexpected: %d\nreceived: %d", fp, fpNum, info.fpNum) + } + + if info.id != tid { + t.Errorf("Fingerprint %s has incorrect transfer ID."+ + "\nexpected: %s\nreceived: %s", fp, tid, info.id) + } + } +} + +// Error path: tests that ReceivedFileTransfers.AddTransfer returns the expected +// error when the PRNG returns an error. +func TestReceivedFileTransfers_AddTransfer_NewTransferIdRngError(t *testing.T) { + prng := NewPrngErr() + kv := versioned.NewKV(make(ekv.Memstore)) + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + } + + // Add the transfer + expectedErr := strings.Split(addTransferNewIdErr, "%")[0] + _, err = rft.AddTransfer(ftCrypto.TransferKey{}, nil, 0, 0, 0, prng) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("AddTransfer did not return the expected error when the PRNG "+ + "should have errored.\nexpected: %s\nrecieved: %+v", expectedErr, err) + } + +} + +// Tests that ReceivedFileTransfers.addFingerprints adds all the fingerprints +// to the map +func TestReceivedFileTransfers_addFingerprints(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + } + + key, _ := ftCrypto.NewTransferKey(prng) + tid, _ := ftCrypto.NewTransferID(prng) + numFps := uint16(24) + + rft.addFingerprints(key, tid, numFps) + + // Check that all the fingerprints are in the info map + for fpNum, fp := range ftCrypto.GenerateFingerprints(key, numFps) { + info, exists := rft.info[fp] + if !exists { + t.Errorf("Part fingerprint %s (#%d) not found.", fp, fpNum) + } + + if int(info.fpNum) != fpNum { + t.Errorf("Fingerprint %s has incorrect fingerprint number."+ + "\nexpected: %d\nreceived: %d", fp, fpNum, info.fpNum) + } + + if info.id != tid { + t.Errorf("Fingerprint %s has incorrect transfer ID."+ + "\nexpected: %s\nreceived: %s", fp, tid, info.id) + } + } +} + +// Tests that ReceivedFileTransfers.GetTransfer returns the newly added +// transfer. +func TestReceivedFileTransfers_GetTransfer(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + } + + // Generate random info for new transfer + key, _ := ftCrypto.NewTransferKey(prng) + mac := []byte("transferMAC") + fileSize := uint32(256) + numParts, numFps := uint16(16), uint16(24) + + // Add the transfer + tid, err := rft.AddTransfer(key, mac, fileSize, numParts, numFps, prng) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + // Get the transfer + transfer, err := rft.GetTransfer(tid) + if err != nil { + t.Errorf("GetTransfer returned an error: %+v", err) + } + + if transfer.GetFileSize() != fileSize { + t.Errorf("New transfer has incorrect file size."+ + "\nexpected: %d\nreceived: %d", fileSize, transfer.GetFileSize()) + } + + if transfer.GetNumParts() != numParts { + t.Errorf("New transfer has incorrect number of parts."+ + "\nexpected: %d\nreceived: %d", numParts, transfer.GetNumParts()) + } + + if transfer.GetNumFps() != numFps { + t.Errorf("New transfer has incorrect number of fingerprints."+ + "\nexpected: %d\nreceived: %d", numFps, transfer.GetNumFps()) + } + + if transfer.GetTransferKey() != key { + t.Errorf("New transfer has incorrect transfer key."+ + "\nexpected: %s\nreceived: %s", key, transfer.GetTransferKey()) + } +} + +// Error path: tests that ReceivedFileTransfers.GetTransfer returns the expected +// error when the provided transfer ID does not correlate to any saved transfer. +func TestReceivedFileTransfers_GetTransfer_NoTransferError(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + } + + // Generate random info for new transfer + key, _ := ftCrypto.NewTransferKey(prng) + mac := []byte("transferMAC") + + // Add the transfer + _, err = rft.AddTransfer(key, mac, 256, 16, 24, prng) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + // Get the transfer + invalidTid, _ := ftCrypto.NewTransferID(prng) + expectedErr := fmt.Sprintf(getReceivedTransferErr, invalidTid) + _, err = rft.GetTransfer(invalidTid) + if err == nil || err.Error() != expectedErr { + t.Errorf("GetTransfer did not return the expected error when no "+ + "transfer for the ID exists.\nexpected: %s\nreceived: %+v", + expectedErr, err) + } +} + +// Tests that DeleteTransfer removed a transfer from memory and storage. +func TestReceivedFileTransfers_DeleteTransfer(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + } + + // Add the transfer + key, _ := ftCrypto.NewTransferKey(prng) + mac := []byte("transferMAC") + numFps := uint16(24) + tid, err := rft.AddTransfer(key, mac, 256, 16, numFps, prng) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + // Delete the transfer + err = rft.DeleteTransfer(tid) + if err != nil { + t.Errorf("DeleteTransfer returned an error: %+v", err) + } + + // Check that the transfer was deleted from the map + _, exists := rft.transfers[tid] + if exists { + t.Errorf("Transfer with ID %s found in map when it should have been "+ + "deleted.", tid) + } + + // Check that the transfer was deleted from storage + _, err = loadReceivedTransfer(tid, rft.kv) + if err == nil { + t.Errorf("Transfer with ID %s found in storage when it should have "+ + "been deleted.", tid) + } + + // Check that all the fingerprints in the info map were deleted + for fpNum, fp := range ftCrypto.GenerateFingerprints(key, numFps) { + _, exists := rft.info[fp] + if exists { + t.Errorf("Part fingerprint %s (#%d) found in map when it should "+ + "have been deleted.", fp, fpNum) + } + } +} + +// Error path: tests that ReceivedFileTransfers.DeleteTransfer returns the +// expected error when the provided transfer ID does not correlate to any saved +// transfer. +func TestReceivedFileTransfers_DeleteTransfer_NoTransferError(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + } + + // Delete the transfer + invalidTid, _ := ftCrypto.NewTransferID(prng) + expectedErr := fmt.Sprintf(getReceivedTransferErr, invalidTid) + err = rft.DeleteTransfer(invalidTid) + if err == nil || err.Error() != expectedErr { + t.Errorf("DeleteTransfer did not return the expected error when no "+ + "transfer for the ID exists.\nexpected: %s\nreceived: %+v", + expectedErr, err) + } +} + +// Tests that ReceivedFileTransfers.AddPart modifies the expected transfer in +// memory. +func TestReceivedFileTransfers_AddPart(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + } + + prng := NewPrng(42) + key, _ := ftCrypto.NewTransferKey(prng) + mac := []byte("transferMAC") + + tid, err := rft.AddTransfer(key, mac, 256, 16, 24, prng) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + // Create encrypted part + expectedData := []byte("test") + partNum, fpNum := uint16(1), uint16(1) + encryptedPart, mac, padding := newEncryptedPartData( + key, expectedData, fpNum, t) + fp := ftCrypto.GenerateFingerprint(key, fpNum) + + // Add encrypted part + rt, _, err := rft.AddPart(encryptedPart, padding, mac, partNum, fp) + if err != nil { + t.Errorf("AddPart returned an error: %+v", err) + } + + // Make sure its fingerprint was removed from the map + _, exists := rft.info[fp] + if exists { + t.Errorf("Fingerprints %s for added part found when it should have "+ + "been deleted.", fp) + } + + // Check that the transfer is correct + expectedRT := rft.transfers[tid] + if !reflect.DeepEqual(expectedRT, rt) { + t.Errorf("Returned transfer does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expectedRT, rt) + } + + // Check that the correct part was stored + receivedPart := expectedRT.receivedParts.parts[partNum] + if !bytes.Equal(receivedPart, expectedData) { + t.Errorf("Part in memory is not expected."+ + "\nexpected: %q\nreceived: %q", expectedData, receivedPart) + } +} + +// Error path: tests that ReceivedFileTransfers.AddPart returns the expected +// error when the provided fingerprint does not correlate to any part. +func TestReceivedFileTransfers_AddPart_NoFingerprintError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + } + + // Create encrypted part + fp := format.NewFingerprint([]byte("invalidTransferKey")) + + // Add encrypted part + expectedErr := fmt.Sprintf(noFingerprintErr, fp) + _, _, err = rft.AddPart([]byte{}, []byte{}, []byte{}, 0, fp) + if err == nil || err.Error() != expectedErr { + t.Errorf("AddPart did not return the expected error when no part for "+ + "the fingerprint exists.\nexpected: %s\nreceived: %+v", + expectedErr, err) + } +} + +// Error path: tests that ReceivedFileTransfers.AddPart returns the expected +// error when the provided transfer ID does not correlate to any saved transfer. +func TestReceivedFileTransfers_AddPart_NoTransferError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + } + + prng := NewPrng(42) + key, _ := ftCrypto.NewTransferKey(prng) + mac := []byte("transferMAC") + + _, err = rft.AddTransfer(key, mac, 256, 16, 24, prng) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + // Create encrypted part + fp := ftCrypto.GenerateFingerprint(key, 1) + invalidTid, _ := ftCrypto.NewTransferID(prng) + rft.info[fp].id = invalidTid + + // Add encrypted part + expectedErr := fmt.Sprintf(getReceivedTransferErr, invalidTid) + _, _, err = rft.AddPart([]byte{}, []byte{}, []byte{}, 0, fp) + if err == nil || err.Error() != expectedErr { + t.Errorf("AddPart did not return the expected error when no transfer "+ + "for the ID exists.\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: tests that ReceivedFileTransfers.AddPart returns the expected +// error when the encrypted part data, MAC, and padding are invalid. +func TestReceivedFileTransfers_AddPart_AddPartError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + } + + prng := NewPrng(42) + key, _ := ftCrypto.NewTransferKey(prng) + mac := []byte("transferMAC") + numParts := uint16(16) + + tid, err := rft.AddTransfer(key, mac, 256, numParts, 24, prng) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + // Create encrypted part + partNum, fpNum := uint16(1), uint16(1) + encryptedPart := []byte("invalidPart") + mac = []byte("invalidMAC") + padding := make([]byte, 24) + fp := ftCrypto.GenerateFingerprint(key, fpNum) + + // Add encrypted part + expectedErr := fmt.Sprintf(addPartErr, partNum, numParts, tid, "") + _, _, err = rft.AddPart(encryptedPart, padding, mac, partNum, fp) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("AddPart did not return the expected error when the "+ + "encrypted part, padding, and MAC are invalid."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that ReceivedFileTransfers.GetFile returns the complete file. +func TestReceivedFileTransfers_GetFile(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + } + + prng := NewPrng(42) + numParts := uint16(16) + + // Generate file parts and expected file + parts, expectedFile := newRandomPartStore( + numParts, kv.Prefix("test"), rand.New(rand.NewSource(42)), t) + + key, _ := ftCrypto.NewTransferKey(prng) + mac := ftCrypto.CreateTransferMAC(expectedFile, key) + + fileSize := uint32(len(expectedFile)) + + tid, err := rft.AddTransfer(key, mac, fileSize, numParts, 24, prng) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + for partNum, part := range parts.parts { + fp := ftCrypto.GenerateFingerprint(key, partNum) + encPart, partMac, padding := newEncryptedPartData(key, part, partNum, t) + _, _, err = rft.AddPart(encPart, padding, partMac, partNum, fp) + if err != nil { + t.Errorf("AddPart encountered an error: %+v", err) + } + } + + receivedFile, err := rft.GetFile(tid) + if err != nil { + t.Errorf("GetFile returned an error: %+v", err) + } + + if !bytes.Equal(expectedFile, receivedFile) { + t.Errorf("Received file does not match expected."+ + "\nexpected: %q\nreceived: %q", expectedFile, receivedFile) + } +} + +// Error path: tests that ReceivedFileTransfers.GetFile returns the expected +// error when no transfer with the ID exists. +func TestReceivedFileTransfers_GetFile_NoTransferError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + } + + tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) + + expectedErr := fmt.Sprintf(getReceivedTransferErr, tid) + _, err = rft.GetFile(tid) + if err == nil || err.Error() != expectedErr { + t.Errorf("GetFile did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// Tests that the ReceivedFileTransfers loaded from storage by +// LoadReceivedFileTransfers matches the original in memory. +func TestLoadReceivedFileTransfers(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to make new ReceivedFileTransfers: %+v", err) + } + + // Add 10 transfers to map in memory + list := make([]ftCrypto.TransferID, 10) + for i := range list { + tid, rt, _ := newRandomReceivedTransfer(16, 24, rft.kv, t) + + // Add to transfer + rft.transfers[tid] = rt + + // Add unused fingerprints + for n := uint32(0); n < rt.fpVector.GetNumKeys(); n++ { + if !rt.fpVector.Used(n) { + fp := ftCrypto.GenerateFingerprint(rt.key, uint16(n)) + rft.info[fp] = &partInfo{tid, uint16(n)} + } + } + + // Add ID to list + list[i] = tid + } + + // Save list to storage + if err = rft.saveTransfersList(); err != nil { + t.Errorf("Faileds to save transfers list: %+v", err) + } + + // Load ReceivedFileTransfers from storage + loadedRFT, err := LoadReceivedFileTransfers(kv) + if err != nil { + t.Errorf("LoadReceivedFileTransfers returned an error: %+v", err) + } + + if !reflect.DeepEqual(rft, loadedRFT) { + t.Errorf("Loaded ReceivedFileTransfers does not match original in "+ + "memory.\nexpected: %+v\nreceived: %+v", rft, loadedRFT) + } +} + +// Error path: tests that ReceivedFileTransfers returns the expected error when +// the transfer list cannot be loaded from storage. +func TestLoadReceivedFileTransfers_NoListInStorageError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + expectedErr := strings.Split(loadReceivedTransfersListErr, "%")[0] + + // Load ReceivedFileTransfers from storage + _, err := LoadReceivedFileTransfers(kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("LoadReceivedFileTransfers did not return the expected error "+ + "when there is no transfer list saved in storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: tests that ReceivedFileTransfers returns the expected error when +// the first transfer loaded from storage does not exist. +func TestLoadReceivedFileTransfers_NoTransferInStorageError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + expectedErr := strings.Split(loadReceivedFileTransfersErr, "%")[0] + + // Save list of one transfer ID to storage + obj := &versioned.Object{ + Version: receivedFileTransfersVersion, + Timestamp: netTime.Now(), + Data: ftCrypto.UnmarshalTransferID([]byte("testID_01")).Bytes(), + } + err := kv.Prefix(receivedFileTransfersPrefix).Set( + receivedFileTransfersKey, receivedFileTransfersVersion, obj) + + // Load ReceivedFileTransfers from storage + _, err = LoadReceivedFileTransfers(kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("LoadReceivedFileTransfers did not return the expected error "+ + "when there is no transfer saved in storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that the ReceivedFileTransfers loaded from storage by +// NewOrLoadReceivedFileTransfers matches the original in memory. +func TestNewOrLoadReceivedFileTransfers(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to make new ReceivedFileTransfers: %+v", err) + } + + // Add 10 transfers to map in memory + list := make([]ftCrypto.TransferID, 10) + for i := range list { + tid, rt, _ := newRandomReceivedTransfer(16, 24, rft.kv, t) + + // Add to transfer + rft.transfers[tid] = rt + + // Add unused fingerprints + for n := uint32(0); n < rt.fpVector.GetNumKeys(); n++ { + if !rt.fpVector.Used(n) { + fp := ftCrypto.GenerateFingerprint(rt.key, uint16(n)) + rft.info[fp] = &partInfo{tid, uint16(n)} + } + } + + // Add ID to list + list[i] = tid + } + + // Save list to storage + if err = rft.saveTransfersList(); err != nil { + t.Errorf("Faileds to save transfers list: %+v", err) + } + + // Load ReceivedFileTransfers from storage + loadedRFT, err := NewOrLoadReceivedFileTransfers(kv) + if err != nil { + t.Errorf("NewOrLoadReceivedFileTransfers returned an error: %+v", err) + } + + if !reflect.DeepEqual(rft, loadedRFT) { + t.Errorf("Loaded ReceivedFileTransfers does not match original in "+ + "memory.\nexpected: %+v\nreceived: %+v", rft, loadedRFT) + } +} + +// Tests that NewOrLoadReceivedFileTransfers returns a new ReceivedFileTransfers +// when there is none in storage. +func TestNewOrLoadReceivedFileTransfers_NewSentFileTransfers(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + + // Load ReceivedFileTransfers from storage + loadedRFT, err := NewOrLoadReceivedFileTransfers(kv) + if err != nil { + t.Errorf("NewOrLoadReceivedFileTransfers returned an error: %+v", err) + } + + newRFT, _ := NewReceivedFileTransfers(kv) + + if !reflect.DeepEqual(newRFT, loadedRFT) { + t.Errorf("Returned ReceivedFileTransfers does not match new."+ + "\nexpected: %+v\nreceived: %+v", newRFT, loadedRFT) + } +} + +// Error path: tests that NewOrLoadReceivedFileTransfers returns the expected +// error when the first transfer loaded from storage does not exist. +func TestNewOrLoadReceivedFileTransfers_NoTransferInStorageError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + expectedErr := strings.Split(loadReceivedFileTransfersErr, "%")[0] + + // Save list of one transfer ID to storage + obj := &versioned.Object{ + Version: receivedFileTransfersVersion, + Timestamp: netTime.Now(), + Data: ftCrypto.UnmarshalTransferID([]byte("testID_01")).Bytes(), + } + err := kv.Prefix(receivedFileTransfersPrefix).Set( + receivedFileTransfersKey, receivedFileTransfersVersion, obj) + + // Load ReceivedFileTransfers from storage + _, err = NewOrLoadReceivedFileTransfers(kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("NewOrLoadReceivedFileTransfers did not return the expected "+ + "error when there is no transfer saved in storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that the list saved by ReceivedFileTransfers.saveTransfersList matches +// the list loaded by ReceivedFileTransfers.load. +func TestLoadReceivedFileTransfers_saveTransfersList_loadTransfersList(t *testing.T) { + rft, err := NewReceivedFileTransfers(versioned.NewKV(make(ekv.Memstore))) + if err != nil { + t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + } + + // Fill map with transfers + expectedList := make([]ftCrypto.TransferID, 7) + for i := range expectedList { + prng := NewPrng(int64(i)) + key, _ := ftCrypto.NewTransferKey(prng) + mac := []byte("transferMAC") + numParts := uint16(i) + numFps := uint16(float32(i) * 2.5) + + expectedList[i], err = rft.AddTransfer( + key, mac, 256, numParts, numFps, prng) + if err != nil { + t.Errorf("Failed to add new transfer #%d: %+v", i, err) + } + } + + // Save the list + err = rft.saveTransfersList() + if err != nil { + t.Errorf("saveTransfersList returned an error: %+v", err) + } + + // Load the list + loadedList, err := rft.loadTransfersList() + if err != nil { + t.Errorf("loadTransfersList returned an error: %+v", err) + } + + // Sort slices so they can be compared + sort.SliceStable(expectedList, func(i, j int) bool { + return bytes.Compare(expectedList[i].Bytes(), expectedList[j].Bytes()) == -1 + }) + sort.SliceStable(loadedList, func(i, j int) bool { + return bytes.Compare(loadedList[i].Bytes(), loadedList[j].Bytes()) == -1 + }) + + if !reflect.DeepEqual(expectedList, loadedList) { + t.Errorf("Loaded transfer list does not match expected."+ + "\nexpected: %v\nreceived: %v", expectedList, loadedList) + } +} + +// Tests that the list loaded by ReceivedFileTransfers.load matches the original +// saved to storage. +func TestLoadReceivedFileTransfers_load(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + rft, err := NewReceivedFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + } + + // Fill map with transfers + idList := make([]ftCrypto.TransferID, 7) + for i := range idList { + prng := NewPrng(int64(i)) + key, _ := ftCrypto.NewTransferKey(prng) + mac := []byte("transferMAC") + + idList[i], err = rft.AddTransfer(key, mac, 256, 16, 24, prng) + if err != nil { + t.Errorf("Failed to add new transfer #%d: %+v", i, err) + } + } + + // Save the list + err = rft.saveTransfersList() + if err != nil { + t.Errorf("saveTransfersList returned an error: %+v", err) + } + + // Build new ReceivedFileTransfers + newRFT := &ReceivedFileTransfers{ + transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), + info: make(map[format.Fingerprint]*partInfo), + kv: kv.Prefix(receivedFileTransfersPrefix), + } + + // Load saved transfers from storage + err = newRFT.load(idList) + if err != nil { + t.Errorf("load returned an error: %+v", err) + } + + // Check that all transfer were loaded from storage + for _, id := range idList { + transfer, exists := newRFT.transfers[id] + if !exists { + t.Errorf("Transfer %s not loaded from storage.", id) + } + + // Check that the loaded transfer matches the original in memory + if !reflect.DeepEqual(transfer, rft.transfers[id]) { + t.Errorf("Loaded transfer does not match original."+ + "\noriginal: %+v\nreceived: %+v", rft.transfers[id], transfer) + } + + // Make sure all fingerprints are present + for fpNum, fp := range ftCrypto.GenerateFingerprints(transfer.key, 24) { + info, exists := newRFT.info[fp] + if !exists { + t.Errorf("Fingerprint %d for transfer %s does not exist in "+ + "the map.", fpNum, id) + } + + if info.id != id { + t.Errorf("Fingerprint has wrong transfer ID."+ + "\nexpected: %s\nreceived: %s", id, info.id) + } + if int(info.fpNum) != fpNum { + t.Errorf("Fingerprint has wrong number."+ + "\nexpected: %d\nreceived: %d", fpNum, info.fpNum) + } + } + } +} + +// Tests that a transfer list marshalled with +// ReceivedFileTransfers.marshalTransfersList and unmarshalled with +// unmarshalTransfersList matches the original. +func TestReceivedFileTransfers_marshalTransfersList_unmarshalTransfersList(t *testing.T) { + prng := NewPrng(42) + rft := &ReceivedFileTransfers{ + transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), + } + + // Add 10 transfers to map in memory + for i := 0; i < 10; i++ { + tid, _ := ftCrypto.NewTransferID(prng) + rft.transfers[tid] = &ReceivedTransfer{} + } + + // Marshal into byte slice + marshalledBytes := rft.marshalTransfersList() + + // Unmarshal marshalled bytes into transfer ID list + list := unmarshalTransfersList(marshalledBytes) + + // Check that the list has all the transfer IDs in memory + for _, tid := range list { + if _, exists := rft.transfers[tid]; !exists { + t.Errorf("No transfer for ID %s exists.", tid) + } else { + delete(rft.transfers, tid) + } + } +} diff --git a/storage/fileTransfer/receiveTransfer.go b/storage/fileTransfer/receiveTransfer.go new file mode 100644 index 0000000000000000000000000000000000000000..7f5a5d897424fc46b9f6ddcc5acaa3db9a29b8b7 --- /dev/null +++ b/storage/fileTransfer/receiveTransfer.go @@ -0,0 +1,427 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "encoding/binary" + "github.com/pkg/errors" + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/storage/utility" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/primitives/netTime" + "sync" + "time" +) + +// Storage constants +const ( + receivedTransfersPrefix = "FileTransferReceivedTransferStore" + receivedTransferKey = "ReceivedTransfer" + receivedTransferVersion = 0 + receivedFpVectorKey = "ReceivedFingerprintVector" +) + +// Error messages for ReceivedTransfer +const ( + newReceivedTransferPartStoreErr = "failed to create new part store: %+v" + newReceivedTransferFpVectorErr = "failed to create new StateVector for fingerprints: %+v" + loadReceivedStoreErr = "failed to load received transfer info from storage: %+v" + loadReceivePartStoreErr = "failed to load received part store from storage: %+v" + loadReceiveFpVectorErr = "failed to load received fingerprint vector from storage: %+v" + getFileErr = "missing %d/%d parts of the file" + getTransferMacErr = "failed to verify transfer MAC" + deleteReceivedTransferInfoErr = "failed to delete received transfer info from storage: %+v" + deleteReceivedFpVectorErr = "failed to delete received fingerprint vector from storage: %+v" + deleteReceivedFilePartsErr = "failed to delete received file parts from storage: %+v" +) + +// ReceivedTransfer contains information and progress data for receiving an in- +// progress file transfer. +type ReceivedTransfer struct { + // The transfer key is a randomly generated key created by the sender and + // used to generate MACs and fingerprints + key ftCrypto.TransferKey + + // The MAC for the entire file; used to verify the integrity of all parts + transferMAC []byte + + // Size of the entire file in bytes + fileSize uint32 + + // The number of file parts in the file + numParts uint16 + + // The number `of fingerprints to generate (function of numParts and the + // retry rate) + numFps uint16 + + // Stores the state of a fingerprint (used/unused) in a bitstream format + // (has its own storage backend) + fpVector *utility.StateVector + + // Saves each part in order (has its own storage backend) + receivedParts *partStore + + // List of callbacks to call for every send + progressCallbacks []*receivedCallbackTracker + + mux sync.RWMutex + kv *versioned.KV +} + +// NewReceivedTransfer generates a ReceivedTransfer with the specified +// transfer key, transfer ID, and a number of parts. +func NewReceivedTransfer(tid ftCrypto.TransferID, key ftCrypto.TransferKey, + transferMAC []byte, fileSize uint32, numParts, numFps uint16, + kv *versioned.KV) (*ReceivedTransfer, error) { + + // Create the ReceivedTransfer object + rf := &ReceivedTransfer{ + key: key, + transferMAC: transferMAC, + fileSize: fileSize, + numParts: numParts, + numFps: numFps, + progressCallbacks: []*receivedCallbackTracker{}, + kv: kv.Prefix(makeReceivedTransferPrefix(tid)), + } + + var err error + + // Create new StateVector for storing fingerprint usage + rf.fpVector, err = utility.NewStateVector( + rf.kv, receivedFpVectorKey, uint32(numFps)) + if err != nil { + return nil, errors.Errorf(newReceivedTransferFpVectorErr, err) + } + + // Create new part store + rf.receivedParts, err = newPartStore(rf.kv, numParts) + if err != nil { + return nil, errors.Errorf(newReceivedTransferPartStoreErr, err) + } + + // Save all fields without their own storage to storage + return rf, rf.saveInfo() +} + +// GetTransferKey returns the transfer Key for this received transfer. +func (rt *ReceivedTransfer) GetTransferKey() ftCrypto.TransferKey { + rt.mux.Lock() + defer rt.mux.Unlock() + + return rt.key +} + +// GetTransferMAC returns the transfer MAC for this received transfer. +func (rt *ReceivedTransfer) GetTransferMAC() []byte { + rt.mux.Lock() + defer rt.mux.Unlock() + + return rt.transferMAC +} + +// GetNumParts returns the number of file parts in this transfer. +func (rt *ReceivedTransfer) GetNumParts() uint16 { + rt.mux.Lock() + defer rt.mux.Unlock() + + return rt.numParts +} + +// GetNumFps returns the number of fingerprints. +func (rt *ReceivedTransfer) GetNumFps() uint16 { + rt.mux.Lock() + defer rt.mux.Unlock() + + return rt.numFps +} + +// GetNumAvailableFps returns the number of unused fingerprints. +func (rt *ReceivedTransfer) GetNumAvailableFps() uint16 { + rt.mux.Lock() + defer rt.mux.Unlock() + + return uint16(rt.fpVector.GetNumAvailable()) +} + +// GetFileSize returns the file size in bytes. +func (rt *ReceivedTransfer) GetFileSize() uint32 { + rt.mux.Lock() + defer rt.mux.Unlock() + + return rt.fileSize +} + +// GetProgress returns the current progress of the transfer. Completed is true +// when all parts have been received, received is the number of parts received, +// and total is the total number of parts excepted to be received. +func (rt *ReceivedTransfer) GetProgress() (completed bool, received, total uint16) { + rt.mux.RLock() + defer rt.mux.RUnlock() + + received = uint16(rt.fpVector.GetNumUsed()) + total = rt.numParts + + if received == total { + completed = true + } + + return completed, received, total +} + +// CallProgressCB calls all the progress callbacks with the most recent progress +// information. +func (rt *ReceivedTransfer) CallProgressCB(err error) { + rt.mux.RLock() + defer rt.mux.RUnlock() + + for _, cb := range rt.progressCallbacks { + cb.call(rt, err) + } +} + +// AddProgressCB appends a new interfaces.ReceivedProgressCallback to the list +// of progress callbacks to be called and calls it. The period is how often the +// callback should be called when there are updates. +func (rt *ReceivedTransfer) AddProgressCB( + cb interfaces.ReceivedProgressCallback, period time.Duration) { + rt.mux.Lock() + + // Add callback + rct := newReceivedCallbackTracker(cb, period) + rt.progressCallbacks = append(rt.progressCallbacks, rct) + + rt.mux.Unlock() + + // Trigger the initial call + rct.callNow(rt, nil) +} + +// AddPart decrypts an encrypted file part, adds it to the list of received +// parts and marks its fingerprint as used. +func (rt *ReceivedTransfer) AddPart(encryptedPart, padding, mac []byte, partNum, + fpNum uint16) error { + rt.mux.Lock() + defer rt.mux.Unlock() + + // Decrypt the encrypted file part + decryptedPart, err := ftCrypto.DecryptPart( + rt.key, encryptedPart, padding, mac, fpNum) + if err != nil { + return err + } + + // Add the part to the list of parts + err = rt.receivedParts.addPart(decryptedPart, partNum) + if err != nil { + return err + } + + // Mark the fingerprint as used + rt.fpVector.Use(uint32(fpNum)) + + return nil +} + +// GetFile returns all the file parts combined into a single byte slice. An +// error is returned if parts are missing or if the MAC cannot be verified. The +// incomplete or invalid file is returned despite errors. +func (rt *ReceivedTransfer) GetFile() ([]byte, error) { + rt.mux.Lock() + defer rt.mux.Unlock() + + fileData, missingParts := rt.receivedParts.getFile() + if missingParts != 0 { + return fileData, errors.Errorf(getFileErr, missingParts, rt.numParts) + } + + // Remove extra data added when sending as parts + fileData = fileData[:rt.fileSize] + + if !ftCrypto.VerifyTransferMAC(fileData, rt.key, rt.transferMAC) { + return fileData, errors.New(getTransferMacErr) + } + + return fileData, nil +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// loadReceivedTransfer loads the ReceivedTransfer with the given transfer ID +// from storage. +func loadReceivedTransfer(tid ftCrypto.TransferID, kv *versioned.KV) ( + *ReceivedTransfer, error) { + // Create the ReceivedTransfer object + rt := &ReceivedTransfer{ + progressCallbacks: []*receivedCallbackTracker{}, + kv: kv.Prefix(makeReceivedTransferPrefix(tid)), + } + + // Load transfer key and number of received parts from storage + err := rt.loadInfo() + if err != nil { + return nil, errors.Errorf(loadReceivedStoreErr, err) + } + + // Load the fingerprint vector from storage + rt.fpVector, err = utility.LoadStateVector(rt.kv, receivedFpVectorKey) + if err != nil { + return nil, errors.Errorf(loadReceiveFpVectorErr, err) + } + + // Load received part store from storage + rt.receivedParts, err = loadPartStore(rt.kv) + if err != nil { + return nil, errors.Errorf(loadReceivePartStoreErr, err) + } + + return rt, nil +} + +// saveInfo saves all fields in ReceivedTransfer that do not have their own +// storage (transfer key, transfer MAC, file size, number of file parts, and +// number of fingerprints) to storage. +func (rt *ReceivedTransfer) saveInfo() error { + rt.mux.Lock() + defer rt.mux.Unlock() + + // Create new versioned object for the ReceivedTransfer + vo := &versioned.Object{ + Version: receivedTransferVersion, + Timestamp: netTime.Now(), + Data: rt.marshal(), + } + + // Save versioned object + return rt.kv.Set(receivedTransferKey, receivedTransferVersion, vo) +} + +// loadInfo gets the transfer key, transfer MAC, file size, number of part, and +// number of fingerprints from storage and saves it to the ReceivedTransfer. +func (rt *ReceivedTransfer) loadInfo() error { + vo, err := rt.kv.Get(receivedTransferKey, receivedTransferVersion) + if err != nil { + return err + } + + // Unmarshal the transfer key and numParts + rt.key, rt.transferMAC, rt.fileSize, rt.numParts, rt.numFps = + unmarshalReceivedTransfer(vo.Data) + + return nil +} + +// delete deletes all data in the ReceivedTransfer from storage. +func (rt *ReceivedTransfer) delete() error { + rt.mux.Lock() + defer rt.mux.Unlock() + + // Delete received transfer info from storage + err := rt.deleteInfo() + if err != nil { + return errors.Errorf(deleteReceivedTransferInfoErr, err) + } + + // Delete fingerprint vector from storage + err = rt.fpVector.Delete() + if err != nil { + return errors.Errorf(deleteReceivedFpVectorErr, err) + } + + // Delete received file parts from storage + err = rt.receivedParts.delete() + if err != nil { + return errors.Errorf(deleteReceivedFilePartsErr, err) + } + + return nil +} + +// deleteInfo removes received transfer info (transfer key, transfer MAC, file +// size, number of parts, and number of fingerprints) from storage. +func (rt *ReceivedTransfer) deleteInfo() error { + return rt.kv.Delete(receivedTransferKey, receivedTransferVersion) +} + +// marshal serializes all primitive fields in ReceivedTransfer (key, +// transferMAC, fileSize, numParts, and numFps). +func (rt *ReceivedTransfer) marshal() []byte { + // Construct the buffer to the correct size (size of key + transfer MAC + // length (2 bytes) + transfer MAC + fileSize (4 bytes) + numParts (2 bytes) + // + numFps (2 bytes)) + buff := bytes.NewBuffer(nil) + buff.Grow(ftCrypto.TransferKeyLength + 2 + len(rt.transferMAC) + 4 + 2 + 2) + + // Write the key to the buffer + buff.Write(rt.key.Bytes()) + + // Write the length of the transfer MAC to the buffer + b := make([]byte, 2) + binary.LittleEndian.PutUint16(b, uint16(len(rt.transferMAC))) + buff.Write(b) + + // Write the transfer MAC to the buffer + buff.Write(rt.transferMAC) + + // Write the file size to the buffer + b = make([]byte, 4) + binary.LittleEndian.PutUint32(b, rt.fileSize) + buff.Write(b) + + // Write the number of parts to the buffer + b = make([]byte, 2) + binary.LittleEndian.PutUint16(b, rt.numParts) + buff.Write(b) + + // Write the number of fingerprints to the buffer + b = make([]byte, 2) + binary.LittleEndian.PutUint16(b, rt.numFps) + buff.Write(b) + + // Return the serialized data + return buff.Bytes() +} + +// unmarshalReceivedTransfer deserializes a byte slice into the primitive fields +// of ReceivedTransfer (key, transferMAC, fileSize, numParts, and numFps). +func unmarshalReceivedTransfer(data []byte) (key ftCrypto.TransferKey, + transferMAC []byte, size uint32, numParts, numFps uint16) { + + buff := bytes.NewBuffer(data) + + // Read the transfer key from the buffer + key = ftCrypto.UnmarshalTransferKey(buff.Next(ftCrypto.TransferKeyLength)) + + // Read the size of the transfer MAC from the buffer + transferMacSize := binary.LittleEndian.Uint16(buff.Next(2)) + + // Read the transfer MAC from the buffer + transferMAC = buff.Next(int(transferMacSize)) + + // Read the file size from the buffer + size = binary.LittleEndian.Uint32(buff.Next(4)) + + // Read the number of part from the buffer + numParts = binary.LittleEndian.Uint16(buff.Next(2)) + + // Read the number of fingerprints from the buffer + numFps = binary.LittleEndian.Uint16(buff.Next(2)) + + return key, transferMAC, size, numParts, numFps +} + +// makeReceivedTransferPrefix generates the unique prefix used on the key value +// store to store received transfers for the given transfer ID. +func makeReceivedTransferPrefix(tid ftCrypto.TransferID) string { + return receivedTransfersPrefix + tid.String() +} diff --git a/storage/fileTransfer/receiveTransfer_test.go b/storage/fileTransfer/receiveTransfer_test.go new file mode 100644 index 0000000000000000000000000000000000000000..9a5130e82c3109cc28fe86a306834d9cf005698b --- /dev/null +++ b/storage/fileTransfer/receiveTransfer_test.go @@ -0,0 +1,906 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "fmt" + "github.com/pkg/errors" + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/storage/utility" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/ekv" + "reflect" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// Tests that NewReceivedTransfer correctly creates a new ReceivedTransfer and +// that it is saved to storage. +func Test_NewReceivedTransfer(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + tid, _ := ftCrypto.NewTransferID(prng) + key, _ := ftCrypto.NewTransferKey(prng) + mac := []byte("transferMAC") + kvPrefixed := kv.Prefix(makeReceivedTransferPrefix(tid)) + fileSize := uint32(256) + numParts, numFps := uint16(16), uint16(24) + fpVector, _ := utility.NewStateVector( + kvPrefixed, receivedFpVectorKey, uint32(numFps)) + + expected := &ReceivedTransfer{ + key: key, + transferMAC: mac, + fileSize: fileSize, + numParts: numParts, + numFps: numFps, + fpVector: fpVector, + receivedParts: &partStore{ + parts: make(map[uint16][]byte, numParts), + numParts: numParts, + kv: kvPrefixed, + }, + progressCallbacks: []*receivedCallbackTracker{}, + kv: kvPrefixed, + } + + // Create new ReceivedTransfer + rt, err := NewReceivedTransfer(tid, key, mac, fileSize, numParts, numFps, kv) + if err != nil { + t.Fatalf("NewReceivedTransfer returned an error: %v", err) + } + + // Check that the new object matches the expected + if !reflect.DeepEqual(expected, rt) { + t.Errorf("New ReceivedTransfer does not match expected."+ + "\nexpected: %#v\nreceived: %#v", expected, rt) + } + + // Make sure it is saved to storage + _, err = kvPrefixed.Get(receivedTransferKey, receivedTransferVersion) + if err != nil { + t.Fatalf("Failed to load ReceivedTransfer from storage: %+v", err) + } + + // Check that the fingerprint vector has correct values + if rt.fpVector.GetNumAvailable() != uint32(numFps) { + t.Errorf("Incorrect number of available keys in fingerprint list."+ + "\nexpected: %d\nreceived: %d", numFps, rt.fpVector.GetNumAvailable()) + } + if rt.fpVector.GetNumKeys() != uint32(numFps) { + t.Errorf("Incorrect number of keys in fingerprint list."+ + "\nexpected: %d\nreceived: %d", numFps, rt.fpVector.GetNumKeys()) + } + if rt.fpVector.GetNumUsed() != 0 { + t.Errorf("Incorrect number of used keys in fingerprint list."+ + "\nexpected: %d\nreceived: %d", 0, rt.fpVector.GetNumUsed()) + } +} + +// Tests that ReceivedTransfer.GetTransferKey returns the expected key. +func TestReceivedTransfer_GetTransferKey(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + + tid, _ := ftCrypto.NewTransferID(prng) + mac := []byte("transferMAC") + expectedKey, _ := ftCrypto.NewTransferKey(prng) + + rt, err := NewReceivedTransfer(tid, expectedKey, mac, 256, 16, 1, kv) + if err != nil { + t.Errorf("Failed to create new ReceivedTransfer: %+v", err) + } + + if expectedKey != rt.GetTransferKey() { + t.Errorf("Failed to get expected transfer key."+ + "\nexpected: %s\nreceived: %s", expectedKey, rt.GetTransferKey()) + } +} + +// Tests that ReceivedTransfer.GetTransferMAC returns the expected bytes. +func TestReceivedTransfer_GetTransferMAC(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + + tid, _ := ftCrypto.NewTransferID(prng) + expectedMAC := []byte("transferMAC") + key, _ := ftCrypto.NewTransferKey(prng) + + rt, err := NewReceivedTransfer(tid, key, expectedMAC, 256, 16, 1, kv) + if err != nil { + t.Errorf("Failed to create new ReceivedTransfer: %+v", err) + } + + if !bytes.Equal(expectedMAC, rt.GetTransferMAC()) { + t.Errorf("Failed to get expected transfer MAC."+ + "\nexpected: %v\nreceived: %v", expectedMAC, rt.GetTransferMAC()) + } +} + +// Tests that ReceivedTransfer.GetNumParts returns the expected number of parts. +func TestReceivedTransfer_GetNumParts(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + expectedNumParts := uint16(16) + _, rt, _ := newRandomReceivedTransfer(expectedNumParts, 20, kv, t) + + if expectedNumParts != rt.GetNumParts() { + t.Errorf("Failed to get expected number of parts."+ + "\nexpected: %d\nreceived: %d", expectedNumParts, rt.GetNumParts()) + } +} + +// Tests that ReceivedTransfer.GetNumFps returns the expected number of +// fingerprints. +func TestReceivedTransfer_GetNumFps(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + expectedNumFps := uint16(20) + _, rt, _ := newRandomReceivedTransfer(16, expectedNumFps, kv, t) + + if expectedNumFps != rt.GetNumFps() { + t.Errorf("Failed to get expected number of fingerprints."+ + "\nexpected: %d\nreceived: %d", expectedNumFps, rt.GetNumFps()) + } +} + +// Tests that ReceivedTransfer.GetNumAvailableFps returns the expected number of +// available fingerprints. +func TestReceivedTransfer_GetNumAvailableFps(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + numParts, numFps := uint16(16), uint16(24) + _, rt, _ := newRandomReceivedTransfer(numParts, numFps, kv, t) + + if numFps-numParts != rt.GetNumAvailableFps() { + t.Errorf("Failed to get expected number of available fingerprints."+ + "\nexpected: %d\nreceived: %d", + numFps-numParts, rt.GetNumAvailableFps()) + } +} + +// Tests that ReceivedTransfer.GetFileSize returns the expected file size. +func TestReceivedTransfer_GetFileSize(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, rt, file := newRandomReceivedTransfer(16, 20, kv, t) + expectedFileSize := len(file) + + if expectedFileSize != int(rt.GetFileSize()) { + t.Errorf("Failed to get expected file size."+ + "\nexpected: %d\nreceived: %d", expectedFileSize, rt.GetFileSize()) + } +} + +// checkReceivedProgress compares the output of ReceivedTransfer.GetProgress to +// expected values. +func checkReceivedProgress(completed bool, received, total uint16, + eCompleted bool, eReceived, eTotal uint16) error { + if eCompleted != completed || eReceived != received || eTotal != total { + return errors.Errorf("Returned progress does not match expected."+ + "\n completed received total"+ + "\nexpected: %5t %3d %3d"+ + "\nreceived: %5t %3d %3d", + eCompleted, eReceived, eTotal, + completed, received, total) + } + + return nil +} + +// Tests that ReceivedTransfer.GetProgress returns the expected progress metrics +// for various transfer states. +func TestReceivedTransfer_GetProgress(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + numParts := uint16(16) + _, rt, _ := newEmptyReceivedTransfer(numParts, 20, kv, t) + + completed, received, total := rt.GetProgress() + err := checkReceivedProgress(completed, received, total, false, 0, numParts) + if err != nil { + t.Error(err) + } + + _, _ = rt.fpVector.Next() + + completed, received, total = rt.GetProgress() + err = checkReceivedProgress(completed, received, total, false, 1, numParts) + if err != nil { + t.Error(err) + } + + for i := 0; i < 4; i++ { + _, _ = rt.fpVector.Next() + } + + completed, received, total = rt.GetProgress() + err = checkReceivedProgress(completed, received, total, false, 5, numParts) + if err != nil { + t.Error(err) + } + + for i := 0; i < 6; i++ { + _, _ = rt.fpVector.Next() + } + + completed, received, total = rt.GetProgress() + err = checkReceivedProgress(completed, received, total, false, 11, numParts) + if err != nil { + t.Error(err) + } + + for i := 0; i < 4; i++ { + _, _ = rt.fpVector.Next() + } + + completed, received, total = rt.GetProgress() + err = checkReceivedProgress(completed, received, total, false, 15, numParts) + if err != nil { + t.Error(err) + } + + _, _ = rt.fpVector.Next() + + completed, received, total = rt.GetProgress() + err = checkReceivedProgress(completed, received, total, true, 16, numParts) + if err != nil { + t.Error(err) + } +} + +// Tests that 5 different callbacks all receive the expected data when +// ReceivedTransfer.CallProgressCB is called at different stages of transfer. +func TestReceivedTransfer_CallProgressCB(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, rt, _ := newEmptyReceivedTransfer(16, 20, kv, t) + + type progressResults struct { + completed bool + received, total uint16 + err error + } + + period := time.Millisecond + + wg := sync.WaitGroup{} + var step0, step1, step2, step3 uint64 + numCallbacks := 5 + + for i := 0; i < numCallbacks; i++ { + progressChan := make(chan progressResults) + + cbFunc := func(completed bool, received, total uint16, err error) { + progressChan <- progressResults{completed, received, total, err} + } + wg.Add(1) + + go func(i int) { + defer wg.Done() + n := 0 + for { + select { + case <-time.NewTimer(time.Second).C: + t.Errorf("Timed out after %s waiting for callback (%d).", + period*5, i) + return + case r := <-progressChan: + switch n { + case 0: + if err := checkReceivedProgress(r.completed, r.received, + r.total, false, 0, rt.numParts); err != nil { + t.Errorf("%2d: %+v", i, err) + } + atomic.AddUint64(&step0, 1) + case 1: + if err := checkReceivedProgress(r.completed, r.received, + r.total, false, 0, rt.numParts); err != nil { + t.Errorf("%2d: %+v", i, err) + } + atomic.AddUint64(&step1, 1) + case 2: + if err := checkReceivedProgress(r.completed, r.received, + r.total, false, 4, rt.numParts); err != nil { + t.Errorf("%2d: %+v", i, err) + } + atomic.AddUint64(&step2, 1) + case 3: + if err := checkReceivedProgress(r.completed, r.received, + r.total, true, 16, rt.numParts); err != nil { + t.Errorf("%2d: %+v", i, err) + } + atomic.AddUint64(&step3, 1) + return + default: + t.Errorf("n (%d) is great than 3 (%d)", n, i) + return + } + n++ + } + } + }(i) + + rt.AddProgressCB(cbFunc, period) + } + + for !atomic.CompareAndSwapUint64(&step0, uint64(numCallbacks), 0) { + } + + rt.CallProgressCB(nil) + + for !atomic.CompareAndSwapUint64(&step1, uint64(numCallbacks), 0) { + } + + for i := 0; i < 4; i++ { + _, _ = rt.fpVector.Next() + } + + rt.CallProgressCB(nil) + + for !atomic.CompareAndSwapUint64(&step2, uint64(numCallbacks), 0) { + } + + for i := 0; i < 12; i++ { + _, _ = rt.fpVector.Next() + } + + rt.CallProgressCB(nil) + + wg.Wait() +} + +// Tests that ReceivedTransfer.AddProgressCB adds an item to the progress +// callback list. +func TestReceivedTransfer_AddProgressCB(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, rt, _ := newEmptyReceivedTransfer(16, 20, kv, t) + + type callbackResults struct { + completed bool + received, total uint16 + err error + } + cbChan := make(chan callbackResults) + cbFunc := interfaces.ReceivedProgressCallback( + func(completed bool, received, total uint16, err error) { + cbChan <- callbackResults{completed, received, total, err} + }) + + done := make(chan bool) + go func() { + select { + case <-time.NewTimer(time.Millisecond).C: + t.Error("Timed out waiting for progress callback to be called.") + case r := <-cbChan: + err := checkReceivedProgress( + r.completed, r.received, r.total, false, 0, 16) + if err != nil { + t.Error(err) + } + if r.err != nil { + t.Errorf("Callback returned an error: %+v", err) + } + } + done <- true + }() + + period := time.Millisecond + rt.AddProgressCB(cbFunc, period) + + if len(rt.progressCallbacks) != 1 { + t.Errorf("Callback list should only have one item."+ + "\nexpected: %d\nreceived: %d", 1, len(rt.progressCallbacks)) + } + + if rt.progressCallbacks[0].period != period { + t.Errorf("Callback has wrong lastCall.\nexpected: %s\nreceived: %s", + period, rt.progressCallbacks[0].period) + } + + if rt.progressCallbacks[0].lastCall != (time.Time{}) { + t.Errorf("Callback has wrong time.\nexpected: %s\nreceived: %s", + time.Time{}, rt.progressCallbacks[0].lastCall) + } + + if rt.progressCallbacks[0].scheduled { + t.Errorf("Callback has wrong scheduled.\nexpected: %t\nreceived: %t", + false, rt.progressCallbacks[0].scheduled) + } + <-done +} + +// Tests that ReceivedTransfer.AddPart adds a part in the correct place in the +// list of parts. +func TestReceivedTransfer_AddPart(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + numFps := uint16(20) + _, rt, _ := newEmptyReceivedTransfer(16, 20, kv, t) + + // Create encrypted part + expectedData := []byte("test") + partNum, fpNum := uint16(1), uint16(1) + encryptedPart, mac, padding := newEncryptedPartData( + rt.key, expectedData, fpNum, t) + + // Add encrypted part + err := rt.AddPart(encryptedPart, padding, mac, partNum, fpNum) + if err != nil { + t.Errorf("AddPart returned an error: %+v", err) + } + + receivedData, exists := rt.receivedParts.parts[partNum] + if !exists { + t.Errorf("Part #%d not found in part map.", partNum) + } else if !bytes.Equal(expectedData, receivedData) { + t.Fatalf("Part data in list does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expectedData, receivedData) + } + + // Check that the fingerprint vector has correct values + if rt.fpVector.GetNumAvailable() != uint32(numFps-1) { + t.Errorf("Incorrect number of available keys in fingerprint list."+ + "\nexpected: %d\nreceived: %d", numFps, rt.fpVector.GetNumAvailable()) + } + if rt.fpVector.GetNumKeys() != uint32(numFps) { + t.Errorf("Incorrect number of keys in fingerprint list."+ + "\nexpected: %d\nreceived: %d", numFps, rt.fpVector.GetNumKeys()) + } + if rt.fpVector.GetNumUsed() != 1 { + t.Errorf("Incorrect number of used keys in fingerprint list."+ + "\nexpected: %d\nreceived: %d", 1, rt.fpVector.GetNumUsed()) + } +} + +// Error path: tests that ReceivedTransfer.AddPart returns the expected error +// when the provided MAC is invalid. +func TestReceivedTransfer_AddPart_DecryptPartError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, rt, _ := newEmptyReceivedTransfer(16, 20, kv, t) + + // Create encrypted part + Data := []byte("test") + partNum, fpNum := uint16(1), uint16(1) + encryptedPart, _, padding := newEncryptedPartData(rt.key, Data, fpNum, t) + mac := []byte("invalidMAC") + + // Add encrypted part + expectedErr := "reconstructed MAC from decrypting does not match MAC from sender" + err := rt.AddPart(encryptedPart, padding, mac, partNum, fpNum) + if err == nil || err.Error() != expectedErr { + t.Errorf("AddPart did not return the expected error when the MAC is "+ + "invalid.\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that ReceivedTransfer.GetFile returns the expected file data. +func TestReceivedTransfer_GetFile(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, rt, expectedFile := newRandomReceivedTransfer(16, 20, kv, t) + + receivedFile, err := rt.GetFile() + if err != nil { + t.Errorf("GetFile returned an error: %+v", err) + } + + // Check that the received file is correct + if !bytes.Equal(expectedFile, receivedFile) { + t.Errorf("File does not match expected.\nexpected: %+v\nreceived: %+v", + expectedFile, receivedFile) + } +} + +// Error path: tests that ReceivedTransfer.GetFile returns the expected error +// when not all file parts are present. +func TestReceivedTransfer_GetFile_MissingPartsError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + numParts := uint16(16) + _, rt, _ := newEmptyReceivedTransfer(numParts, 20, kv, t) + expectedErr := fmt.Sprintf(getFileErr, numParts, numParts) + + _, err := rt.GetFile() + if err == nil || err.Error() != expectedErr { + t.Errorf("GetFile failed to return the expected error when all file "+ + "parts are missing.\nexpected: %s\nreceived: %+v", + expectedErr, err) + } +} + +// Error path: tests that ReceivedTransfer.GetFile returns the expected error +// when the stored transfer MAC does not match the one generated from the file. +func TestReceivedTransfer_GetFile_InvalidMacError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, rt, _ := newRandomReceivedTransfer(16, 20, kv, t) + + rt.transferMAC = []byte("invalidMAC") + + _, err := rt.GetFile() + if err == nil || err.Error() != getTransferMacErr { + t.Errorf("GetFile failed to return the expected error when the file "+ + "MAC cannot be verified.\nexpected: %s\nreceived: %+v", + getTransferMacErr, err) + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Function Testing // +//////////////////////////////////////////////////////////////////////////////// + +// Tests that loadReceivedTransfer returns a ReceivedTransfer that matches the +// original object in memory. +func Test_loadReceivedTransfer(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + tid, expectedRT, _ := newRandomReceivedTransfer(16, 20, kv, t) + + // Create encrypted part + expectedData := []byte("test") + partNum, fpNum := uint16(1), uint16(1) + encryptedPart, mac, padding := newEncryptedPartData( + expectedRT.key, expectedData, fpNum, t) + + // Add encrypted part + err := expectedRT.AddPart(encryptedPart, padding, mac, partNum, fpNum) + if err != nil { + t.Errorf("Failed to add test part: %+v", err) + } + + loadedRT, err := loadReceivedTransfer(tid, kv) + if err != nil { + t.Errorf("loadReceivedTransfer returned an error: %+v", err) + } + + if !reflect.DeepEqual(expectedRT, loadedRT) { + t.Errorf("Loaded ReceivedTransfer does not match expected"+ + ".\nexpected: %+v\nreceived: %+v", expectedRT, loadedRT) + } +} + +// Error path: tests that loadReceivedTransfer returns the expected error when +// no info is in storage to load. +func Test_loadReceivedTransfer_LoadInfoError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + tid := ftCrypto.UnmarshalTransferID([]byte("invalidTransferID")) + expectedErr := strings.Split(loadReceivedStoreErr, "%")[0] + + _, err := loadReceivedTransfer(tid, kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadReceivedTransfer did not return the expected error when "+ + "trying to load info from storage that does not exist."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: tests that loadReceivedTransfer returns the expected error when +// the fingerprint state vector has been deleted from storage. +func Test_loadReceivedTransfer_LoadStateVectorError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + tid, rt, _ := newRandomReceivedTransfer(16, 20, kv, t) + + // Create encrypted part + data := []byte("test") + partNum, fpNum := uint16(1), uint16(1) + encryptedPart, mac, padding := newEncryptedPartData(rt.key, data, fpNum, t) + + // Add encrypted part + err := rt.AddPart(encryptedPart, padding, mac, partNum, fpNum) + if err != nil { + t.Errorf("Failed to add test part: %+v", err) + } + + // Delete fingerprint state vector from storage + err = rt.fpVector.Delete() + if err != nil { + t.Errorf("Failed to delete fingerprint vector: %+v", err) + } + + expectedErr := strings.Split(loadReceiveFpVectorErr, "%")[0] + _, err = loadReceivedTransfer(tid, kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadReceivedTransfer did not return the expected error when "+ + "the state vector was deleted from storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: tests that loadReceivedTransfer returns the expected error when +// the part store has been deleted from storage. +func Test_loadReceivedTransfer_LoadPartStoreError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + tid, rt, _ := newRandomReceivedTransfer(16, 20, kv, t) + + // Create encrypted part + data := []byte("test") + partNum, fpNum := uint16(1), uint16(1) + encryptedPart, mac, padding := newEncryptedPartData(rt.key, data, fpNum, t) + + // Add encrypted part + err := rt.AddPart(encryptedPart, padding, mac, partNum, fpNum) + if err != nil { + t.Errorf("Failed to add test part: %+v", err) + } + + // Delete fingerprint state vector from storage + err = rt.receivedParts.delete() + if err != nil { + t.Errorf("Failed to delete part store: %+v", err) + } + + expectedErr := strings.Split(loadReceivePartStoreErr, "%")[0] + _, err = loadReceivedTransfer(tid, kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadReceivedTransfer did not return the expected error when "+ + "the part store was deleted from storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that ReceivedTransfer.saveInfo saves the expected data to storage. +func TestReceivedTransfer_saveInfo(t *testing.T) { + rt := &ReceivedTransfer{ + key: ftCrypto.UnmarshalTransferKey([]byte("key")), + transferMAC: []byte("transferMAC"), + numParts: 16, + kv: versioned.NewKV(make(ekv.Memstore)), + } + + err := rt.saveInfo() + if err != nil { + t.Fatalf("saveInfo returned an error: %v", err) + } + + vo, err := rt.kv.Get(receivedTransferKey, receivedTransferVersion) + if err != nil { + t.Fatalf("Failed to load ReceivedTransfer from storage: %+v", err) + } + + if !bytes.Equal(rt.marshal(), vo.Data) { + t.Errorf("Marshalled data loaded from storage does not match expected."+ + "\nexpected: %+v\nreceived: %+v", rt.marshal(), vo.Data) + } +} + +// Tests that ReceivedTransfer.loadInfo loads a saved ReceivedTransfer from +// storage. +func TestReceivedTransfer_loadInfo(t *testing.T) { + rt := &ReceivedTransfer{ + key: ftCrypto.UnmarshalTransferKey([]byte("key")), + transferMAC: []byte("transferMAC"), + numParts: 16, + kv: versioned.NewKV(make(ekv.Memstore)), + } + + err := rt.saveInfo() + if err != nil { + t.Errorf("failed to save new ReceivedTransfer to storage: %+v", err) + } + + loadedRT := &ReceivedTransfer{kv: rt.kv} + err = loadedRT.loadInfo() + if err != nil { + t.Errorf("load returned an error: %+v", err) + } + + if !reflect.DeepEqual(rt, loadedRT) { + t.Errorf("Loaded ReceivedTransfer does not match expected."+ + "\nexpected: %+v\nreceived: %+v", rt, loadedRT) + } +} + +// Error path: tests that ReceivedTransfer.loadInfo returns an error when there +// is no object in storage to load +func TestReceivedTransfer_loadInfo_Error(t *testing.T) { + loadedRT := &ReceivedTransfer{kv: versioned.NewKV(make(ekv.Memstore))} + err := loadedRT.loadInfo() + if err == nil { + t.Errorf("Loaded object that should not be in storage: %+v", err) + } +} + +// Tests that ReceivedTransfer.delete removes all data from storage. +func TestReceivedTransfer_delete(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, rt, _ := newRandomReceivedTransfer(16, 20, kv, t) + + // Create encrypted part + expectedData := []byte("test") + partNum, fpNum := uint16(1), uint16(1) + encryptedPart, mac, padding := newEncryptedPartData( + rt.key, expectedData, fpNum, t) + + // Add encrypted part + err := rt.AddPart(encryptedPart, padding, mac, partNum, fpNum) + if err != nil { + t.Fatalf("Failed to add test part: %+v", err) + } + + // Delete everything from storage + err = rt.delete() + if err != nil { + t.Errorf("delete returned an error: %+v", err) + } + + // Check that the SentTransfer info was deleted + err = rt.loadInfo() + if err == nil { + t.Error("Successfully loaded SentTransfer info from storage when it " + + "should have been deleted.") + } + + // Check that the parts store were deleted + _, err = loadPartStore(rt.kv) + if err == nil { + t.Error("Successfully loaded file parts from storage when it should " + + "have been deleted.") + } + + // Check that the fingerprint vector was deleted + _, err = utility.LoadStateVector(rt.kv, receivedFpVectorKey) + if err == nil { + t.Error("Successfully loaded fingerprint vector from storage when it " + + "should have been deleted.") + } +} + +// Tests that ReceivedTransfer.deleteInfo removes the saved ReceivedTransfer +// data from storage. +func TestReceivedTransfer_deleteInfo(t *testing.T) { + rt := &ReceivedTransfer{ + key: ftCrypto.UnmarshalTransferKey([]byte("key")), + transferMAC: []byte("transferMAC"), + numParts: 16, + kv: versioned.NewKV(make(ekv.Memstore)), + } + + // Save from storage + err := rt.saveInfo() + if err != nil { + t.Errorf("failed to save new ReceivedTransfer to storage: %+v", err) + } + + // Delete from storage + err = rt.deleteInfo() + if err != nil { + t.Errorf("deleteInfo returned an error: %+v", err) + } + + // Make sure deleted object cannot be loaded from storage + _, err = rt.kv.Get(receivedTransferKey, receivedTransferVersion) + if err == nil { + t.Error("Loaded object that should be deleted from storage.") + } +} + +// Tests that a ReceivedTransfer marshalled with ReceivedTransfer.marshal and +// then unmarshalled with unmarshalReceivedTransfer matches the original. +func TestReceivedTransfer_marshal_unmarshalReceivedTransfer(t *testing.T) { + rt := &ReceivedTransfer{ + key: ftCrypto.UnmarshalTransferKey([]byte("key")), + transferMAC: []byte("transferMAC"), + fileSize: 256, + numParts: 16, + numFps: 20, + } + + marshaledData := rt.marshal() + key, mac, fileSize, numParts, numFps := unmarshalReceivedTransfer(marshaledData) + + if rt.key != key { + t.Errorf("Failed to get expected key.\nexpected: %s\nreceived: %s", + rt.key, key) + } + + if !bytes.Equal(rt.transferMAC, mac) { + t.Errorf("Failed to get expected transfer MAC."+ + "\nexpected: %v\nreceived: %v", rt.transferMAC, mac) + } + + if rt.fileSize != fileSize { + t.Errorf("Failed to get expected file size."+ + "\nexpected: %d\nreceived: %d", rt.fileSize, fileSize) + } + + if rt.numParts != numParts { + t.Errorf("Failed to get expected number of parts."+ + "\nexpected: %d\nreceived: %d", rt.numParts, numParts) + } + + if rt.numFps != numFps { + t.Errorf("Failed to get expected number of fingerprints."+ + "\nexpected: %d\nreceived: %d", rt.numFps, numFps) + } +} + +// Consistency test: tests that makeReceivedTransferPrefix returns the expected +// prefixes for the provided transfer IDs. +func Test_makeReceivedTransferPrefix_Consistency(t *testing.T) { + prng := NewPrng(42) + expectedPrefixes := []string{ + "FileTransferReceivedTransferStoreU4x/lrFkvxuXu59LtHLon1sUhPJSCcnZND6SugndnVI=", + "FileTransferReceivedTransferStore39ebTXZCm2F6DJ+fDTulWwzA1hRMiIU1hBrL4HCbB1g=", + "FileTransferReceivedTransferStoreCD9h03W8ArQd9PkZKeGP2p5vguVOdI6B555LvW/jTNw=", + "FileTransferReceivedTransferStoreuoQ+6NY+jE/+HOvqVG2PrBPdGqwEzi6ih3xVec+ix44=", + "FileTransferReceivedTransferStoreGwuvrogbgqdREIpC7TyQPKpDRlp4YgYWl4rtDOPGxPM=", + "FileTransferReceivedTransferStorernvD4ElbVxL+/b4MECiH4QDazS2IX2kstgfaAKEcHHA=", + "FileTransferReceivedTransferStoreceeWotwtwlpbdLLhKXBeJz8FySMmgo4rBW44F2WOEGE=", + "FileTransferReceivedTransferStoreSYlH/fNEQQ7UwRYCP6jjV2tv7Sf/iXS6wMr9mtBWkrE=", + "FileTransferReceivedTransferStoreNhnnOJZN/ceejVNDc2Yc/WbXT+weG4lJGrcjbkt1IWI=", + "FileTransferReceivedTransferStorekM8r60LDyicyhWDxqsBnzqbov0bUqytGgEAsX7KCDog=", + } + + for i, expected := range expectedPrefixes { + tid, _ := ftCrypto.NewTransferID(prng) + prefix := makeReceivedTransferPrefix(tid) + + if expected != prefix { + t.Errorf("New ReceivedTransfer prefix does not match expected (%d)."+ + "\nexpected: %s\nreceived: %s", i, expected, prefix) + } + } +} + +// newRandomReceivedTransfer generates a new ReceivedTransfer with random data. +// Returns the generated transfer ID, new ReceivedTransfer, and the full file. +func newRandomReceivedTransfer(numParts, numFps uint16, kv *versioned.KV, + t *testing.T) (ftCrypto.TransferID, *ReceivedTransfer, []byte) { + + prng := NewPrng(42) + tid, _ := ftCrypto.NewTransferID(prng) + key, _ := ftCrypto.NewTransferKey(prng) + parts, fileData := newRandomPartStore(numParts, kv, prng, t) + fileSize := uint32(len(fileData)) + mac := ftCrypto.CreateTransferMAC(fileData, key) + + rt, err := NewReceivedTransfer(tid, key, mac, fileSize, numParts, numFps, kv) + if err != nil { + t.Errorf("Failed to create new ReceivedTransfer: %+v", err) + } + + for partNum, part := range parts.parts { + encryptedPart, mac, padding := newEncryptedPartData(key, part, partNum, t) + err := rt.AddPart(encryptedPart, padding, mac, partNum, partNum) + if err != nil { + t.Errorf("Failed to add part #%d: %+v", partNum, err) + } + } + + return tid, rt, fileData +} + +// newRandomReceivedTransfer generates a new empty ReceivedTransfer. Returns the +// generated transfer ID, new ReceivedTransfer, and the full file. +func newEmptyReceivedTransfer(numParts, numFps uint16, kv *versioned.KV, + t *testing.T) (ftCrypto.TransferID, *ReceivedTransfer, []byte) { + + prng := NewPrng(42) + tid, _ := ftCrypto.NewTransferID(prng) + key, _ := ftCrypto.NewTransferKey(prng) + _, fileData := newRandomPartStore(numParts, kv, prng, t) + fileSize := uint32(len(fileData)) + + mac := ftCrypto.CreateTransferMAC(fileData, key) + + rt, err := NewReceivedTransfer(tid, key, mac, fileSize, numParts, numFps, kv) + if err != nil { + t.Errorf("Failed to create new ReceivedTransfer: %+v", err) + } + + return tid, rt, fileData +} + +// newEncryptedPartData encrypts the part data and returns the encrypted part +// its MAC, and its padding. +func newEncryptedPartData(key ftCrypto.TransferKey, part []byte, fpNum uint16, + t *testing.T) ([]byte, []byte, []byte) { + // Create encrypted part + prng := NewPrng(42) + encPart, mac, padding, err := ftCrypto.EncryptPart(key, part, fpNum, prng) + if err != nil { + t.Fatalf("Failed to encrypt data: %+v", err) + } + + return encPart, mac, padding +} diff --git a/storage/fileTransfer/receivedCallbackTracker.go b/storage/fileTransfer/receivedCallbackTracker.go new file mode 100644 index 0000000000000000000000000000000000000000..67b454b5d21f073dec450c7d5f3d3e84103d4367 --- /dev/null +++ b/storage/fileTransfer/receivedCallbackTracker.go @@ -0,0 +1,95 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/xx_network/primitives/netTime" + "sync" + "time" +) + +// receivedCallbackTracker tracks the interfaces.ReceivedProgressCallback and +// information on when to call it. The callback will be called on each file part +// reception, unless the time since the lastCall is smaller than the period. In +// that case, a callback is marked as scheduled and waits to be called at the +// end of the period. A callback is called once every period, regardless of the +// number of receptions that occur. +type receivedCallbackTracker struct { + period time.Duration // How often to call the callback + lastCall time.Time // Timestamp of the last call + scheduled bool // Denotes if callback call is scheduled + cb interfaces.ReceivedProgressCallback + mux sync.RWMutex +} + +// newReceivedCallbackTracker creates a new and unused receivedCallbackTracker. +func newReceivedCallbackTracker(cb interfaces.ReceivedProgressCallback, + period time.Duration) *receivedCallbackTracker { + return &receivedCallbackTracker{ + period: period, + lastCall: time.Time{}, + scheduled: false, + cb: cb, + } +} + +// call triggers the progress callback with the most recent progress from the +// receivedProgressTracker. If a callback has been called within the last +// period, then a new call is scheduled to occur at the beginning of the next +// period. If a call is already scheduled, then nothing happens; when the +// callback is finally called, it will do so with the most recent changes. +func (rct *receivedCallbackTracker) call(tracker receivedProgressTracker, err error) { + rct.mux.RLock() + // Exit if a callback is already scheduled + if rct.scheduled { + rct.mux.RUnlock() + return + } + + rct.mux.RUnlock() + rct.mux.Lock() + defer rct.mux.Unlock() + + if rct.scheduled { + return + } + + // Check if a callback has occurred within the last period + timeSinceLastCall := netTime.Since(rct.lastCall) + if timeSinceLastCall > rct.period { + // If no callback occurred, then trigger the callback now + rct.callNow(tracker, err) + rct.lastCall = netTime.Now() + } else { + // If a callback did occur, then schedule a new callback to occur at the + // start of the next period + rct.scheduled = true + go func() { + select { + case <-time.NewTimer(rct.period - timeSinceLastCall).C: + rct.mux.Lock() + rct.callNow(tracker, err) + rct.lastCall = netTime.Now() + rct.scheduled = false + rct.mux.Unlock() + } + }() + } +} + +// callNow calls the callback immediately regardless of the schedule or period. +func (rct *receivedCallbackTracker) callNow(tracker receivedProgressTracker, err error) { + completed, received, total := tracker.GetProgress() + go rct.cb(completed, received, total, err) +} + +// receivedProgressTracker interface tracks the progress of a transfer. +type receivedProgressTracker interface { + GetProgress() (completed bool, received, total uint16) +} diff --git a/storage/fileTransfer/receivedCallbackTracker_test.go b/storage/fileTransfer/receivedCallbackTracker_test.go new file mode 100644 index 0000000000000000000000000000000000000000..bd3e8a479b0d6d75392321880d872c7c3d79fc03 --- /dev/null +++ b/storage/fileTransfer/receivedCallbackTracker_test.go @@ -0,0 +1,130 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "reflect" + "testing" + "time" +) + +// Tests that newReceivedCallbackTracker returns the expected +// receivedCallbackTracker and that the callback triggers correctly. +func Test_newReceivedCallbackTracker(t *testing.T) { + type cbFields struct { + completed bool + received, total uint16 + err error + } + + cbChan := make(chan cbFields) + cbFunc := func(completed bool, received, total uint16, err error) { + cbChan <- cbFields{completed, received, total, err} + } + + expectedRCT := &receivedCallbackTracker{ + period: time.Millisecond, + lastCall: time.Time{}, + scheduled: false, + cb: cbFunc, + } + + receivedSCT := newReceivedCallbackTracker(expectedRCT.cb, expectedRCT.period) + + go receivedSCT.cb(false, 0, 0, nil) + + select { + case <-time.NewTimer(time.Millisecond).C: + t.Error("Timed out waiting for callback to be called.") + case r := <-cbChan: + err := checkReceivedProgress(r.completed, r.received, r.total, false, 0, 0) + if err != nil { + t.Error(err) + } + } + + // Nil the callbacks so that DeepEqual works + receivedSCT.cb = nil + expectedRCT.cb = nil + + if !reflect.DeepEqual(expectedRCT, receivedSCT) { + t.Errorf("New receivedCallbackTracker does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expectedRCT, receivedSCT) + } +} + +// Tests that +func Test_receivedCallbackTracker_call(t *testing.T) { + type cbFields struct { + completed bool + received, total uint16 + err error + } + + cbChan := make(chan cbFields) + cbFunc := func(completed bool, received, total uint16, err error) { + cbChan <- cbFields{completed, received, total, err} + } + + sct := newReceivedCallbackTracker(cbFunc, 50*time.Millisecond) + + tracker := testReceiveTrack{false, 1, 3} + sct.call(tracker, nil) + + select { + case <-time.NewTimer(10 * time.Millisecond).C: + t.Error("Timed out waiting for callback to be called.") + case r := <-cbChan: + err := checkReceivedProgress(r.completed, r.received, r.total, false, 1, 3) + if err != nil { + t.Error(err) + } + } + + tracker = testReceiveTrack{true, 3, 3} + sct.call(tracker, nil) + + select { + case <-time.NewTimer(10 * time.Millisecond).C: + if !sct.scheduled { + t.Error("Callback should be scheduled.") + } + case r := <-cbChan: + t.Errorf("Received message when period of %s should not have been "+ + "reached: %+v", sct.period, r) + } + + sct.call(tracker, nil) + + select { + case <-time.NewTimer(60 * time.Millisecond).C: + t.Error("Timed out waiting for callback to be called.") + case r := <-cbChan: + err := checkReceivedProgress(r.completed, r.received, r.total, true, 3, 3) + if err != nil { + t.Error(err) + } + } +} + +// Tests that ReceivedTransfer satisfies the receivedProgressTracker interface. +func TestReceivedTransfer_ReceivedProgressTrackerInterface(t *testing.T) { + var _ receivedProgressTracker = &ReceivedTransfer{} +} + +// testReceiveTrack is a test structure that satisfies the +// receivedProgressTracker interface. +type testReceiveTrack struct { + completed bool + received, total uint16 +} + +// GetProgress returns the values in the testTrack. +func (trt testReceiveTrack) GetProgress() (completed bool, received, total uint16) { + return trt.completed, trt.received, trt.total +} diff --git a/storage/fileTransfer/sentCallbackTracker.go b/storage/fileTransfer/sentCallbackTracker.go new file mode 100644 index 0000000000000000000000000000000000000000..3309eaef70af1026b09b4f0cbfeec344f71387f0 --- /dev/null +++ b/storage/fileTransfer/sentCallbackTracker.go @@ -0,0 +1,109 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/xx_network/primitives/netTime" + "sync" + "time" +) + +// sentCallbackTracker tracks the interfaces.SentProgressCallback and +// information on when to call it. The callback will be called on each send, +// unless the time since the lastCall is smaller than the period. In that case, +// a callback is marked as scheduled and waits to be called at the end of the +// period. A callback is called once every period, regardless of the number of +// sends that occur. +type sentCallbackTracker struct { + period time.Duration // How often to call the callback + lastCall time.Time // Timestamp of the last call + scheduled bool // Denotes if callback call is scheduled + cb interfaces.SentProgressCallback + mux sync.RWMutex +} + +// newSentCallbackTracker creates a new and unused sentCallbackTracker. +func newSentCallbackTracker(cb interfaces.SentProgressCallback, + period time.Duration) *sentCallbackTracker { + return &sentCallbackTracker{ + period: period, + lastCall: time.Time{}, + scheduled: false, + cb: cb, + } +} + +// call triggers the progress callback with the most recent progress from the +// sentProgressTracker. If a callback has been called within the last period, +// then a new call is scheduled to occur at the beginning of the next period. If +// a call is already scheduled, then nothing happens; when the callback is +// finally called, it will do so with the most recent changes. +func (sct *sentCallbackTracker) call(tracker sentProgressTracker, err error) { + sct.mux.RLock() + // Exit if a callback is already scheduled + if sct.scheduled { + sct.mux.RUnlock() + return + } + + sct.mux.RUnlock() + sct.mux.Lock() + defer sct.mux.Unlock() + + if sct.scheduled { + return + } + + // Check if a callback has occurred within the last period + timeSinceLastCall := netTime.Since(sct.lastCall) + if timeSinceLastCall > sct.period { + // If no callback occurred, then trigger the callback now + sct.callNow(tracker, err) + sct.lastCall = netTime.Now() + } else { + // If a callback did occur, then schedule a new callback to occur at the + // start of the next period + sct.scheduled = true + go func() { + select { + case <-time.NewTimer(sct.period - timeSinceLastCall).C: + sct.mux.Lock() + sct.callNow(tracker, err) + sct.lastCall = netTime.Now() + sct.scheduled = false + sct.mux.Unlock() + } + }() + } +} + +// callNow calls the callback immediately regardless of the schedule or period. +func (sct *sentCallbackTracker) callNow(tracker sentProgressTracker, err error) { + completed, sent, arrived, total := tracker.GetProgress() + go sct.cb(completed, sent, arrived, total, err) +} + +// callNowUnsafe calls the callback immediately regardless of the schedule or +// period without taking a thread lock. This function should be used if a lock +// is already taken on the sentProgressTracker. +func (sct *sentCallbackTracker) callNowUnsafe(tracker sentProgressTracker, err error) { + completed, sent, arrived, total := tracker.getProgress() + go sct.cb(completed, sent, arrived, total, err) +} + +// sentProgressTracker interface tracks the progress of a transfer. +type sentProgressTracker interface { + // GetProgress returns the sent transfer progress in a thread-safe manner. + GetProgress() (completed bool, sent, arrived, total uint16) + + // getProgress returns the sent transfer progress in a thread-unsafe manner. + // This function should be used if a lock is already taken on the sent + // transfer. + getProgress() (completed bool, sent, arrived, total uint16) +} diff --git a/storage/fileTransfer/sentCallbackTracker_test.go b/storage/fileTransfer/sentCallbackTracker_test.go new file mode 100644 index 0000000000000000000000000000000000000000..925268928c8ae0ad89539af0600d16e20a183533 --- /dev/null +++ b/storage/fileTransfer/sentCallbackTracker_test.go @@ -0,0 +1,137 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "reflect" + "testing" + "time" +) + +// Tests that newSentCallbackTracker returns the expected sentCallbackTracker +// and that the callback triggers correctly. +func Test_newSentCallbackTracker(t *testing.T) { + type cbFields struct { + completed bool + sent, arrived, total uint16 + err error + } + + cbChan := make(chan cbFields) + cbFunc := func(completed bool, sent, arrived, total uint16, err error) { + cbChan <- cbFields{completed, sent, arrived, total, err} + } + + expectedSCT := &sentCallbackTracker{ + period: time.Millisecond, + lastCall: time.Time{}, + scheduled: false, + cb: cbFunc, + } + + receivedSCT := newSentCallbackTracker(expectedSCT.cb, expectedSCT.period) + + go receivedSCT.cb(false, 0, 0, 0, nil) + + select { + case <-time.NewTimer(time.Millisecond).C: + t.Error("Timed out waiting for callback to be called.") + case r := <-cbChan: + err := checkSentProgress( + r.completed, r.sent, r.arrived, r.total, false, 0, 0, 0) + if err != nil { + t.Error(err) + } + } + + // Nil the callbacks so that DeepEqual works + receivedSCT.cb = nil + expectedSCT.cb = nil + + if !reflect.DeepEqual(expectedSCT, receivedSCT) { + t.Errorf("New sentCallbackTracker does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expectedSCT, receivedSCT) + } +} + +// Tests that +func Test_sentCallbackTracker_call(t *testing.T) { + type cbFields struct { + completed bool + sent, arrived, total uint16 + err error + } + + cbChan := make(chan cbFields) + cbFunc := func(completed bool, sent, arrived, total uint16, err error) { + cbChan <- cbFields{completed, sent, arrived, total, err} + } + + sct := newSentCallbackTracker(cbFunc, 50*time.Millisecond) + + tracker := testSentTrack{false, 1, 2, 3} + sct.call(tracker, nil) + + select { + case <-time.NewTimer(10 * time.Millisecond).C: + t.Error("Timed out waiting for callback to be called.") + case r := <-cbChan: + err := checkSentProgress( + r.completed, r.sent, r.arrived, r.total, false, 1, 2, 3) + if err != nil { + t.Error(err) + } + } + + tracker = testSentTrack{false, 1, 2, 3} + sct.call(tracker, nil) + + select { + case <-time.NewTimer(10 * time.Millisecond).C: + if !sct.scheduled { + t.Error("Callback should be scheduled.") + } + case r := <-cbChan: + t.Errorf("Received message when period of %s should not have been "+ + "reached: %+v", sct.period, r) + } + + sct.call(tracker, nil) + + select { + case <-time.NewTimer(60 * time.Millisecond).C: + t.Error("Timed out waiting for callback to be called.") + case r := <-cbChan: + err := checkSentProgress( + r.completed, r.sent, r.arrived, r.total, false, 1, 2, 3) + if err != nil { + t.Error(err) + } + } +} + +// Tests that SentTransfer satisfies the sentProgressTracker interface. +func TestSentTransfer_SentProgressTrackerInterface(t *testing.T) { + var _ sentProgressTracker = &SentTransfer{} +} + +// testSentTrack is a test structure that satisfies the sentProgressTracker +// interface. +type testSentTrack struct { + completed bool + sent, arrived, total uint16 +} + +func (tst testSentTrack) getProgress() (completed bool, sent, arrived, total uint16) { + return tst.completed, tst.sent, tst.arrived, tst.total +} + +// GetProgress returns the values in the testTrack. +func (tst testSentTrack) GetProgress() (completed bool, sent, arrived, total uint16) { + return tst.completed, tst.sent, tst.arrived, tst.total +} diff --git a/storage/fileTransfer/sentFileTransfers.go b/storage/fileTransfer/sentFileTransfers.go new file mode 100644 index 0000000000000000000000000000000000000000..96aa1886cc4313b6acb43218adfb9eae63bdb307 --- /dev/null +++ b/storage/fileTransfer/sentFileTransfers.go @@ -0,0 +1,254 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "github.com/pkg/errors" + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "sync" + "time" +) + +// Storage keys and versions. +const ( + sentFileTransfersPrefix = "SentFileTransfersStore" + sentFileTransfersKey = "SentFileTransfers" + sentFileTransfersVersion = 0 +) + +// Error messages. +const ( + saveSentTransfersListErr = "failed to save list of sent items in transfer map to storage: %+v" + loadSentTransfersListErr = "failed to load list of sent items in transfer map from storage: %+v" + loadSentTransfersErr = "failed to load sent transfers from storage: %+v" + + newSentTransferErr = "failed to create new sent transfer: %+v" + getSentTransferErr = "sent file transfer not found" + deleteSentTransferErr = "failed to delete sent transfer with ID %s from store: %+v" +) + +// SentFileTransfers contains information for tracking sent file transfers. +type SentFileTransfers struct { + transfers map[ftCrypto.TransferID]*SentTransfer + mux sync.Mutex + kv *versioned.KV +} + +// NewSentFileTransfers creates a new SentFileTransfers with an empty map. +func NewSentFileTransfers(kv *versioned.KV) (*SentFileTransfers, error) { + sft := &SentFileTransfers{ + transfers: make(map[ftCrypto.TransferID]*SentTransfer), + kv: kv.Prefix(sentFileTransfersPrefix), + } + + return sft, sft.saveTransfersList() +} + +// AddTransfer creates a new empty SentTransfer and adds it to the transfers +// map. +func (sft *SentFileTransfers) AddTransfer(recipient *id.ID, + key ftCrypto.TransferKey, parts [][]byte, numFps uint16, + progressCB interfaces.SentProgressCallback, period time.Duration, + rng csprng.Source) (ftCrypto.TransferID, error) { + + sft.mux.Lock() + defer sft.mux.Unlock() + + // Generate new transfer ID + tid, err := ftCrypto.NewTransferID(rng) + if err != nil { + return tid, errors.Errorf(addTransferNewIdErr, err) + } + + // Generate a new SentTransfer and add it to the map + sft.transfers[tid], err = NewSentTransfer( + recipient, tid, key, parts, numFps, progressCB, period, sft.kv) + if err != nil { + return tid, errors.Errorf(newSentTransferErr, err) + } + + // Update list of transfers in storage + err = sft.saveTransfersList() + if err != nil { + return tid, errors.Errorf(saveSentTransfersListErr, err) + } + + return tid, nil +} + +// GetTransfer returns the SentTransfer with the given transfer ID. An error is +// returned if no corresponding transfer is found. +func (sft *SentFileTransfers) GetTransfer(tid ftCrypto.TransferID) ( + *SentTransfer, error) { + sft.mux.Lock() + defer sft.mux.Unlock() + + rt, exists := sft.transfers[tid] + if !exists { + return nil, errors.New(getSentTransferErr) + } + + return rt, nil +} + +// DeleteTransfer removes the SentTransfer with the associated transfer ID +// from memory and storage. +func (sft *SentFileTransfers) DeleteTransfer(tid ftCrypto.TransferID) error { + sft.mux.Lock() + defer sft.mux.Unlock() + + // Return an error if the transfer does not exist + _, exists := sft.transfers[tid] + if !exists { + return errors.New(getSentTransferErr) + } + + // Delete all data the transfer saved to storage + err := sft.transfers[tid].delete() + if err != nil { + return errors.Errorf(deleteSentTransferErr, tid, err) + } + + // Delete the transfer from memory + delete(sft.transfers, tid) + + // Update the transfers list for the removed transfer + err = sft.saveTransfersList() + if err != nil { + return errors.Errorf(saveSentTransfersListErr, err) + } + + return nil +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// LoadSentFileTransfers loads all SentFileTransfers from storage. +func LoadSentFileTransfers(kv *versioned.KV) (*SentFileTransfers, error) { + sft := &SentFileTransfers{ + transfers: make(map[ftCrypto.TransferID]*SentTransfer), + kv: kv.Prefix(sentFileTransfersPrefix), + } + + // Get the list of transfer IDs corresponding to each sent transfer from + // storage + transfersList, err := sft.loadTransfersList() + if err != nil { + return nil, errors.Errorf(loadSentTransfersListErr, err) + } + + // Load each transfer in the list from storage into the map + err = sft.loadTransfers(transfersList) + if err != nil { + return nil, errors.Errorf(loadSentTransfersErr, err) + } + + return sft, nil +} + +// NewOrLoadSentFileTransfers loads all SentFileTransfers from storage, if they +// exist. Otherwise, a new SentFileTransfers is returned. +func NewOrLoadSentFileTransfers(kv *versioned.KV) (*SentFileTransfers, error) { + sft := &SentFileTransfers{ + transfers: make(map[ftCrypto.TransferID]*SentTransfer), + kv: kv.Prefix(sentFileTransfersPrefix), + } + + // If the transfer list cannot be loaded from storage, then create a new + // SentFileTransfers + vo, err := sft.kv.Get(sentFileTransfersKey, sentFileTransfersVersion) + if err != nil { + return NewSentFileTransfers(kv) + } + + // Unmarshal data into list of saved transfer IDs + transfersList := unmarshalTransfersList(vo.Data) + + // Load each transfer in the list from storage into the map + err = sft.loadTransfers(transfersList) + if err != nil { + return nil, errors.Errorf(loadSentTransfersErr, err) + } + + return sft, nil +} + +// saveTransfersList saves a list of items in the transfers map to storage. +func (sft *SentFileTransfers) saveTransfersList() error { + // Create new versioned object with a list of items in the transfers map + obj := &versioned.Object{ + Version: sentFileTransfersVersion, + Timestamp: netTime.Now(), + Data: sft.marshalTransfersList(), + } + + // Save list of items in the transfers map to storage + return sft.kv.Set(sentFileTransfersKey, sentFileTransfersVersion, obj) +} + +// loadTransfersList gets the list of transfer IDs corresponding to each saved +// sent transfer from storage. +func (sft *SentFileTransfers) loadTransfersList() ([]ftCrypto.TransferID, error) { + // Get transfers list from storage + vo, err := sft.kv.Get(sentFileTransfersKey, sentFileTransfersVersion) + if err != nil { + return nil, err + } + + // Unmarshal data into list of saved transfer IDs + return unmarshalTransfersList(vo.Data), nil +} + +// loadTransfers loads each SentTransfer from the list and adds them to the map. +func (sft *SentFileTransfers) loadTransfers(list []ftCrypto.TransferID) error { + var err error + + // Load each sentTransfer from storage into the map + for _, tid := range list { + sft.transfers[tid], err = loadSentTransfer(tid, sft.kv) + if err != nil { + return err + } + } + + return nil +} + +// marshalTransfersList creates a list of all transfer IDs in the transfers map +// and serialises it. +func (sft *SentFileTransfers) marshalTransfersList() []byte { + buff := bytes.NewBuffer(nil) + buff.Grow(ftCrypto.TransferIdLength * len(sft.transfers)) + + for tid := range sft.transfers { + buff.Write(tid.Bytes()) + } + + return buff.Bytes() +} + +// unmarshalTransfersList deserializes a byte slice into a list of transfer IDs. +func unmarshalTransfersList(b []byte) []ftCrypto.TransferID { + buff := bytes.NewBuffer(b) + list := make([]ftCrypto.TransferID, 0, buff.Len()/ftCrypto.TransferIdLength) + + const size = ftCrypto.TransferIdLength + for n := buff.Next(size); len(n) == size; n = buff.Next(size) { + list = append(list, ftCrypto.UnmarshalTransferID(n)) + } + + return list +} diff --git a/storage/fileTransfer/sentFileTransfers_test.go b/storage/fileTransfer/sentFileTransfers_test.go new file mode 100644 index 0000000000000000000000000000000000000000..cae6426eddf056a9f4a14a3d4158f6e6c2c189c8 --- /dev/null +++ b/storage/fileTransfer/sentFileTransfers_test.go @@ -0,0 +1,484 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "reflect" + "strings" + "testing" +) + +// Tests that NewSentFileTransfers creates a new object with empty maps and that +// it is saved to storage +func TestNewSentFileTransfers(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + expectedSFT := &SentFileTransfers{ + transfers: make(map[ftCrypto.TransferID]*SentTransfer), + kv: kv.Prefix(sentFileTransfersPrefix), + } + + sft, err := NewSentFileTransfers(kv) + + // Check that the new SentFileTransfers matches the expected + if !reflect.DeepEqual(expectedSFT, sft) { + t.Errorf("New SentFileTransfers does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expectedSFT, sft) + } + + // Ensure that the transfer list is saved to storage + _, err = expectedSFT.kv.Get(sentFileTransfersKey, sentFileTransfersVersion) + if err != nil { + t.Errorf("Failed to load transfer list from storage: %+v", err) + } +} + +// Tests that SentFileTransfers.AddTransfer adds a new transfer to the map and +// that its ID is saved to the list in storage. +func TestSentFileTransfers_AddTransfer(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + sft, err := NewSentFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new SentFileTransfers: %+v", err) + } + + // Generate info for new transfer + recipient := id.NewIdFromString("recipientID", id.User, t) + key, _ := ftCrypto.NewTransferKey(prng) + parts, _ := newRandomPartSlice(16, prng, t) + numFps := uint16(24) + + // Add the transfer + tid, err := sft.AddTransfer(recipient, key, parts, numFps, nil, 0, prng) + if err != nil { + t.Errorf("AddTransfer returned an error: %+v", err) + } + + _, exists := sft.transfers[tid] + if !exists { + t.Errorf("New transfer %s does not exist in map.", tid) + } + + list, err := sft.loadTransfersList() + if err != nil { + t.Errorf("Failed to load transfer list from storage: %+v", err) + } + + if list[0] != tid { + t.Errorf("Transfer ID saved to storage does not match ID in memory."+ + "\nexpected: %s\nreceived: %s", tid, list[0]) + } +} + +// Error path: tests that SentFileTransfers.AddTransfer returns the expected +// error when the PRNG returns an error. +func TestSentFileTransfers_AddTransfer_NewTransferIdRngError(t *testing.T) { + prng := NewPrngErr() + kv := versioned.NewKV(make(ekv.Memstore)) + sft, err := NewSentFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new SentFileTransfers: %+v", err) + } + + // Add the transfer + expectedErr := strings.Split(addTransferNewIdErr, "%")[0] + _, err = sft.AddTransfer(nil, ftCrypto.TransferKey{}, nil, 0, nil, 0, prng) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("AddTransfer did not return the expected error when the PRNG "+ + "should have errored.\nexpected: %s\nrecieved: %+v", expectedErr, err) + } +} + +// Tests that SentFileTransfers.GetTransfer returns the expected transfer from +// the map. +func TestSentFileTransfers_GetTransfer(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + sft, err := NewSentFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new SentFileTransfers: %+v", err) + } + + // Generate info for new transfer + recipient := id.NewIdFromString("recipientID", id.User, t) + key, _ := ftCrypto.NewTransferKey(prng) + parts, _ := newRandomPartSlice(16, prng, t) + numFps := uint16(24) + + tid, err := sft.AddTransfer(recipient, key, parts, numFps, nil, 0, prng) + if err != nil { + t.Errorf("AddTransfer returned an error: %+v", err) + } + + transfer, err := sft.GetTransfer(tid) + if err != nil { + t.Errorf("GetTransfer returned an error: %+v", err) + } + + if !reflect.DeepEqual(sft.transfers[tid], transfer) { + t.Errorf("Received transfer does not match expected."+ + "\nexpected: %+v\nreceived: %+v", sft.transfers[tid], transfer) + } +} + +// Error path: tests that SentFileTransfers.GetTransfer returns the expected +// error when the map is empty/there is no transfer with the given transfer ID. +func TestSentFileTransfers_GetTransfer_NoTransferError(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + sft, err := NewSentFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new SentFileTransfers: %+v", err) + } + + tid, _ := ftCrypto.NewTransferID(prng) + + _, err = sft.GetTransfer(tid) + if err == nil || err.Error() != getSentTransferErr { + t.Errorf("GetTransfer did not return the expected error when it is "+ + "empty.\nexpected: %s\nreceived: %+v", getSentTransferErr, err) + } +} + +// Tests that SentFileTransfers.DeleteTransfer removes the transfer from the map +// in memory and from the list in storage. +func TestSentFileTransfers_DeleteTransfer(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + sft, err := NewSentFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new SentFileTransfers: %+v", err) + } + + // Generate info for new transfer + recipient := id.NewIdFromString("recipientID", id.User, t) + key, _ := ftCrypto.NewTransferKey(prng) + parts, _ := newRandomPartSlice(16, prng, t) + numFps := uint16(24) + + tid, err := sft.AddTransfer(recipient, key, parts, numFps, nil, 0, prng) + if err != nil { + t.Errorf("AddTransfer returned an error: %+v", err) + } + + err = sft.DeleteTransfer(tid) + if err != nil { + t.Errorf("DeleteTransfer returned an error: %+v", err) + } + + transfer, err := sft.GetTransfer(tid) + if err == nil { + t.Errorf("No error getting transfer that should be deleted: %+v", transfer) + } + + list, err := sft.loadTransfersList() + if err != nil { + t.Errorf("Failed to load transfer list from storage: %+v", err) + } + + if len(list) > 0 { + t.Errorf("Transfer ID list in storage not empty: %+v", list) + } +} + +// Error path: tests that SentFileTransfers.DeleteTransfer returns the expected +// error when the map is empty/there is no transfer with the given transfer ID. +func TestSentFileTransfers_DeleteTransfer_NoTransferError(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + sft, err := NewSentFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to create new SentFileTransfers: %+v", err) + } + + tid, _ := ftCrypto.NewTransferID(prng) + + err = sft.DeleteTransfer(tid) + if err == nil || err.Error() != getSentTransferErr { + t.Errorf("DeleteTransfer did not return the expected error when it is "+ + "empty.\nexpected: %s\nreceived: %+v", getSentTransferErr, err) + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// Tests that the SentFileTransfers loaded from storage by LoadSentFileTransfers +// matches the original in memory. +func TestLoadSentFileTransfers(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + sft, err := NewSentFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to make new SentFileTransfers: %+v", err) + } + + // Add 10 transfers to map in memory + list := make([]ftCrypto.TransferID, 10) + for i := range list { + tid, st := newRandomSentTransfer(16, 24, sft.kv, t) + sft.transfers[tid] = st + list[i] = tid + } + + // Save list to storage + if err = sft.saveTransfersList(); err != nil { + t.Errorf("Faileds to save transfers list: %+v", err) + } + + // Load SentFileTransfers from storage + loadedSFT, err := LoadSentFileTransfers(kv) + if err != nil { + t.Errorf("LoadSentFileTransfers returned an error: %+v", err) + } + + // Equalize all progressCallbacks because reflect.DeepEqual does not seem + // to work on function pointers + for _, tid := range list { + loadedSFT.transfers[tid].progressCallbacks = sft.transfers[tid].progressCallbacks + } + + if !reflect.DeepEqual(sft, loadedSFT) { + t.Errorf("Loaded SentFileTransfers does not match original in memory."+ + "\nexpected: %+v\nreceived: %+v", sft, loadedSFT) + } +} + +// Error path: tests that LoadSentFileTransfers returns the expected error when +// the transfer list cannot be loaded from storage. +func TestLoadSentFileTransfers_NoListInStorageError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + expectedErr := strings.Split(loadSentTransfersListErr, "%")[0] + + // Load SentFileTransfers from storage + _, err := LoadSentFileTransfers(kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("LoadSentFileTransfers did not return the expected error "+ + "when there is no transfer list saved in storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: tests that LoadSentFileTransfers returns the expected error when +// the first transfer loaded from storage does not exist. +func TestLoadSentFileTransfers_NoTransferInStorageError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + expectedErr := strings.Split(loadSentTransfersErr, "%")[0] + + // Save list of one transfer ID to storage + obj := &versioned.Object{ + Version: sentFileTransfersVersion, + Timestamp: netTime.Now(), + Data: ftCrypto.UnmarshalTransferID([]byte("testID_01")).Bytes(), + } + err := kv.Prefix(sentFileTransfersPrefix).Set( + sentFileTransfersKey, sentFileTransfersVersion, obj) + + // Load SentFileTransfers from storage + _, err = LoadSentFileTransfers(kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("LoadSentFileTransfers did not return the expected error "+ + "when there is no transfer saved in storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that the SentFileTransfers loaded from storage by +// NewOrLoadSentFileTransfers matches the original in memory. +func TestNewOrLoadSentFileTransfers(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + sft, err := NewSentFileTransfers(kv) + if err != nil { + t.Fatalf("Failed to make new SentFileTransfers: %+v", err) + } + + // Add 10 transfers to map in memory + list := make([]ftCrypto.TransferID, 10) + for i := range list { + tid, st := newRandomSentTransfer(16, 24, sft.kv, t) + sft.transfers[tid] = st + list[i] = tid + } + + // Save list to storage + if err = sft.saveTransfersList(); err != nil { + t.Errorf("Faileds to save transfers list: %+v", err) + } + + // Load SentFileTransfers from storage + loadedSFT, err := NewOrLoadSentFileTransfers(kv) + if err != nil { + t.Errorf("NewOrLoadSentFileTransfers returned an error: %+v", err) + } + + // Equalize all progressCallbacks because reflect.DeepEqual does not seem + // to work on function pointers + for _, tid := range list { + loadedSFT.transfers[tid].progressCallbacks = sft.transfers[tid].progressCallbacks + } + + if !reflect.DeepEqual(sft, loadedSFT) { + t.Errorf("Loaded SentFileTransfers does not match original in memory."+ + "\nexpected: %+v\nreceived: %+v", sft, loadedSFT) + } +} + +// Tests that NewOrLoadSentFileTransfers returns a new SentFileTransfers when +// there is none in storage. +func TestNewOrLoadSentFileTransfers_NewSentFileTransfers(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + // Load SentFileTransfers from storage + loadedSFT, err := NewOrLoadSentFileTransfers(kv) + if err != nil { + t.Errorf("NewOrLoadSentFileTransfers returned an error: %+v", err) + } + + newSFT, _ := NewSentFileTransfers(kv) + + if !reflect.DeepEqual(newSFT, loadedSFT) { + t.Errorf("Returned SentFileTransfers does not match new."+ + "\nexpected: %+v\nreceived: %+v", newSFT, loadedSFT) + } +} + +// Error path: tests that the NewOrLoadSentFileTransfers returns the expected +// error when the first transfer loaded from storage does not exist. +func TestNewOrLoadSentFileTransfers_NoTransferInStorageError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + expectedErr := strings.Split(loadSentTransfersErr, "%")[0] + + // Save list of one transfer ID to storage + obj := &versioned.Object{ + Version: sentFileTransfersVersion, + Timestamp: netTime.Now(), + Data: ftCrypto.UnmarshalTransferID([]byte("testID_01")).Bytes(), + } + err := kv.Prefix(sentFileTransfersPrefix).Set( + sentFileTransfersKey, sentFileTransfersVersion, obj) + + // Load SentFileTransfers from storage + _, err = NewOrLoadSentFileTransfers(kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("NewOrLoadSentFileTransfers did not return the expected "+ + "error when there is no transfer saved in storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that SentFileTransfers.saveTransfersList saves all the transfer IDs to +// storage by loading them from storage via SentFileTransfers.loadTransfersList +// and comparing the list to the list in memory. +func TestSentFileTransfers_saveTransfersList_loadTransfersList(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + sft := &SentFileTransfers{ + transfers: make(map[ftCrypto.TransferID]*SentTransfer), + kv: kv.Prefix(sentFileTransfersPrefix), + } + + // Add 10 transfers to map in memory + for i := 0; i < 10; i++ { + tid, st := newRandomSentTransfer(16, 24, sft.kv, t) + sft.transfers[tid] = st + } + + // Save transfer ID list to storage + err := sft.saveTransfersList() + if err != nil { + t.Errorf("saveTransfersList returned an error: %+v", err) + } + + // Get list from storage + list, err := sft.loadTransfersList() + if err != nil { + t.Errorf("loadTransfersList returned an error: %+v", err) + } + + // Check that the list has all the transfer IDs in memory + for _, tid := range list { + if _, exists := sft.transfers[tid]; !exists { + t.Errorf("No transfer for ID %s exists.", tid) + } else { + delete(sft.transfers, tid) + } + } +} + +// Tests that the transfer loaded by SentFileTransfers.loadTransfers from +// storage matches the original in memory +func TestSentFileTransfers_loadTransfers(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + sft := &SentFileTransfers{ + transfers: make(map[ftCrypto.TransferID]*SentTransfer), + kv: kv.Prefix(sentFileTransfersPrefix), + } + + // Add 10 transfers to map in memory + list := make([]ftCrypto.TransferID, 10) + for i := range list { + tid, st := newRandomSentTransfer(16, 24, sft.kv, t) + sft.transfers[tid] = st + list[i] = tid + } + + // Load the transfers into a new SentFileTransfers + loadedSft := &SentFileTransfers{ + transfers: make(map[ftCrypto.TransferID]*SentTransfer), + kv: kv.Prefix(sentFileTransfersPrefix), + } + err := loadedSft.loadTransfers(list) + if err != nil { + t.Errorf("loadTransfers returned an error: %+v", err) + } + + // Equalize all progressCallbacks because reflect.DeepEqual does not seem + // to work on function pointers + for _, tid := range list { + loadedSft.transfers[tid].progressCallbacks = sft.transfers[tid].progressCallbacks + } + + if !reflect.DeepEqual(sft.transfers, loadedSft.transfers) { + t.Errorf("Transfers loaded from storage does not match transfers in memory."+ + "\nexpected: %+v\nreceived: %+v", sft.transfers, loadedSft.transfers) + } +} + +// Tests that a transfer list marshalled with +// SentFileTransfers.marshalTransfersList and unmarshalled with +// unmarshalTransfersList matches the original. +func TestSentFileTransfers_marshalTransfersList_unmarshalTransfersList(t *testing.T) { + prng := NewPrng(42) + sft := &SentFileTransfers{ + transfers: make(map[ftCrypto.TransferID]*SentTransfer), + } + + // Add 10 transfers to map in memory + for i := 0; i < 10; i++ { + tid, _ := ftCrypto.NewTransferID(prng) + sft.transfers[tid] = &SentTransfer{} + } + + // Marshal into byte slice + marshalledBytes := sft.marshalTransfersList() + + // Unmarshal marshalled bytes into transfer ID list + list := unmarshalTransfersList(marshalledBytes) + + // Check that the list has all the transfer IDs in memory + for _, tid := range list { + if _, exists := sft.transfers[tid]; !exists { + t.Errorf("No transfer for ID %s exists.", tid) + } else { + delete(sft.transfers, tid) + } + } +} diff --git a/storage/fileTransfer/sentTransfer.go b/storage/fileTransfer/sentTransfer.go new file mode 100644 index 0000000000000000000000000000000000000000..2b05f84d5e62b549708c651072f559b9d1841bf7 --- /dev/null +++ b/storage/fileTransfer/sentTransfer.go @@ -0,0 +1,631 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "encoding/binary" + "github.com/pkg/errors" + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/storage/utility" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "sync" + "time" +) + +// Storage keys and versions. +const ( + sentTransferPrefix = "FileTransferSentTransferStore" + sentTransferKey = "SentTransfer" + sentTransferVersion = 0 + sentFpVectorKey = "SentFingerprintVector" +) + +// Error messages. +const ( + // NewSentTransfer + newSentTransferFpVectorErr = "failed to create new state vector for fingerprints: %+v" + newSentTransferPartStoreErr = "failed to create new part store: %+v" + newInProgressTransfersErr = "failed to create new in-progress transfers bundle: %+v" + newFinishedTransfersErr = "failed to create new finished transfers bundle: %+v" + + // SentTransfer.ReInit + reInitSentTransferFpVectorErr = "failed to overwrite fingerprint state vector with new vector: %+v" + reInitInProgressTransfersErr = "failed to overwrite in-progress transfers bundle: %+v" + reInitFinishedTransfersErr = "failed to overwrite finished transfers bundle: %+v" + + // loadSentTransfer + loadSentStoreErr = "failed to load sent transfer info from storage: %+v" + loadSentFpVectorErr = "failed to load sent fingerprint vector from storage: %+v" + loadSentPartStoreErr = "failed to load sent part store from storage: %+v" + loadInProgressTransfersErr = "failed to load in-progress transfers bundle from storage: %+v" + loadFinishedTransfersErr = "failed to load finished transfers bundle from storage: %+v" + + // SentTransfer.delete + deleteSentTransferInfoErr = "failed to delete sent transfer info from storage: %+v" + deleteSentFpVectorErr = "failed to delete sent fingerprint vector from storage: %+v" + deleteSentFilePartsErr = "failed to delete sent file parts from storage: %+v" + deleteInProgressTransfersErr = "failed to delete in-progress transfers from storage: %+v" + deleteFinishedTransfersErr = "failed to delete finished transfers from storage: %+v" + + // SentTransfer.FinishTransfer + noPartsForRoundErr = "no file parts in-progress on round %d" + deleteInProgressPartsErr = "failed to remove file parts on round %d from in-progress: %+v" + + // SentTransfer.GetEncryptedPart + noPartNumErr = "no part with part number %d exists" + maxRetriesErr = "maximum number of retries reached" + fingerprintErr = "could not get fingerprint: %+v" + encryptPartErr = "failed to encrypt file part #%d: %+v" +) + +// MaxRetriesErr is returned as an error when number of file part sending +// retries runs out. This occurs when all the fingerprints in a transfer have +// been used. +var MaxRetriesErr = errors.New(maxRetriesErr) + +// SentTransfer contains information and progress data for sending and in- +// progress file transfer. +type SentTransfer struct { + // ID of the recipient of the file transfer + recipient *id.ID + + // The transfer key is a randomly generated key created by the sender and + // used to generate MACs and fingerprints + key ftCrypto.TransferKey + + // The number of file parts in the file + numParts uint16 + + // The number of fingerprints to generate (function of numParts and the + // retry rate) + numFps uint16 + + // Stores the state of a fingerprint (used/unused) in a bitstream format + // (has its own storage backend) + fpVector *utility.StateVector + + // List of all file parts in order to send (has its own storage backend) + sentParts *partStore + + // List of parts per round that are currently transferring + inProgressTransfers *transferredBundle + + // List of parts per round that finished transferring + finishedTransfers *transferredBundle + + // List of callbacks to call for every send + progressCallbacks []*sentCallbackTracker + + // status indicates that the transfer is either done or errored out and + // that no more callbacks should be called + status transferStatus + + mux sync.RWMutex + kv *versioned.KV +} + +type transferStatus int + +const ( + running transferStatus = iota + stopping + stopped +) + +// NewSentTransfer generates a new SentTransfer with the specified transfer key, +// transfer ID, and number of parts. +func NewSentTransfer(recipient *id.ID, tid ftCrypto.TransferID, + key ftCrypto.TransferKey, parts [][]byte, numFps uint16, + progressCB interfaces.SentProgressCallback, period time.Duration, + kv *versioned.KV) (*SentTransfer, error) { + + // Create the SentTransfer object + st := &SentTransfer{ + recipient: recipient, + key: key, + numParts: uint16(len(parts)), + numFps: numFps, + progressCallbacks: []*sentCallbackTracker{}, + status: running, + kv: kv.Prefix(makeSentTransferPrefix(tid)), + } + + var err error + + // Create new StateVector for storing fingerprint usage + st.fpVector, err = utility.NewStateVector( + st.kv, sentFpVectorKey, uint32(numFps)) + if err != nil { + return nil, errors.Errorf(newSentTransferFpVectorErr, err) + } + + // Create new part store + st.sentParts, err = newPartStoreFromParts(st.kv, parts...) + if err != nil { + return nil, errors.Errorf(newSentTransferPartStoreErr, err) + } + + // Create new in-progress transfer bundle + st.inProgressTransfers, err = newTransferredBundle(inProgressKey, st.kv) + if err != nil { + return nil, errors.Errorf(newInProgressTransfersErr, err) + } + + // Create new finished transfer bundle + st.finishedTransfers, err = newTransferredBundle(finishedKey, st.kv) + if err != nil { + return nil, errors.Errorf(newFinishedTransfersErr, err) + } + + // Add first progress callback + if progressCB != nil { + st.AddProgressCB(progressCB, period) + } + + return st, st.saveInfo() +} + +// ReInit resets the SentTransfer to its initial state so that sending can +// resume from the beginning. ReInit is used when the sent transfer runs out of +// retries and a user wants to attempt to resend the entire file again. +func (st *SentTransfer) ReInit(numFps uint16, + progressCB interfaces.SentProgressCallback, period time.Duration) error { + st.mux.Lock() + defer st.mux.Unlock() + var err error + + // Mark the status as running + st.status = running + + // Update number of fingerprints and overwrite old fingerprint vector + st.numFps = numFps + st.fpVector, err = utility.NewStateVector( + st.kv, sentFpVectorKey, uint32(numFps)) + if err != nil { + return errors.Errorf(reInitSentTransferFpVectorErr, err) + } + + // Overwrite in-progress transfer bundle + st.inProgressTransfers, err = newTransferredBundle(inProgressKey, st.kv) + if err != nil { + return errors.Errorf(reInitInProgressTransfersErr, err) + } + + // Overwrite finished transfer bundle + st.finishedTransfers, err = newTransferredBundle(finishedKey, st.kv) + if err != nil { + return errors.Errorf(reInitFinishedTransfersErr, err) + } + + // Clear callbacks + st.progressCallbacks = []*sentCallbackTracker{} + + // Add first progress callback + if progressCB != nil { + // Add callback + sct := newSentCallbackTracker(progressCB, period) + st.progressCallbacks = append(st.progressCallbacks, sct) + + // Trigger the initial call + sct.callNowUnsafe(st, nil) + } + + return nil +} + +// GetRecipient returns the ID of the recipient of the transfer. +func (st *SentTransfer) GetRecipient() *id.ID { + st.mux.RLock() + defer st.mux.RUnlock() + + return st.recipient +} + +// GetTransferKey returns the transfer Key for this sent transfer. +func (st *SentTransfer) GetTransferKey() ftCrypto.TransferKey { + st.mux.RLock() + defer st.mux.RUnlock() + + return st.key +} + +// GetNumParts returns the number of file parts in this transfer. +func (st *SentTransfer) GetNumParts() uint16 { + st.mux.RLock() + defer st.mux.RUnlock() + + return st.numParts +} + +// GetNumFps returns the number of fingerprints. +func (st *SentTransfer) GetNumFps() uint16 { + st.mux.RLock() + defer st.mux.RUnlock() + + return st.numFps +} + +// GetNumAvailableFps returns the number of unused fingerprints. +func (st *SentTransfer) GetNumAvailableFps() uint16 { + st.mux.RLock() + defer st.mux.RUnlock() + + return uint16(st.fpVector.GetNumAvailable()) +} + +// GetProgress returns the current progress of the transfer. Completed is true +// when all parts have arrived, sent is the number of in-progress parts, arrived +// is the number of finished parts and total is the total number of parts being +// sent. +func (st *SentTransfer) GetProgress() (completed bool, sent, arrived, total uint16) { + st.mux.RLock() + defer st.mux.RUnlock() + + return st.getProgress() +} + +// getProgress is the thread-unsafe helper function for GetProgress. +func (st *SentTransfer) getProgress() (completed bool, sent, arrived, total uint16) { + arrived = st.finishedTransfers.getNumParts() + sent = st.inProgressTransfers.getNumParts() + total = st.numParts + + if sent == 0 && arrived == total { + completed = true + } + + return completed, sent, arrived, total +} + +// CallProgressCB calls all the progress callbacks with the most recent progress +// information. +func (st *SentTransfer) CallProgressCB(err error) { + st.mux.Lock() + + switch st.status { + case stopped: + st.mux.Unlock() + return + case stopping: + st.status = stopped + } + + st.mux.Unlock() + st.mux.RLock() + defer st.mux.RUnlock() + + for _, cb := range st.progressCallbacks { + cb.call(st, err) + } +} + +// AddProgressCB appends a new interfaces.SentProgressCallback to the list of +// progress callbacks to be called and calls it. The period is how often the +// callback should be called when there are updates. +func (st *SentTransfer) AddProgressCB(cb interfaces.SentProgressCallback, + period time.Duration) { + st.mux.Lock() + + // Add callback + sct := newSentCallbackTracker(cb, period) + st.progressCallbacks = append(st.progressCallbacks, sct) + + st.mux.Unlock() + + // Trigger the initial call + sct.callNow(st, nil) +} + +// GetEncryptedPart gets the specified part, encrypts it, and returns the +// encrypted part along with its MAC, padding, and fingerprint. +func (st *SentTransfer) GetEncryptedPart(partNum uint16, partSize int, + rng csprng.Source) (encPart, mac, padding []byte, fp format.Fingerprint, + err error) { + st.mux.Lock() + defer st.mux.Unlock() + + // Lookup part + part, exists := st.sentParts.getPart(partNum) + if !exists { + return nil, nil, nil, format.Fingerprint{}, + errors.Errorf(noPartNumErr, partNum) + } + + // If all fingerprints have been used but parts still remain, then change + // the status to stopping and return an error specifying that all the + // retries have been used + if st.fpVector.GetNumAvailable() < 1 { + st.status = stopping + return nil, nil, nil, format.Fingerprint{}, MaxRetriesErr + } + + // Get next unused fingerprint number and mark it as used + nextKey, err := st.fpVector.Next() + if err != nil { + return nil, nil, nil, format.Fingerprint{}, + errors.Errorf(fingerprintErr, err) + } + fpNum := uint16(nextKey) + + // Generate fingerprint + fp = ftCrypto.GenerateFingerprint(st.key, fpNum) + + // Encrypt the file part and generate the file part MAC and padding (nonce) + maxLengthPart := make([]byte, partSize) + copy(maxLengthPart, part) + encPart, mac, padding, err = ftCrypto.EncryptPart( + st.key, maxLengthPart, fpNum, rng) + if err != nil { + return nil, nil, nil, format.Fingerprint{}, + errors.Errorf(encryptPartErr, partNum, err) + } + + return encPart, mac, padding, fp, err +} + +// SetInProgress adds the specified file part numbers to the in-progress +// transfers for the given round ID. Returns whether the round already exists in +// the list. +func (st *SentTransfer) SetInProgress(rid id.Round, partNums ...uint16) (error, bool) { + st.mux.Lock() + defer st.mux.Unlock() + + _, exists := st.inProgressTransfers.getPartNums(rid) + + return st.inProgressTransfers.addPartNums(rid, partNums...), exists +} + +// GetInProgress returns a list of all part number in the in-progress transfers +// list. +func (st *SentTransfer) GetInProgress(rid id.Round) ([]uint16, bool) { + st.mux.Lock() + defer st.mux.Unlock() + + return st.inProgressTransfers.getPartNums(rid) +} + +// UnsetInProgress removed the file part numbers from the in-progress transfers +// for the given round ID. Returns the list of part numbers that were removed +// from the list. +func (st *SentTransfer) UnsetInProgress(rid id.Round) ([]uint16, error) { + st.mux.Lock() + defer st.mux.Unlock() + + // Get the list of part numbers to be removed from list + partNums, _ := st.inProgressTransfers.getPartNums(rid) + + return partNums, st.inProgressTransfers.deletePartNums(rid) +} + +// FinishTransfer moves the in-progress file parts for the given round to the +// finished list. +func (st *SentTransfer) FinishTransfer(rid id.Round) error { + st.mux.Lock() + defer st.mux.Unlock() + + // Get the parts in-progress for the round ID or return an error if none + // exist + partNums, exists := st.inProgressTransfers.getPartNums(rid) + if !exists { + return errors.Errorf(noPartsForRoundErr, rid) + } + + // Delete the parts from the in-progress list + err := st.inProgressTransfers.deletePartNums(rid) + if err != nil { + return errors.Errorf(deleteInProgressPartsErr, rid, err) + } + + // Add the parts to the finished list + err = st.finishedTransfers.addPartNums(rid, partNums...) + if err != nil { + return err + } + + // If all parts have been moved to the finished list, then set the status + // to stopping + if st.finishedTransfers.getNumParts() == st.numParts && + st.inProgressTransfers.getNumParts() == 0 { + st.status = stopping + } + + return nil +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// loadSentTransfer loads the SentTransfer with the given transfer ID from +// storage. +func loadSentTransfer(tid ftCrypto.TransferID, kv *versioned.KV) (*SentTransfer, + error) { + st := &SentTransfer{ + kv: kv.Prefix(makeSentTransferPrefix(tid)), + } + + // Load transfer key and number of sent parts from storage + err := st.loadInfo() + if err != nil { + return nil, errors.Errorf(loadSentStoreErr, err) + } + + // Load the fingerprint vector from storage + st.fpVector, err = utility.LoadStateVector(st.kv, sentFpVectorKey) + if err != nil { + return nil, errors.Errorf(loadSentFpVectorErr, err) + } + + // Load sent part store from storage + st.sentParts, err = loadPartStore(st.kv) + if err != nil { + return nil, errors.Errorf(loadSentPartStoreErr, err) + } + + // Load in-progress transfer bundle from storage + st.inProgressTransfers, err = loadTransferredBundle(inProgressKey, st.kv) + if err != nil { + return nil, errors.Errorf(loadInProgressTransfersErr, err) + } + + // Load finished transfer bundle from storage + st.finishedTransfers, err = loadTransferredBundle(finishedKey, st.kv) + if err != nil { + return nil, errors.Errorf(loadFinishedTransfersErr, err) + } + + return st, nil +} + +// saveInfo saves the SentTransfer info (transfer key, number of file parts, +// and number of fingerprints) to storage. + +// saveInfo saves all fields in SentTransfer that do not have their own storage +// (recipient ID, transfer key, number of file parts, and number of +// fingerprints) to storage. +func (st *SentTransfer) saveInfo() error { + st.mux.Lock() + defer st.mux.Unlock() + + // Create new versioned object for the SentTransfer + obj := &versioned.Object{ + Version: sentTransferVersion, + Timestamp: netTime.Now(), + Data: st.marshal(), + } + + // Save versioned object + return st.kv.Set(sentTransferKey, sentTransferVersion, obj) +} + +// loadInfo gets the recipient ID, transfer key, number of part, and number of +// fingerprints from storage and saves it to the SentTransfer. +func (st *SentTransfer) loadInfo() error { + vo, err := st.kv.Get(sentTransferKey, sentTransferVersion) + if err != nil { + return err + } + + // Unmarshal the transfer key and numParts + st.recipient, st.key, st.numParts, st.numFps = unmarshalSentTransfer(vo.Data) + + return nil +} + +// delete deletes all data in the SentTransfer from storage. +func (st *SentTransfer) delete() error { + st.mux.Lock() + defer st.mux.Unlock() + + // Delete sent transfer info from storage + err := st.deleteInfo() + if err != nil { + return errors.Errorf(deleteSentTransferInfoErr, err) + } + + // Delete fingerprint vector from storage + err = st.fpVector.Delete() + if err != nil { + return errors.Errorf(deleteSentFpVectorErr, err) + } + + // Delete sent file parts from storage + err = st.sentParts.delete() + if err != nil { + return errors.Errorf(deleteSentFilePartsErr, err) + } + + // Delete in-progress transfer bundles from storage + err = st.inProgressTransfers.delete() + if err != nil { + return errors.Errorf(deleteInProgressTransfersErr, err) + } + + // Delete finished transfer bundles from storage + err = st.finishedTransfers.delete() + if err != nil { + return errors.Errorf(deleteFinishedTransfersErr, err) + } + + return nil +} + +// deleteInfo removes received transfer info (recipient, transfer key, number +// of parts, and number of fingerprints) from storage. +func (st *SentTransfer) deleteInfo() error { + return st.kv.Delete(sentTransferKey, sentTransferVersion) +} + +// marshal serializes the transfer key, numParts, and numFps. + +// marshal serializes all primitive fields in SentTransfer (recipient, key, +// numParts, and numFps). +func (st *SentTransfer) marshal() []byte { + // Construct the buffer to the correct size + // (size of ID + size of key + numParts (2 bytes) + numFps (2 bytes)) + buff := bytes.NewBuffer(nil) + buff.Grow(id.ArrIDLen + ftCrypto.TransferKeyLength + 2 + 2) + + // Write the recipient ID to the buffer + if st.recipient != nil { + buff.Write(st.recipient.Marshal()) + } else { + buff.Write((&id.ID{}).Marshal()) + } + + // Write the key to the buffer + buff.Write(st.key.Bytes()) + + // Write the number of parts to the buffer + b := make([]byte, 2) + binary.LittleEndian.PutUint16(b, st.numParts) + buff.Write(b) + + // Write the number of fingerprints to the buffer + b = make([]byte, 2) + binary.LittleEndian.PutUint16(b, st.numFps) + buff.Write(b) + + // Return the serialized data + return buff.Bytes() +} + +// unmarshalSentTransfer deserializes a byte slice into the primitive fields +// of SentTransfer (recipient, key, numParts, and numFps). +func unmarshalSentTransfer(b []byte) ( + recipient *id.ID, key ftCrypto.TransferKey, numParts, numFps uint16) { + + buff := bytes.NewBuffer(b) + + // Read the recipient ID from the buffer + recipient = &id.ID{} + copy(recipient[:], buff.Next(id.ArrIDLen)) + + // Read the transfer key from the buffer + key = ftCrypto.UnmarshalTransferKey(buff.Next(ftCrypto.TransferKeyLength)) + + // Read the number of part from the buffer + numParts = binary.LittleEndian.Uint16(buff.Next(2)) + + // Read the number of fingerprints from the buffer + numFps = binary.LittleEndian.Uint16(buff.Next(2)) + + return recipient, key, numParts, numFps +} + +// makeSentTransferPrefix generates the unique prefix used on the key value +// store to store sent transfers for the given transfer ID. +func makeSentTransferPrefix(tid ftCrypto.TransferID) string { + return sentTransferPrefix + tid.String() +} diff --git a/storage/fileTransfer/sentTransfer_test.go b/storage/fileTransfer/sentTransfer_test.go new file mode 100644 index 0000000000000000000000000000000000000000..14e8e76ad91801e024eed945341e25a3141df7b2 --- /dev/null +++ b/storage/fileTransfer/sentTransfer_test.go @@ -0,0 +1,1304 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "fmt" + "github.com/pkg/errors" + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/storage/utility" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/ekv" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "reflect" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// Tests that NewSentTransfer creates the expected SentTransfer and that it is +// saved to storage. +func Test_NewSentTransfer(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + recipient, _ := id.NewRandomID(prng, id.User) + tid, _ := ftCrypto.NewTransferID(prng) + key, _ := ftCrypto.NewTransferKey(prng) + kvPrefixed := kv.Prefix(makeSentTransferPrefix(tid)) + parts := [][]byte{ + []byte("test0"), []byte("test1"), []byte("test2"), + []byte("test3"), []byte("test4"), []byte("test5"), + } + numParts, numFps := uint16(len(parts)), uint16(float64(len(parts))*1.5) + fpVector, _ := utility.NewStateVector( + kvPrefixed, sentFpVectorKey, uint32(numFps)) + + type cbFields struct { + completed bool + sent, arrived, total uint16 + err error + } + + expectedCB := cbFields{ + completed: false, + sent: 0, + arrived: 0, + total: numParts, + err: nil, + } + + cbChan := make(chan cbFields) + cb := func(completed bool, sent, arrived, total uint16, err error) { + cbChan <- cbFields{ + completed: completed, + sent: sent, + arrived: arrived, + total: total, + err: err, + } + } + + expectedPeriod := time.Second + + expected := &SentTransfer{ + recipient: recipient, + key: key, + numParts: numParts, + numFps: numFps, + fpVector: fpVector, + sentParts: &partStore{ + parts: partSliceToMap(parts...), + numParts: uint16(len(parts)), + kv: kvPrefixed, + }, + inProgressTransfers: &transferredBundle{ + list: make(map[id.Round][]uint16), + key: inProgressKey, + kv: kvPrefixed, + }, + finishedTransfers: &transferredBundle{ + list: make(map[id.Round][]uint16), + key: finishedKey, + kv: kvPrefixed, + }, + progressCallbacks: []*sentCallbackTracker{ + newSentCallbackTracker(cb, expectedPeriod), + }, + kv: kvPrefixed, + } + + // Create new SentTransfer + st, err := NewSentTransfer( + recipient, tid, key, parts, numFps, cb, expectedPeriod, kv) + if err != nil { + t.Errorf("NewSentTransfer returned an error: %+v", err) + } + + // Check that the callback is called when added + select { + case <-time.NewTimer(10 * time.Millisecond).C: + t.Error("Timed out waiting fpr progress callback to be called.") + case cbResults := <-cbChan: + if !reflect.DeepEqual(expectedCB, cbResults) { + t.Errorf("Did not receive correct results from callback."+ + "\nexpected: %+v\nreceived: %+v", expectedCB, cbResults) + } + } + + st.progressCallbacks = expected.progressCallbacks + + // Check that the new object matches the expected + if !reflect.DeepEqual(expected, st) { + t.Errorf("New SentTransfer does not match expected."+ + "\nexpected: %#v\nreceived: %#v", expected, st) + } + + // Make sure it is saved to storage + _, err = kvPrefixed.Get(sentTransferKey, sentTransferVersion) + if err != nil { + t.Errorf("Failed to get new SentTransfer from storage: %+v", err) + } + + // Check that the fingerprint vector has correct values + if st.fpVector.GetNumAvailable() != uint32(numFps) { + t.Errorf("Incorrect number of available keys in fingerprint list."+ + "\nexpected: %d\nreceived: %d", numFps, st.fpVector.GetNumAvailable()) + } + if st.fpVector.GetNumKeys() != uint32(numFps) { + t.Errorf("Incorrect number of keys in fingerprint list."+ + "\nexpected: %d\nreceived: %d", numFps, st.fpVector.GetNumKeys()) + } + if st.fpVector.GetNumUsed() != 0 { + t.Errorf("Incorrect number of used keys in fingerprint list."+ + "\nexpected: %d\nreceived: %d", 0, st.fpVector.GetNumUsed()) + } +} + +// Tests that SentTransfer.ReInit overwrites the fingerprint vector, in-progress +// transfer, finished transfers, and progress callbacks with new and empty +// objects. +func TestSentTransfer_ReInit(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + recipient, _ := id.NewRandomID(prng, id.User) + tid, _ := ftCrypto.NewTransferID(prng) + key, _ := ftCrypto.NewTransferKey(prng) + kvPrefixed := kv.Prefix(makeSentTransferPrefix(tid)) + parts := [][]byte{ + []byte("test0"), []byte("test1"), []byte("test2"), + []byte("test3"), []byte("test4"), []byte("test5"), + } + numParts, numFps1 := uint16(len(parts)), uint16(float64(len(parts))*1.5) + numFps2 := 2 * numFps1 + fpVector, _ := utility.NewStateVector( + kvPrefixed, sentFpVectorKey, uint32(numFps2)) + + type cbFields struct { + completed bool + sent, arrived, total uint16 + err error + } + + expectedCB := cbFields{ + completed: false, + sent: 0, + arrived: 0, + total: numParts, + err: nil, + } + + cbChan := make(chan cbFields) + cb := func(completed bool, sent, arrived, total uint16, err error) { + cbChan <- cbFields{ + completed: completed, + sent: sent, + arrived: arrived, + total: total, + err: err, + } + } + + expectedPeriod := time.Millisecond + + expected := &SentTransfer{ + recipient: recipient, + key: key, + numParts: numParts, + numFps: numFps2, + fpVector: fpVector, + sentParts: &partStore{ + parts: partSliceToMap(parts...), + numParts: uint16(len(parts)), + kv: kvPrefixed, + }, + inProgressTransfers: &transferredBundle{ + list: make(map[id.Round][]uint16), + key: inProgressKey, + kv: kvPrefixed, + }, + finishedTransfers: &transferredBundle{ + list: make(map[id.Round][]uint16), + key: finishedKey, + kv: kvPrefixed, + }, + progressCallbacks: []*sentCallbackTracker{ + newSentCallbackTracker(cb, expectedPeriod), + }, + kv: kvPrefixed, + } + + // Create new SentTransfer + st, err := NewSentTransfer( + recipient, tid, key, parts, numFps1, nil, 2*expectedPeriod, kv) + if err != nil { + t.Errorf("NewSentTransfer returned an error: %+v", err) + } + + // Re-initialize SentTransfer with new number of fingerprints and callback + err = st.ReInit(numFps2, cb, expectedPeriod) + if err != nil { + t.Errorf("ReInit returned an error: %+v", err) + } + + // Check that the callback is called when added + select { + case <-time.NewTimer(10 * time.Millisecond).C: + t.Error("Timed out waiting fpr progress callback to be called.") + case cbResults := <-cbChan: + if !reflect.DeepEqual(expectedCB, cbResults) { + t.Errorf("Did not receive correct results from callback."+ + "\nexpected: %+v\nreceived: %+v", expectedCB, cbResults) + } + } + + st.progressCallbacks = expected.progressCallbacks + + // Check that the new object matches the expected + if !reflect.DeepEqual(expected, st) { + t.Errorf("New SentTransfer does not match expected."+ + "\nexpected: %#v\nreceived: %#v", expected, st) + } + + // Make sure it is saved to storage + _, err = kvPrefixed.Get(sentTransferKey, sentTransferVersion) + if err != nil { + t.Errorf("Failed to get new SentTransfer from storage: %+v", err) + } + + // Check that the fingerprint vector has correct values + if st.fpVector.GetNumAvailable() != uint32(numFps2) { + t.Errorf("Incorrect number of available keys in fingerprint list."+ + "\nexpected: %d\nreceived: %d", numFps2, st.fpVector.GetNumAvailable()) + } + if st.fpVector.GetNumKeys() != uint32(numFps2) { + t.Errorf("Incorrect number of keys in fingerprint list."+ + "\nexpected: %d\nreceived: %d", numFps2, st.fpVector.GetNumKeys()) + } + if st.fpVector.GetNumUsed() != 0 { + t.Errorf("Incorrect number of used keys in fingerprint list."+ + "\nexpected: %d\nreceived: %d", 0, st.fpVector.GetNumUsed()) + } +} + +// Tests that SentTransfer.GetRecipient returns the expected ID. +func TestSentTransfer_GetRecipient(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + expectedRecipient, _ := id.NewRandomID(prng, id.User) + tid, _ := ftCrypto.NewTransferID(prng) + key, _ := ftCrypto.NewTransferKey(prng) + + // Create new SentTransfer + st, err := NewSentTransfer( + expectedRecipient, tid, key, [][]byte{}, 5, nil, 0, kv) + if err != nil { + t.Errorf("Failed to create new SentTransfer: %+v", err) + } + + if expectedRecipient != st.GetRecipient() { + t.Errorf("Failed to get expected transfer key."+ + "\nexpected: %s\nreceived: %s", expectedRecipient, st.GetRecipient()) + } +} + +// Tests that SentTransfer.GetTransferKey returns the expected transfer key. +func TestSentTransfer_GetTransferKey(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + recipient, _ := id.NewRandomID(prng, id.User) + tid, _ := ftCrypto.NewTransferID(prng) + expectedKey, _ := ftCrypto.NewTransferKey(prng) + + // Create new SentTransfer + st, err := NewSentTransfer( + recipient, tid, expectedKey, [][]byte{}, 5, nil, 0, kv) + if err != nil { + t.Errorf("Failed to create new SentTransfer: %+v", err) + } + + if expectedKey != st.GetTransferKey() { + t.Errorf("Failed to get expected transfer key."+ + "\nexpected: %s\nreceived: %s", expectedKey, st.GetTransferKey()) + } +} + +// Tests that SentTransfer.GetNumParts returns the expected number of parts. +func TestSentTransfer_GetNumParts(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + expectedNumParts := uint16(16) + _, st := newRandomSentTransfer(expectedNumParts, 24, kv, t) + + if expectedNumParts != st.GetNumParts() { + t.Errorf("Failed to get expected number of parts."+ + "\nexpected: %d\nreceived: %d", expectedNumParts, st.GetNumParts()) + } +} + +// Tests that SentTransfer.GetNumFps returns the expected number of +// fingerprints. +func TestSentTransfer_GetNumFps(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + expectedNumFps := uint16(24) + _, st := newRandomSentTransfer(16, expectedNumFps, kv, t) + + if expectedNumFps != st.GetNumFps() { + t.Errorf("Failed to get expected number of fingerprints."+ + "\nexpected: %d\nreceived: %d", expectedNumFps, st.GetNumFps()) + } +} + +// Tests that SentTransfer.GetNumAvailableFps returns the expected number of +// available fingerprints. +func TestSentTransfer_GetNumAvailableFps(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + numParts, numFps := uint16(16), uint16(24) + _, st := newRandomSentTransfer(numParts, numFps, kv, t) + + if numFps != st.GetNumAvailableFps() { + t.Errorf("Failed to get expected number of available fingerprints."+ + "\nexpected: %d\nreceived: %d", + numFps, st.GetNumAvailableFps()) + } + + for i := uint16(0); i < numParts; i++ { + _, _ = st.fpVector.Next() + } + + if numFps-numParts != st.GetNumAvailableFps() { + t.Errorf("Failed to get expected number of available fingerprints."+ + "\nexpected: %d\nreceived: %d", + numFps-numParts, st.GetNumAvailableFps()) + } +} + +// Tests that SentTransfer.GetProgress returns the expected progress metrics for +// various transfer states. +func TestSentTransfer_GetProgress(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + numParts := uint16(16) + _, st := newRandomSentTransfer(16, 24, kv, t) + + completed, sent, arrived, total := st.GetProgress() + err := checkSentProgress( + completed, sent, arrived, total, false, 0, 0, numParts) + if err != nil { + t.Error(err) + } + + _, _ = st.SetInProgress(1, 0, 1, 2) + + completed, sent, arrived, total = st.GetProgress() + err = checkSentProgress(completed, sent, arrived, total, false, 3, 0, numParts) + if err != nil { + t.Error(err) + } + + _, _ = st.SetInProgress(2, 3, 4, 5) + + completed, sent, arrived, total = st.GetProgress() + err = checkSentProgress(completed, sent, arrived, total, false, 6, 0, numParts) + if err != nil { + t.Error(err) + } + + _ = st.FinishTransfer(1) + _, _ = st.UnsetInProgress(2) + + completed, sent, arrived, total = st.GetProgress() + err = checkSentProgress(completed, sent, arrived, total, false, 0, 3, numParts) + if err != nil { + t.Error(err) + } + + _, _ = st.SetInProgress(3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) + + completed, sent, arrived, total = st.GetProgress() + err = checkSentProgress( + completed, sent, arrived, total, false, 10, 3, numParts) + if err != nil { + t.Error(err) + } + + _ = st.FinishTransfer(3) + _, _ = st.SetInProgress(4, 3, 4, 5) + + completed, sent, arrived, total = st.GetProgress() + err = checkSentProgress( + completed, sent, arrived, total, false, 3, 13, numParts) + if err != nil { + t.Error(err) + } + + _ = st.FinishTransfer(4) + + completed, sent, arrived, total = st.GetProgress() + err = checkSentProgress(completed, sent, arrived, total, true, 0, 16, numParts) + if err != nil { + t.Error(err) + } +} + +// Tests that 5 different callbacks all receive the expected data when +// SentTransfer.CallProgressCB is called at different stages of transfer. +func TestSentTransfer_CallProgressCB(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + + type progressResults struct { + completed bool + sent, arrived, total uint16 + err error + } + + period := time.Millisecond + + wg := sync.WaitGroup{} + var step0, step1, step2, step3 uint64 + numCallbacks := 5 + + for i := 0; i < numCallbacks; i++ { + progressChan := make(chan progressResults) + + cbFunc := func(completed bool, sent, arrived, total uint16, err error) { + progressChan <- progressResults{completed, sent, arrived, total, err} + } + wg.Add(1) + + go func(i int) { + defer wg.Done() + n := 0 + for { + select { + case <-time.NewTimer(time.Second).C: + t.Errorf("Timed out after %s waiting for callback (%d).", + period*5, i) + return + case r := <-progressChan: + switch n { + case 0: + if err := checkSentProgress(r.completed, r.sent, r.arrived, + r.total, false, 0, 0, st.numParts); err != nil { + t.Errorf("%2d: %+v", i, err) + } + atomic.AddUint64(&step0, 1) + case 1: + if err := checkSentProgress(r.completed, r.sent, r.arrived, + r.total, false, 0, 0, st.numParts); err != nil { + t.Errorf("%2d: %+v", i, err) + } + atomic.AddUint64(&step1, 1) + case 2: + if err := checkSentProgress(r.completed, r.sent, r.arrived, + r.total, false, 0, 6, st.numParts); err != nil { + t.Errorf("%2d: %+v", i, err) + } + atomic.AddUint64(&step2, 1) + case 3: + if err := checkSentProgress(r.completed, r.sent, r.arrived, + r.total, true, 0, 16, st.numParts); err != nil { + t.Errorf("%2d: %+v", i, err) + } + atomic.AddUint64(&step3, 1) + return + default: + t.Errorf("n (%d) is great than 3 (%d)", n, i) + return + } + n++ + } + } + }(i) + + st.AddProgressCB(cbFunc, period) + } + + for !atomic.CompareAndSwapUint64(&step0, uint64(numCallbacks), 0) { + } + + st.CallProgressCB(nil) + + for !atomic.CompareAndSwapUint64(&step1, uint64(numCallbacks), 0) { + } + + _, _ = st.SetInProgress(0, 0, 1, 2) + _, _ = st.SetInProgress(1, 3, 4, 5) + _, _ = st.SetInProgress(2, 6, 7, 8) + _, _ = st.UnsetInProgress(1) + _ = st.FinishTransfer(0) + _ = st.FinishTransfer(2) + + st.CallProgressCB(nil) + + for !atomic.CompareAndSwapUint64(&step2, uint64(numCallbacks), 0) { + } + + _, _ = st.SetInProgress(4, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15) + _ = st.FinishTransfer(4) + + st.CallProgressCB(nil) + + wg.Wait() +} + +// Tests that SentTransfer.AddProgressCB adds an item to the progress callback +// list. +func TestSentTransfer_AddProgressCB(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + + type callbackResults struct { + completed bool + sent, arrived, total uint16 + err error + } + cbChan := make(chan callbackResults) + cbFunc := interfaces.SentProgressCallback( + func(completed bool, sent, arrived, total uint16, err error) { + cbChan <- callbackResults{completed, sent, arrived, total, err} + }) + + done := make(chan bool) + go func() { + select { + case <-time.NewTimer(time.Millisecond).C: + t.Error("Timed out waiting for progress callback to be called.") + case r := <-cbChan: + err := checkSentProgress( + r.completed, r.sent, r.arrived, r.total, false, 0, 0, 16) + if err != nil { + t.Error(err) + } + if r.err != nil { + t.Errorf("Callback returned an error: %+v", err) + } + } + done <- true + }() + + period := time.Millisecond + st.AddProgressCB(cbFunc, period) + + if len(st.progressCallbacks) != 1 { + t.Errorf("Callback list should only have one item."+ + "\nexpected: %d\nreceived: %d", 1, len(st.progressCallbacks)) + } + + if st.progressCallbacks[0].period != period { + t.Errorf("Callback has wrong lastCall.\nexpected: %s\nreceived: %s", + period, st.progressCallbacks[0].period) + } + + if st.progressCallbacks[0].lastCall != (time.Time{}) { + t.Errorf("Callback has wrong time.\nexpected: %s\nreceived: %s", + time.Time{}, st.progressCallbacks[0].lastCall) + } + + if st.progressCallbacks[0].scheduled { + t.Errorf("Callback has wrong scheduled.\nexpected: %t\nreceived: %t", + false, st.progressCallbacks[0].scheduled) + } + <-done +} + +// Loops through each file part encrypting it with SentTransfer.GetEncryptedPart +// and tests that it returns an encrypted part, MAC, and padding (nonce) that +// can be used to successfully decrypt and get the original part. Also tests +// that fingerprints are valid and not used more than once. +// It also tests that +func TestSentTransfer_GetEncryptedPart(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + prng := NewPrng(42) + + // Create and fill fingerprint map used to check fingerprint validity + // The first item in the uint16 slice is the fingerprint number and the + // second item is the number of times it has been used + fpMap := make(map[format.Fingerprint][]uint16, st.numFps) + for num, fp := range ftCrypto.GenerateFingerprints(st.key, st.numFps) { + fpMap[fp] = []uint16{uint16(num), 0} + } + + for i := uint16(0); i < st.numFps; i++ { + partNum := i % st.numParts + + encPart, mac, padding, fp, err := st.GetEncryptedPart(partNum, 16, prng) + if err != nil { + t.Fatalf("GetEncryptedPart returned an error for part number "+ + "%d (%d): %+v", partNum, i, err) + } + + // Check that the fingerprint is valid + fpNum, exists := fpMap[fp] + if !exists { + t.Errorf("Fingerprint %s invalid for part number %d (%d).", + fp, partNum, i) + } + + // Check that the fingerprint has not been used + if fpNum[1] > 0 { + t.Errorf("Fingerprint %s for part number %d already used by %d "+ + "other parts (%d).", fp, partNum, fpNum[1], i) + } + + // Attempt to decrypt the part + part, err := ftCrypto.DecryptPart(st.key, encPart, padding, mac, fpNum[0]) + if err != nil { + t.Errorf("Failed to decrypt file part number %d (%d): %+v", + partNum, i, err) + } + + // Make sure the decrypted part matches the original + expectedPart, _ := st.sentParts.getPart(i % st.numParts) + if !bytes.Equal(expectedPart, part) { + t.Errorf("Decyrpted part number %d does not match expected (%d)."+ + "\nexpected: %+v\nreceived: %+v", partNum, i, expectedPart, part) + } + } +} + +// Error path: tests that SentTransfer.GetEncryptedPart returns the expected +// error when no part for the given part number exists. +func TestSentTransfer_GetEncryptedPart_NoPartError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + prng := NewPrng(42) + + partNum := st.numParts + 1 + expectedErr := fmt.Sprintf(noPartNumErr, partNum) + + _, _, _, _, err := st.GetEncryptedPart(partNum, 16, prng) + if err == nil || err.Error() != expectedErr { + t.Errorf("GetEncryptedPart did not return the expected error for a "+ + "nonexistent part number %d.\nexpected: %s\nreceived: %+v", + partNum, expectedErr, err) + } +} + +// Error path: tests that SentTransfer.GetEncryptedPart returns the expected +// error when no fingerprints are available. +func TestSentTransfer_GetEncryptedPart_NoFingerprintsError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + prng := NewPrng(42) + + // Use up all the fingerprints + for i := uint16(0); i < st.numFps; i++ { + partNum := i % st.numParts + _, _, _, _, err := st.GetEncryptedPart(partNum, 16, prng) + if err != nil { + t.Errorf("Error when encyrpting part number %d (%d): %+v", + partNum, i, err) + } + } + + // Try to encrypt without any fingerprints + _, _, _, _, err := st.GetEncryptedPart(5, 16, prng) + if err != MaxRetriesErr { + t.Errorf("GetEncryptedPart did not return MaxRetriesErr when all "+ + "fingerprints have been used.\nexpected: %s\nreceived: %+v", + MaxRetriesErr, err) + } +} + +// Error path: tests that SentTransfer.GetEncryptedPart returns the expected +// error when encrypting the part fails due to a PRNG error. +func TestSentTransfer_GetEncryptedPart_EncryptPartError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + prng := NewPrngErr() + + // Create and fill fingerprint map used to check fingerprint validity + // The first item in the uint16 slice is the fingerprint number and the + // second item is the number of times it has been used + fpMap := make(map[format.Fingerprint][]uint16, st.numFps) + for num, fp := range ftCrypto.GenerateFingerprints(st.key, st.numFps) { + fpMap[fp] = []uint16{uint16(num), 0} + } + + partNum := uint16(0) + expectedErr := fmt.Sprintf(encryptPartErr, partNum, "") + + _, _, _, _, err := st.GetEncryptedPart(partNum, 16, prng) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("GetEncryptedPart did not return the expected error when "+ + "the PRNG should have errored.\nexpected: %s\nreceived: %+v", + expectedErr, err) + } +} + +// Tests that SentTransfer.SetInProgress correctly adds the part numbers for the +// given round ID to the in-progress map. +func TestSentTransfer_SetInProgress(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + + rid := id.Round(5) + expectedPartNums := []uint16{1, 2, 3} + + // Add parts to the in-progress list + err, exists := st.SetInProgress(rid, expectedPartNums...) + if err != nil { + t.Errorf("SetInProgress returned an error: %+v", err) + } + + // Check that the round does not already exist + if exists { + t.Errorf("Round %d already exists.", rid) + } + + // Check that the round ID is in the map + partNums, exists := st.inProgressTransfers.list[rid] + if !exists { + t.Errorf("Part numbers for round %d not found.", rid) + } + + // Check that the returned part numbers are correct + if !reflect.DeepEqual(expectedPartNums, partNums) { + t.Errorf("Received part numbers do not match expected."+ + "\nexpected: %v\nreceived: %v", expectedPartNums, partNums) + } + + // Check that only one item was added to the list + if len(st.inProgressTransfers.list) > 1 { + t.Errorf("Extra items in in-progress list."+ + "\nexpected: %d\nreceived: %d", 1, len(st.inProgressTransfers.list)) + } + + // Add more parts to the in-progress list + err, exists = st.SetInProgress(rid, expectedPartNums...) + if err != nil { + t.Errorf("SetInProgress returned an error: %+v", err) + } + + // Check that the round already exists + if !exists { + t.Errorf("Round %d should already exist.", rid) + } +} + +// Tests that SentTransfer.GetInProgress returns the correct part numbers for +// the given round ID in the in-progress map. +func TestSentTransfer_GetInProgress(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + + rid := id.Round(5) + expectedPartNums := []uint16{1, 2, 3, 4, 5, 6} + + // Add parts to the in-progress list + err, _ := st.SetInProgress(rid, expectedPartNums[:3]...) + if err != nil { + t.Errorf("Failed to set parts %v to in-progress: %+v", + expectedPartNums[3:], err) + } + + // Add parts to the in-progress list + err, _ = st.SetInProgress(rid, expectedPartNums[3:]...) + if err != nil { + t.Errorf("Failed to set parts %v to in-progress: %+v", + expectedPartNums[:3], err) + } + + // Get the in-progress parts + receivedPartNums, exists := st.GetInProgress(rid) + if !exists { + t.Errorf("Failed to find parts for round %d that should exist.", rid) + } + + // Check that the returned part numbers are correct + if !reflect.DeepEqual(expectedPartNums, receivedPartNums) { + t.Errorf("Received part numbers do not match expected."+ + "\nexpected: %v\nreceived: %v", expectedPartNums, receivedPartNums) + } +} + +// Tests that SentTransfer.UnsetInProgress correctly removes the part numbers +// for the given round ID from the in-progress map. +func TestSentTransfer_UnsetInProgress(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + + rid := id.Round(5) + expectedPartNums := []uint16{1, 2, 3, 4, 5, 6} + + // Add parts to the in-progress list + if err, _ := st.SetInProgress(rid, expectedPartNums[:3]...); err != nil { + t.Errorf("Failed to set parts in-progress: %+v", err) + } + if err, _ := st.SetInProgress(rid, expectedPartNums[3:]...); err != nil { + t.Errorf("Failed to set parts in-progress: %+v", err) + } + + // Remove parts from in-progress list + receivedPartNums, err := st.UnsetInProgress(rid) + if err != nil { + t.Errorf("UnsetInProgress returned an error: %+v", err) + } + + if !reflect.DeepEqual(expectedPartNums, receivedPartNums) { + t.Errorf("Received part numbers do not match expected."+ + "\nexpected: %v\nreceived: %v", expectedPartNums, receivedPartNums) + } + + // Check that the round ID is not the map + partNums, exists := st.inProgressTransfers.list[rid] + if exists { + t.Errorf("Part numbers for round %d found: %v", rid, partNums) + } + + // Check that the list is empty + if len(st.inProgressTransfers.list) != 0 { + t.Errorf("Extra items in in-progress list."+ + "\nexpected: %d\nreceived: %d", 0, len(st.inProgressTransfers.list)) + } +} + +// Tests that SentTransfer.FinishTransfer removes the parts from the in-progress +// list and moved them to the finished list. +func TestSentTransfer_FinishTransfer(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + + rid := id.Round(5) + expectedPartNums := []uint16{1, 2, 3} + + // Add parts to the in-progress list + err, _ := st.SetInProgress(rid, expectedPartNums...) + if err != nil { + t.Errorf("Failed to add parts to in-progress list: %+v", err) + } + + // Move transfers to the finished list + err = st.FinishTransfer(rid) + if err != nil { + t.Errorf("FinishTransfer returned an error: %+v", err) + } + + // Check that the round ID is not in the in-progress map + _, exists := st.inProgressTransfers.list[rid] + if exists { + t.Errorf("Found parts for round %d that should not be in map.", rid) + } + + // Check that the round ID is in the finished map + partNums, exists := st.finishedTransfers.list[rid] + if !exists { + t.Errorf("Part numbers for round %d not found.", rid) + } + + // Check that the returned part numbers are correct + if !reflect.DeepEqual(expectedPartNums, partNums) { + t.Errorf("Received part numbers do not match expected."+ + "\nexpected: %+v\nreceived: %+v", expectedPartNums, partNums) + } + + // Check that only one item was added to the list + if len(st.finishedTransfers.list) > 1 { + t.Errorf("Extra items in finished list."+ + "\nexpected: %d\nreceived: %d", 1, len(st.finishedTransfers.list)) + } +} + +// Error path: tests that SentTransfer.FinishTransfer returns the expected error +// when the round ID is found in the in-progress map. +func TestSentTransfer_FinishTransfer_NoRoundErr(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + + rid := id.Round(5) + expectedErr := fmt.Sprintf(noPartsForRoundErr, rid) + + // Move transfers to the finished list + err := st.FinishTransfer(rid) + if err == nil || err.Error() != expectedErr { + t.Errorf("Did not get expected error when round ID not in in-progress "+ + "map.\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Function Testing // +//////////////////////////////////////////////////////////////////////////////// + +// Tests that loadSentTransfer returns a SentTransfer that matches the original +// object in memory. +func Test_loadSentTransfer(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + tid, expectedST := newRandomSentTransfer(16, 24, kv, t) + err, _ := expectedST.SetInProgress(5, 3, 4, 5) + if err != nil { + t.Errorf("Failed to add parts to in-progress transfer: %+v", err) + } + err, _ = expectedST.SetInProgress(10, 10, 11, 12) + if err != nil { + t.Errorf("Failed to add parts to in-progress transfer: %+v", err) + } + + err = expectedST.FinishTransfer(10) + if err != nil { + t.Errorf("Failed to move parts to finished transfer: %+v", err) + } + + loadedST, err := loadSentTransfer(tid, kv) + if err != nil { + t.Errorf("loadSentTransfer returned an error: %+v", err) + } + + loadedST.progressCallbacks = expectedST.progressCallbacks + + if !reflect.DeepEqual(expectedST, loadedST) { + t.Errorf("Loaded SentTransfer does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expectedST, loadedST) + } +} + +// Error path: tests that loadSentTransfer returns the expected error when no +// transfer with the given ID exists in storage. +func Test_loadSentTransfer_LoadInfoError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + tid := ftCrypto.UnmarshalTransferID([]byte("invalidTransferID")) + + expectedErr := strings.Split(loadSentStoreErr, "%")[0] + _, err := loadSentTransfer(tid, kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadSentTransfer did not return the expected error when no "+ + "transfer with the ID %s exists in storage."+ + "\nexpected: %s\nreceived: %+v", tid, expectedErr, err) + } +} + +// Error path: tests that loadSentTransfer returns the expected error when the +// fingerprint state vector was deleted from storage. +func Test_loadSentTransfer_LoadStateVectorError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + tid, st := newRandomSentTransfer(16, 24, kv, t) + + // Delete the fingerprint state vector from storage + err := st.fpVector.Delete() + if err != nil { + t.Errorf("Failed to delete the fingerprint vector: %+v", err) + } + + expectedErr := strings.Split(loadSentFpVectorErr, "%")[0] + _, err = loadSentTransfer(tid, kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadSentTransfer did not return the expected error when "+ + "the fingerprint vector was delete from storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: tests that loadSentTransfer returns the expected error when the +// part store was deleted from storage. +func Test_loadSentTransfer_LoadPartStoreError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + tid, st := newRandomSentTransfer(16, 24, kv, t) + + // Delete the part store from storage + err := st.sentParts.delete() + if err != nil { + t.Errorf("Failed to delete the part store: %+v", err) + } + + expectedErr := strings.Split(loadSentPartStoreErr, "%")[0] + _, err = loadSentTransfer(tid, kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadSentTransfer did not return the expected error when "+ + "the part store was delete from storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: tests that loadSentTransfer returns the expected error when the +// in-progress transfers bundle was deleted from storage. +func Test_loadSentTransfer_LoadInProgressTransfersError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + tid, st := newRandomSentTransfer(16, 24, kv, t) + + // Delete the in-progress transfers bundle from storage + err := st.inProgressTransfers.delete() + if err != nil { + t.Errorf("Failed to delete the in-progress transfers bundle: %+v", err) + } + + expectedErr := strings.Split(loadInProgressTransfersErr, "%")[0] + _, err = loadSentTransfer(tid, kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadSentTransfer did not return the expected error when "+ + "the in-progress transfers bundle was delete from storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: tests that loadSentTransfer returns the expected error when the +// finished transfers bundle was deleted from storage. +func Test_loadSentTransfer_LoadFinishedTransfersError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + tid, st := newRandomSentTransfer(16, 24, kv, t) + + // Delete the finished transfers bundle from storage + err := st.finishedTransfers.delete() + if err != nil { + t.Errorf("Failed to delete the finished transfers bundle: %+v", err) + } + + expectedErr := strings.Split(loadFinishedTransfersErr, "%")[0] + _, err = loadSentTransfer(tid, kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadSentTransfer did not return the expected error when "+ + "the finished transfers bundle was delete from storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that SentTransfer.saveInfo saves the expected data to storage. +func TestSentTransfer_saveInfo(t *testing.T) { + st := &SentTransfer{ + key: ftCrypto.UnmarshalTransferKey([]byte("key")), + numParts: 16, + kv: versioned.NewKV(make(ekv.Memstore)), + } + + err := st.saveInfo() + if err != nil { + t.Errorf("saveInfo returned an error: %+v", err) + } + + vo, err := st.kv.Get(sentTransferKey, sentTransferVersion) + if err != nil { + t.Errorf("Failed to load SentTransfer from storage: %+v", err) + } + + if !bytes.Equal(st.marshal(), vo.Data) { + t.Errorf("Marshalled data loaded from storage does not match expected."+ + "\nexpected: %+v\nreceived: %+v", st.marshal(), vo.Data) + } +} + +// Tests that SentTransfer.loadInfo loads a saved SentTransfer from storage. +func TestSentTransfer_loadInfo(t *testing.T) { + st := &SentTransfer{ + recipient: id.NewIdFromString("recipient", id.User, t), + key: ftCrypto.UnmarshalTransferKey([]byte("key")), + numParts: 16, + kv: versioned.NewKV(make(ekv.Memstore)), + } + + err := st.saveInfo() + if err != nil { + t.Errorf("failed to save new SentTransfer to storage: %+v", err) + } + + loadedST := &SentTransfer{kv: st.kv} + err = loadedST.loadInfo() + if err != nil { + t.Errorf("load returned an error: %+v", err) + } + + if !reflect.DeepEqual(st, loadedST) { + t.Errorf("Loaded SentTransfer does not match expected."+ + "\nexpected: %+v\nreceived: %+v", st, loadedST) + } +} + +// Error path: tests that SentTransfer.loadInfo returns an error when there is +// no object in storage to load +func TestSentTransfer_loadInfo_Error(t *testing.T) { + loadedST := &SentTransfer{kv: versioned.NewKV(make(ekv.Memstore))} + err := loadedST.loadInfo() + if err == nil { + t.Errorf("Loaded object that should not be in storage: %+v", err) + } +} + +// Tests that SentTransfer.delete removes all data from storage. +func TestSentTransfer_delete(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + + // Add in-progress transfers + err, _ := st.SetInProgress(5, 3, 4, 5) + if err != nil { + t.Errorf("Failed to add parts to in-progress transfer: %+v", err) + } + err, _ = st.SetInProgress(10, 10, 11, 12) + if err != nil { + t.Errorf("Failed to add parts to in-progress transfer: %+v", err) + } + + // Mark last in-progress transfer and finished + err = st.FinishTransfer(10) + if err != nil { + t.Errorf("Failed to move parts to finished transfer: %+v", err) + } + + // Delete everything from storage + err = st.delete() + if err != nil { + t.Errorf("delete returned an error: %+v", err) + } + + // Check that the SentTransfer info was deleted + err = st.loadInfo() + if err == nil { + t.Error("Successfully loaded SentTransfer info from storage when it " + + "should have been deleted.") + } + + // Check that the parts store were deleted + _, err = loadPartStore(st.kv) + if err == nil { + t.Error("Successfully loaded file parts from storage when it should " + + "have been deleted.") + } + + // Check that the in-progress transfers were deleted + _, err = loadTransferredBundle(inProgressKey, st.kv) + if err == nil { + t.Error("Successfully loaded in-progress transfers from storage when " + + "it should have been deleted.") + } + + // Check that the finished transfers were deleted + _, err = loadTransferredBundle(finishedKey, st.kv) + if err == nil { + t.Error("Successfully loaded finished transfers from storage when " + + "it should have been deleted.") + } + + // Check that the fingerprint vector was deleted + _, err = utility.LoadStateVector(st.kv, sentFpVectorKey) + if err == nil { + t.Error("Successfully loaded fingerprint vector from storage when it " + + "should have been deleted.") + } +} + +// Tests that SentTransfer.deleteInfo removes the saved SentTransfer data from +// storage. +func TestSentTransfer_deleteInfo(t *testing.T) { + st := &SentTransfer{ + key: ftCrypto.UnmarshalTransferKey([]byte("key")), + numParts: 16, + kv: versioned.NewKV(make(ekv.Memstore)), + } + + // Save from storage + err := st.saveInfo() + if err != nil { + t.Errorf("failed to save new SentTransfer to storage: %+v", err) + } + + // Delete from storage + err = st.deleteInfo() + if err != nil { + t.Errorf("deleteInfo returned an error: %+v", err) + } + + // Make sure deleted object cannot be loaded from storage + _, err = st.kv.Get(sentTransferKey, sentTransferVersion) + if err == nil { + t.Error("Loaded object that should be deleted from storage.") + } +} + +// Tests that a SentTransfer marshalled with SentTransfer.marshal and then +// unmarshalled with unmarshalSentTransfer matches the original. +func TestSentTransfer_marshal_unmarshalSentTransfer(t *testing.T) { + st := &SentTransfer{ + recipient: id.NewIdFromString("testRecipient", id.User, t), + key: ftCrypto.UnmarshalTransferKey([]byte("key")), + numParts: 16, + numFps: 20, + } + + marshaledData := st.marshal() + + recipient, key, numParts, numFps := unmarshalSentTransfer(marshaledData) + + if !st.recipient.Cmp(recipient) { + t.Errorf("Failed to get recipient ID.\nexpected: %s\nreceived: %s", + st.recipient, recipient) + } + + if st.key != key { + t.Errorf("Failed to get expected key.\nexpected: %s\nreceived: %s", + st.key, key) + } + + if st.numParts != numParts { + t.Errorf("Failed to get expected number of parts."+ + "\nexpected: %d\nreceived: %d", st.numParts, numParts) + } + + if st.numFps != numFps { + t.Errorf("Failed to get expected number of fingerprints."+ + "\nexpected: %d\nreceived: %d", st.numFps, numFps) + } +} + +// Consistency test: tests that makeSentTransferPrefix returns the expected +// prefixes for the provided transfer IDs. +func Test_makeSentTransferPrefix_Consistency(t *testing.T) { + prng := NewPrng(42) + expectedPrefixes := []string{ + "FileTransferSentTransferStoreU4x/lrFkvxuXu59LtHLon1sUhPJSCcnZND6SugndnVI=", + "FileTransferSentTransferStore39ebTXZCm2F6DJ+fDTulWwzA1hRMiIU1hBrL4HCbB1g=", + "FileTransferSentTransferStoreCD9h03W8ArQd9PkZKeGP2p5vguVOdI6B555LvW/jTNw=", + "FileTransferSentTransferStoreuoQ+6NY+jE/+HOvqVG2PrBPdGqwEzi6ih3xVec+ix44=", + "FileTransferSentTransferStoreGwuvrogbgqdREIpC7TyQPKpDRlp4YgYWl4rtDOPGxPM=", + "FileTransferSentTransferStorernvD4ElbVxL+/b4MECiH4QDazS2IX2kstgfaAKEcHHA=", + "FileTransferSentTransferStoreceeWotwtwlpbdLLhKXBeJz8FySMmgo4rBW44F2WOEGE=", + "FileTransferSentTransferStoreSYlH/fNEQQ7UwRYCP6jjV2tv7Sf/iXS6wMr9mtBWkrE=", + "FileTransferSentTransferStoreNhnnOJZN/ceejVNDc2Yc/WbXT+weG4lJGrcjbkt1IWI=", + "FileTransferSentTransferStorekM8r60LDyicyhWDxqsBnzqbov0bUqytGgEAsX7KCDog=", + } + + for i, expected := range expectedPrefixes { + tid, _ := ftCrypto.NewTransferID(prng) + prefix := makeSentTransferPrefix(tid) + + if expected != prefix { + t.Errorf("New SentTransfer prefix does not match expected (%d)."+ + "\nexpected: %s\nreceived: %s", i, expected, prefix) + } + } +} + +// newRandomSentTransfer generates a new SentTransfer with random data. +func newRandomSentTransfer(numParts, numFps uint16, kv *versioned.KV, + t *testing.T) (ftCrypto.TransferID, *SentTransfer) { + // Generate new PRNG with the seed generated by multiplying the pointer for + // numParts with the current UNIX time in nanoseconds + seed, _ := strconv.ParseInt(fmt.Sprintf("%d", &numParts), 10, 64) + seed *= netTime.Now().UnixNano() + prng := NewPrng(seed) + + recipient, _ := id.NewRandomID(prng, id.User) + tid, _ := ftCrypto.NewTransferID(prng) + key, _ := ftCrypto.NewTransferKey(prng) + parts := make([][]byte, numParts) + for i := uint16(0); i < numParts; i++ { + parts[i] = make([]byte, 16) + _, err := prng.Read(parts[i]) + if err != nil { + t.Errorf("Failed to generate random part: %+v", err) + } + } + + st, err := NewSentTransfer(recipient, tid, key, parts, numFps, nil, 0, kv) + if err != nil { + t.Errorf("Failed to create new SentTansfer: %+v", err) + } + + return tid, st +} + +// checkSentProgress compares the output of SentTransfer.GetProgress to expected +// values. +func checkSentProgress(completed bool, sent, arrived, total uint16, + eCompleted bool, eSent, eArrived, eTotal uint16) error { + if eCompleted != completed || eSent != sent || eArrived != arrived || + eTotal != total { + return errors.Errorf("Returned progress does not match expected."+ + "\n completed sent arrived total"+ + "\nexpected: %5t %3d %3d %3d"+ + "\nreceived: %5t %3d %3d %3d", + eCompleted, eSent, eArrived, eTotal, + completed, sent, arrived, total) + } + + return nil +} diff --git a/storage/fileTransfer/transferredBundle.go b/storage/fileTransfer/transferredBundle.go new file mode 100644 index 0000000000000000000000000000000000000000..58e0e72f787d64dcda024ce90b1799c388c5ae68 --- /dev/null +++ b/storage/fileTransfer/transferredBundle.go @@ -0,0 +1,190 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "encoding/binary" + "github.com/pkg/errors" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" +) + +// Storage keys and versions. +const ( + transferredBundleVersion = 0 + transferredBundleKey = "FileTransferBundle" + inProgressKey = "inProgressTransfers" + finishedKey = "finishedTransfers" +) + +// Error messages. +const ( + loadTransferredBundleErr = "failed to get transferredBundle from storage: %+v" +) + +// transferredBundle lists the file parts sent per round ID. +type transferredBundle struct { + list map[id.Round][]uint16 + numParts uint16 + key string + kv *versioned.KV +} + +// newTransferredBundle generates a new transferredBundle and saves it to +// storage. +func newTransferredBundle(key string, kv *versioned.KV) (*transferredBundle, error) { + tb := &transferredBundle{ + list: make(map[id.Round][]uint16), + key: key, + kv: kv, + } + + return tb, tb.save() +} + +// addPartNums adds a round to the map with the specified part numbers. +func (tb *transferredBundle) addPartNums(rid id.Round, partNums ...uint16) error { + tb.list[rid] = append(tb.list[rid], partNums...) + + // Increment number of parts + tb.numParts += uint16(len(partNums)) + + return tb.save() +} + +// getPartNums returns the list of part numbers for the given round ID. If there +// are no part numbers for the round ID, then it returns false. +func (tb *transferredBundle) getPartNums(rid id.Round) ([]uint16, bool) { + partNums, exists := tb.list[rid] + return partNums, exists +} + +// getNumParts returns the number of file parts stored in list. +func (tb *transferredBundle) getNumParts() uint16 { + return tb.numParts +} + +// deletePartNums deletes the round and its part numbers from the map. +func (tb *transferredBundle) deletePartNums(rid id.Round) error { + // Decrement number of parts + tb.numParts -= uint16(len(tb.list[rid])) + + // Remove from list + delete(tb.list, rid) + + // Remove from storage + return tb.save() +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// loadTransferredBundle loads a transferredBundle from storage. +func loadTransferredBundle(key string, kv *versioned.KV) (*transferredBundle, error) { + vo, err := kv.Get(makeTransferredBundleKey(key), transferredBundleVersion) + if err != nil { + return nil, errors.Errorf(loadTransferredBundleErr, err) + } + + tb := &transferredBundle{ + list: make(map[id.Round][]uint16), + key: key, + kv: kv, + } + + tb.unmarshal(vo.Data) + + return tb, nil +} + +// save stores the transferredBundle to storage. +func (tb *transferredBundle) save() error { + obj := &versioned.Object{ + Version: transferredBundleVersion, + Timestamp: netTime.Now(), + Data: tb.marshal(), + } + + return tb.kv.Set( + makeTransferredBundleKey(tb.key), transferredBundleVersion, obj) +} + +// delete remove the transferredBundle from storage. +func (tb *transferredBundle) delete() error { + return tb.kv.Delete( + makeTransferredBundleKey(tb.key), transferredBundleVersion) +} + +// marshal serialises the map into a byte slice. +func (tb *transferredBundle) marshal() []byte { + // Create buffer + buff := bytes.NewBuffer(nil) + + for rid, partNums := range tb.list { + // Write round ID to buffer + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(rid)) + buff.Write(b) + + // Write number of part numbers to buffer + b = make([]byte, 2) + binary.LittleEndian.PutUint16(b, uint16(len(partNums))) + buff.Write(b) + + // Write list of part numbers to buffer + for _, partNum := range partNums { + b = make([]byte, 2) + binary.LittleEndian.PutUint16(b, partNum) + buff.Write(b) + } + } + + return buff.Bytes() +} + +// unmarshal deserializes the byte slice into the transferredBundle. +func (tb *transferredBundle) unmarshal(b []byte) { + buff := bytes.NewBuffer(b) + + // Iterate over all map entries + for n := buff.Next(8); len(n) == 8; n = buff.Next(8) { + // Get the round ID from the first 8 bytes + rid := id.Round(binary.LittleEndian.Uint64(n)) + + // Get number of part numbers listed + partNumsLen := binary.LittleEndian.Uint16(buff.Next(2)) + + // Increment number of parts + tb.numParts += partNumsLen + + // Initialize part number list to the correct size + tb.list[rid] = make([]uint16, 0, partNumsLen) + + // Add all part numbers to list + for i := uint16(0); i < partNumsLen; i++ { + partNum := binary.LittleEndian.Uint16(buff.Next(2)) + tb.list[rid] = append(tb.list[rid], partNum) + } + } +} + +// makeTransferredBundleKey concatenates a unique key with a constant to create +// a key for saving a transferredBundle to storage. +func makeTransferredBundleKey(key string) string { + return transferredBundleKey + key +} diff --git a/storage/fileTransfer/transferredBundle_test.go b/storage/fileTransfer/transferredBundle_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e5198d493738565618af2122b0a1268b6aff2386 --- /dev/null +++ b/storage/fileTransfer/transferredBundle_test.go @@ -0,0 +1,317 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "bytes" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/primitives/id" + "math/rand" + "reflect" + "sort" + "strings" + "testing" +) + +// Tests that newTransferredBundle returns the expected transferredBundle and +// that it can be loaded from storage. +func Test_newTransferredBundle(t *testing.T) { + expectedTB := &transferredBundle{ + list: make(map[id.Round][]uint16), + key: "testKey1", + kv: versioned.NewKV(make(ekv.Memstore)), + } + + tb, err := newTransferredBundle(expectedTB.key, expectedTB.kv) + if err != nil { + t.Errorf("newTransferredBundle produced an error: %+v", err) + } + + if !reflect.DeepEqual(expectedTB, tb) { + t.Errorf("New transferredBundle does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expectedTB, tb) + } + + _, err = expectedTB.kv.Get( + makeTransferredBundleKey(expectedTB.key), transferredBundleVersion) + if err != nil { + t.Errorf("Failed to load transferredBundle from storage: %+v", err) + } +} + +// Tests that transferredBundle.addPartNums adds the part numbers for the round +// ID to the map correctly. +func Test_transferredBundle_addPartNums(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + key := "testKey" + tb, err := newTransferredBundle(key, kv) + if err != nil { + t.Errorf("Failed to create new transferredBundle: %+v", err) + } + + rid := id.Round(10) + expectedPartNums := []uint16{5, 128, 23, 1} + + err = tb.addPartNums(rid, expectedPartNums...) + if err != nil { + t.Errorf("addPartNums returned an error: %+v", err) + } + + partNums, exists := tb.list[rid] + if !exists || !reflect.DeepEqual(expectedPartNums, partNums) { + t.Errorf("Part numbers in memory does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expectedPartNums, partNums) + } +} + +// Tests that transferredBundle.getPartNums returns the expected part numbers +func Test_transferredBundle_getPartNums(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + key := "testKey" + tb, err := newTransferredBundle(key, kv) + if err != nil { + t.Errorf("Failed to create new transferredBundle: %+v", err) + } + + rid := id.Round(10) + expectedPartNums := []uint16{5, 128, 23, 1} + + err = tb.addPartNums(rid, expectedPartNums...) + if err != nil { + t.Errorf("failed to add part numbers: %+v", err) + } + + partNums, exists := tb.getPartNums(rid) + if !exists || !reflect.DeepEqual(expectedPartNums, partNums) { + t.Errorf("Part numbers in memory does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expectedPartNums, partNums) + } +} + +// Tests that transferredBundle.getNumParts returns the correct number of parts +// after parts are added and removed from the list. +func Test_transferredBundle_getNumParts(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + tb, err := newTransferredBundle("testKey", versioned.NewKV(make(ekv.Memstore))) + if err != nil { + t.Errorf("Failed to create new transferredBundle: %+v", err) + } + + // Add 10 random lists of part numbers to the map + var expectedNumParts uint16 + for i := 0; i < 10; i++ { + partNums := make([]uint16, prng.Intn(16)) + for j := range partNums { + partNums[j] = uint16(prng.Uint32()) + } + + // Add number of parts for odd numbered rounds + if i%2 == 1 { + expectedNumParts += uint16(len(partNums)) + } + + err = tb.addPartNums(id.Round(i), partNums...) + if err != nil { + t.Errorf("Failed to add part #%d: %+v", i, err) + } + } + + // Delete num parts for even numbered rounds + for i := 0; i < 10; i += 2 { + err = tb.deletePartNums(id.Round(i)) + if err != nil { + t.Errorf("Failed to delete part #%d: %+v", i, err) + } + } + + // Get number of parts + receivedNumParts := tb.getNumParts() + + if expectedNumParts != receivedNumParts { + t.Errorf("Failed to get expected number of parts."+ + "\nexpected: %d\nreceived: %d", expectedNumParts, receivedNumParts) + } +} + +// Tests that transferredBundle.deletePartNums deletes the part number from +// memory. +func Test_transferredBundle_deletePartNums(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + key := "testKey" + tb, err := newTransferredBundle(key, kv) + if err != nil { + t.Errorf("Failed to create new transferredBundle: %+v", err) + } + + rid := id.Round(10) + expectedPartNums := []uint16{5, 128, 23, 1} + + err = tb.addPartNums(rid, expectedPartNums...) + if err != nil { + t.Errorf("failed to add part numbers: %+v", err) + } + + err = tb.deletePartNums(rid) + if err != nil { + t.Errorf("deletePartNums returned an error: %+v", err) + } + + _, exists := tb.list[rid] + if exists { + t.Error("Found part numbers that should have been deleted.") + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// Tests that loadTransferredBundle returns a transferredBundle from storage +// that matches the original in memory. +func Test_loadTransferredBundle(t *testing.T) { + expectedTB := &transferredBundle{ + list: map[id.Round][]uint16{ + 1: {1, 2, 3}, + 2: {4, 5, 6}, + 3: {7, 8, 9}, + }, + numParts: 9, + key: "testKey2", + kv: versioned.NewKV(make(ekv.Memstore)), + } + + err := expectedTB.save() + if err != nil { + t.Errorf("Failed to save transferredBundle to storage: %+v", err) + } + + tb, err := loadTransferredBundle(expectedTB.key, expectedTB.kv) + if err != nil { + t.Errorf("loadTransferredBundle returned an error: %+v", err) + } + + if !reflect.DeepEqual(expectedTB, tb) { + t.Errorf("transferredBundle loaded from storage does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expectedTB, tb) + } +} + +// Error path: tests that loadTransferredBundle returns the expected error when +// there is no transferredBundle in storage. +func Test_loadTransferredBundle_LoadError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + expectedErr := strings.Split(loadTransferredBundleErr, "%")[0] + + _, err := loadTransferredBundle("", kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadTransferredBundle did not returned the expected error "+ + "when no transferredBundle exists in storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that transferredBundle.save saves the correct data to storage. +func Test_transferredBundle_save(t *testing.T) { + tb := &transferredBundle{ + list: map[id.Round][]uint16{ + 1: {1, 2, 3}, + 2: {4, 5, 6}, + 3: {7, 8, 9}, + }, + key: "testKey3", + kv: versioned.NewKV(make(ekv.Memstore)), + } + expectedData := tb.marshal() + + err := tb.save() + if err != nil { + t.Errorf("save returned an error: %+v", err) + } + + vo, err := tb.kv.Get( + makeTransferredBundleKey(tb.key), transferredBundleVersion) + if err != nil { + t.Errorf("Failed to load transferredBundle from storage: %+v", err) + } + + sort.SliceStable(expectedData, func(i, j int) bool { + return expectedData[i] > expectedData[j] + }) + + sort.SliceStable(vo.Data, func(i, j int) bool { + return vo.Data[i] > vo.Data[j] + }) + + if !bytes.Equal(expectedData, vo.Data) { + t.Errorf("Loaded transferredBundle does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expectedData, vo.Data) + } +} + +// Tests that transferredBundle.delete removes a saved transferredBundle from +// storage. +func Test_transferredBundle_delete(t *testing.T) { + tb := &transferredBundle{ + list: map[id.Round][]uint16{ + 1: {1, 2, 3}, + 2: {4, 5, 6}, + 3: {7, 8, 9}, + }, + key: "testKey4", + kv: versioned.NewKV(make(ekv.Memstore)), + } + + err := tb.save() + if err != nil { + t.Errorf("Failed to save transferredBundle to storage: %+v", err) + } + + err = tb.delete() + if err != nil { + t.Errorf("delete returned an error: %+v", err) + } + + _, err = tb.kv.Get(makeTransferredBundleKey(tb.key), transferredBundleVersion) + if err == nil { + t.Error("Read transferredBundleVersion from storage when it should" + + "have been deleted.") + } +} + +// Tests that a transferredBundle that is marshalled via +// transferredBundle.marshal and unmarshalled via transferredBundle.unmarshal +// matches the original. +func Test_transferredBundle_marshal_unmarshal(t *testing.T) { + expectedTB := &transferredBundle{ + list: map[id.Round][]uint16{ + 1: {1, 2, 3}, + 2: {4, 5, 6}, + 3: {7, 8, 9}, + }, + numParts: 9, + } + + b := expectedTB.marshal() + + tb := &transferredBundle{list: make(map[id.Round][]uint16)} + + tb.unmarshal(b) + + if !reflect.DeepEqual(expectedTB, tb) { + t.Errorf("Failed to marshal and unmarshal transferredBundle into original."+ + "\nexpected: %+v\nreceived: %+v", expectedTB, tb) + } +} diff --git a/storage/fileTransfer/utils_test.go b/storage/fileTransfer/utils_test.go new file mode 100644 index 0000000000000000000000000000000000000000..cd59306b455bc2533b1f264ab73687a26db69b76 --- /dev/null +++ b/storage/fileTransfer/utils_test.go @@ -0,0 +1,34 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "github.com/pkg/errors" + "gitlab.com/xx_network/crypto/csprng" + "io" + "math/rand" +) + +//////////////////////////////////////////////////////////////////////////////// +// PRNG // +//////////////////////////////////////////////////////////////////////////////// + +// Prng is a PRNG that satisfies the csprng.Source interface. +type Prng struct{ prng io.Reader } + +func NewPrng(seed int64) csprng.Source { return &Prng{rand.New(rand.NewSource(seed))} } +func (s *Prng) Read(b []byte) (int, error) { return s.prng.Read(b) } +func (s *Prng) SetSeed([]byte) error { return nil } + +// PrngErr is a PRNG that satisfies the csprng.Source interface. However, it +// always returns an error +type PrngErr struct{} + +func NewPrngErr() csprng.Source { return &PrngErr{} } +func (s *PrngErr) Read([]byte) (int, error) { return 0, errors.New("ReadFailure") } +func (s *PrngErr) SetSeed([]byte) error { return errors.New("SetSeedFailure") } diff --git a/storage/utility/stateVector.go b/storage/utility/stateVector.go new file mode 100644 index 0000000000000000000000000000000000000000..3c458145aeef57b6dd5b4ce8ba3f38c4d6aa5cad --- /dev/null +++ b/storage/utility/stateVector.go @@ -0,0 +1,418 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package utility + +import ( + "encoding/json" + "fmt" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/xx_network/primitives/netTime" + "sync" + "testing" +) + +// Storage key and version. +const ( + stateVectorKey = "stateVector" + currentStateVectorVersion = 0 +) + +// Error messages. +const ( + saveUsedKeyErr = "Failed to save %s after marking key %d as used: %+v" + saveNextErr = "failed to save %s after getting next available key: %+v" + noKeysErr = "all keys used" + loadUnmarshalErr = "failed to unmarshal from storage: %+v" + testInterfaceErr = "%s can only be used for testing." +) + +// StateVector stores a list of a set number of items and their binary state. +// It is storage backed. +type StateVector struct { + // Bitfield for key states; if a key is unused, then it is set to 0; + // otherwise, it is used/not available and is set 1 + vect []uint64 + + firstAvailable uint32 // Sequentially, the first unused key (equal to 0) + numKeys uint32 // Total number of keys + numAvailable uint32 // Number of unused keys + + key string // Unique string used to save/load object from storage + kv *versioned.KV + mux sync.RWMutex +} + +// NewStateVector generates a new StateVector with the specified number of keys. +func NewStateVector(kv *versioned.KV, key string, numKeys uint32) ( + *StateVector, error) { + + // Calculate the number of 64-bit blocks needed to store numKeys + numBlocks := (numKeys + 63) / 64 + + sv := &StateVector{ + vect: make([]uint64, numBlocks), + firstAvailable: 0, + numKeys: numKeys, + numAvailable: numKeys, + key: makeStateVectorKey(key), + kv: kv, + } + + return sv, sv.save() +} + +// Use marks the key as used (sets it to 1). +func (sv *StateVector) Use(keyNum uint32) { + sv.mux.Lock() + defer sv.mux.Unlock() + + // Mark the key as used + sv.use(keyNum) + + // Save changes to storage + if err := sv.save(); err != nil { + jww.FATAL.Printf(saveUsedKeyErr, sv, keyNum, err) + } +} + +// use marks the key as used (sets it to 1). It is not thread-safe and does not +// save to storage. +func (sv *StateVector) use(keyNum uint32) { + // If the key is already used, then exit + if sv.used(keyNum) { + return + } + + // Calculate block and position of the key + block, pos := getBlockAndPos(keyNum) + + // Set the key to used (1) + sv.vect[block] |= 1 << pos + + // Decrement number available unused keys + sv.numAvailable-- + + // If this is the first available key, then advanced to the next available + if keyNum == sv.firstAvailable { + sv.nextAvailable() + } +} + +// Used returns true if the key is used (set to 1) or false if the key is unused +// (set to 0). +func (sv *StateVector) Used(keyNum uint32) bool { + sv.mux.RLock() + defer sv.mux.RUnlock() + + return sv.used(keyNum) +} + +// used determines if the key is used or unused. This function is not thread- +// safe. +func (sv *StateVector) used(keyNum uint32) bool { + // Calculate block and position of the keyNum + block, pos := getBlockAndPos(keyNum) + + return (sv.vect[block]>>pos)&1 == 1 +} + +// Next marks the first unused key as used. An error is returned if all keys are +// used or if the save to storage fails. +func (sv *StateVector) Next() (uint32, error) { + sv.mux.Lock() + defer sv.mux.Unlock() + + // Return an error if all keys are used + if sv.firstAvailable >= sv.numKeys { + return sv.numKeys, errors.New(noKeysErr) + } + + // Mark the first available as used (which also advanced firstAvailable) + nextKey := sv.firstAvailable + sv.use(nextKey) + + // Save to storage + if err := sv.save(); err != nil { + return 0, errors.Errorf(saveNextErr, sv, err) + } + + return nextKey, nil +} + +// nextAvailable finds the next unused key and sets it as the firstAvailable. It +// is not thread-safe and does not save to storage. +func (sv *StateVector) nextAvailable() { + // Add one to start at the next position + pos := sv.firstAvailable + 1 + block := pos / 64 + + // Loop through each key until the first unused key is found + for block < uint32(len(sv.vect)) && (sv.vect[block]>>(pos%64))&1 == 1 { + pos++ + block = pos / 64 + } + + sv.firstAvailable = pos +} + +// GetNumAvailable returns the number of unused keys. +func (sv *StateVector) GetNumAvailable() uint32 { + sv.mux.RLock() + defer sv.mux.RUnlock() + return sv.numAvailable +} + +// GetNumUsed returns the number of used keys. +func (sv *StateVector) GetNumUsed() uint32 { + sv.mux.RLock() + defer sv.mux.RUnlock() + return sv.numKeys - sv.numAvailable +} + +// GetNumKeys returns the total number of keys. +func (sv *StateVector) GetNumKeys() uint32 { + sv.mux.RLock() + defer sv.mux.RUnlock() + return sv.numKeys +} + +// GetUnusedKeyNums returns a list of all unused keys. +func (sv *StateVector) GetUnusedKeyNums() []uint32 { + sv.mux.RLock() + defer sv.mux.RUnlock() + + // Initialise list with capacity set to number of unused keys + keyNums := make([]uint32, 0, sv.numAvailable) + + // Loop through each key and add any unused to the list + for keyNum := sv.firstAvailable; keyNum < sv.numKeys; keyNum++ { + if !sv.used(keyNum) { + keyNums = append(keyNums, keyNum) + } + } + + return keyNums +} + +// GetUsedKeyNums returns a list of all used keys. +func (sv *StateVector) GetUsedKeyNums() []uint32 { + sv.mux.RLock() + defer sv.mux.RUnlock() + + // Initialise list with capacity set to the number of used keys + keyNums := make([]uint32, 0, sv.numKeys-sv.numAvailable) + + // Loop through each key and add any used key numbers to the list + for keyNum := uint32(0); keyNum < sv.numKeys; keyNum++ { + if sv.used(keyNum) { + keyNums = append(keyNums, keyNum) + } + } + + return keyNums +} + +// getBlockAndPos calculates the block index and the position within that block +// of a key number. +func getBlockAndPos(keyNum uint32) (block, pos uint32) { + block = keyNum / 64 + pos = keyNum % 64 + + return block, pos +} + +// String returns a unique string representing the StateVector. This functions +// satisfies the fmt.Stringer interface. +func (sv *StateVector) String() string { + return "stateVector: " + sv.key +} + +// GoString returns the fields of the StateVector. This functions satisfies the +// fmt.GoStringer interface. +func (sv *StateVector) GoString() string { + return fmt.Sprintf( + "{vect:%v firstAvailable:%d numKeys:%d numAvailable:%d key:%s kv:%p}", + sv.vect, sv.firstAvailable, sv.numKeys, sv.numAvailable, sv.key, sv.kv) +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// stateVectorDisk is used to save the data from a StateVector so that it can be +// JSON marshalled. +type stateVectorDisk struct { + Vect []uint64 + FirstAvailable uint32 + NumKeys uint32 + NumAvailable uint32 +} + +// LoadStateVector loads a StateVector with the specified key from the given +// versioned storage. +func LoadStateVector(kv *versioned.KV, key string) (*StateVector, error) { + sv := &StateVector{ + key: makeStateVectorKey(key), + kv: kv, + } + + // Load StateVector data from storage + obj, err := kv.Get(sv.key, currentStateVectorVersion) + if err != nil { + return nil, err + } + + // Unmarshal data + err = sv.unmarshal(obj.Data) + if err != nil { + return nil, errors.Errorf(loadUnmarshalErr, err) + } + + return sv, nil +} + +// save stores the StateVector in storage. +func (sv *StateVector) save() error { + // Marshal the StateVector + data, err := sv.marshal() + if err != nil { + return err + } + + // Create the versioned object + obj := versioned.Object{ + Version: currentStateVectorVersion, + Timestamp: netTime.Now(), + Data: data, + } + + return sv.kv.Set(sv.key, currentStateVectorVersion, &obj) +} + +// Delete remove the StateVector from storage. +func (sv *StateVector) Delete() error { + return sv.kv.Delete(sv.key, currentStateVectorVersion) +} + +// marshal serialises the StateVector. +func (sv *StateVector) marshal() ([]byte, error) { + svd := stateVectorDisk{} + + svd.FirstAvailable = sv.firstAvailable + svd.NumKeys = sv.numKeys + svd.NumAvailable = sv.numAvailable + svd.Vect = sv.vect + + return json.Marshal(&svd) +} + +// unmarshal deserializes the byte slice into a StateVector. +func (sv *StateVector) unmarshal(b []byte) error { + var svd stateVectorDisk + err := json.Unmarshal(b, &svd) + if err != nil { + return err + } + + sv.firstAvailable = svd.FirstAvailable + sv.numKeys = svd.NumKeys + sv.numAvailable = svd.NumAvailable + sv.vect = svd.Vect + + return nil +} + +// makeStateVectorKey generates the unique key used to save a StateVector to +// storage. +func makeStateVectorKey(key string) string { + return stateVectorKey + key +} + +//////////////////////////////////////////////////////////////////////////////// +// Testing Functions // +//////////////////////////////////////////////////////////////////////////////// + +// SaveTEST saves the StateVector to storage. This should only be used for +// testing. +func (sv *StateVector) SaveTEST(x interface{}) error { + switch x.(type) { + case *testing.T, *testing.M, *testing.B, *testing.PB: + break + default: + jww.FATAL.Panicf(testInterfaceErr, "SaveTEST") + } + + sv.mux.Lock() + defer sv.mux.Unlock() + + return sv.save() +} + +// SetFirstAvailableTEST sets the firstAvailable. This should only be used for +// testing. +func (sv *StateVector) SetFirstAvailableTEST(keyNum uint32, x interface{}) { + switch x.(type) { + case *testing.T, *testing.M, *testing.B, *testing.PB: + break + default: + jww.FATAL.Panicf(testInterfaceErr, "SetFirstAvailableTEST") + } + + sv.mux.Lock() + defer sv.mux.Unlock() + + sv.firstAvailable = keyNum +} + +// SetNumKeysTEST sets the numKeys. This should only be used for testing. +func (sv *StateVector) SetNumKeysTEST(numKeys uint32, x interface{}) { + switch x.(type) { + case *testing.T, *testing.M, *testing.B, *testing.PB: + break + default: + jww.FATAL.Panicf(testInterfaceErr, "SetNumKeysTEST") + } + + sv.mux.Lock() + defer sv.mux.Unlock() + + sv.numKeys = numKeys +} + +// SetNumAvailableTEST sets the numAvailable. This should only be used for +// testing. +func (sv *StateVector) SetNumAvailableTEST(numAvailable uint32, x interface{}) { + switch x.(type) { + case *testing.T, *testing.M, *testing.B, *testing.PB: + break + default: + jww.FATAL.Panicf(testInterfaceErr, "SetNumAvailableTEST") + } + + sv.mux.Lock() + defer sv.mux.Unlock() + + sv.numAvailable = numAvailable +} + +// SetKvTEST sets the kv. This should only be used for testing. +func (sv *StateVector) SetKvTEST(kv *versioned.KV, x interface{}) { + switch x.(type) { + case *testing.T, *testing.M, *testing.B, *testing.PB: + break + default: + jww.FATAL.Panicf(testInterfaceErr, "SetKvTEST") + } + + sv.mux.Lock() + defer sv.mux.Unlock() + + sv.kv = kv +} diff --git a/storage/utility/stateVector_test.go b/storage/utility/stateVector_test.go new file mode 100644 index 0000000000000000000000000000000000000000..90a637173df59e18e6613684317d931805295e00 --- /dev/null +++ b/storage/utility/stateVector_test.go @@ -0,0 +1,744 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package utility + +import ( + "bytes" + "encoding/base64" + "fmt" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/primitives/netTime" + "math/rand" + "reflect" + "strings" + "testing" +) + +// Tests that NewStateVector creates the expected new StateVector and that it is +// saved to storage. +func TestNewStateVector(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + key := "myTestKey" + numKeys := uint32(275) + expected := &StateVector{ + vect: make([]uint64, (numKeys+63)/64), + firstAvailable: 0, + numKeys: numKeys, + numAvailable: numKeys, + key: makeStateVectorKey(key), + kv: kv, + } + + received, err := NewStateVector(kv, key, numKeys) + if err != nil { + t.Errorf("NewStateVector returned an error: %+v", err) + } + + if !reflect.DeepEqual(expected, received) { + t.Errorf("New StateVector does not match expected."+ + "\nexpected: %#v\nreceived: %#v", expected, received) + } + + _, err = kv.Get(key, currentStateVectorVersion) + if err == nil { + t.Error("New StateVector not saved to storage.") + } +} + +// Tests that StateVector.Use sets the correct keys to used and does not modify +// others keys and that numAvailable is correctly set. +func TestStateVector_Use(t *testing.T) { + sv := newTestStateVector("StateVectorUse", 138, t) + + // Set some keys to used + usedKeys := []uint32{0, 2, 3, 4, 6, 39, 62, 70, 98, 100} + usedKeysMap := make(map[uint32]bool, len(usedKeys)) + for i, keyNum := range usedKeys { + + sv.Use(keyNum) + + // Check if numAvailable is correct + if sv.numAvailable != sv.numKeys-uint32(i+1) { + t.Errorf("numAvailable incorrect (%d).\nexpected: %d\nreceived: %d", + i, sv.numKeys-uint32(i+1), sv.numAvailable) + } + + usedKeysMap[keyNum] = true + } + + // Check all keys for their expected states + for i := uint32(0); i < sv.numKeys; i++ { + if usedKeysMap[i] { + if !sv.Used(i) { + t.Errorf("Key #%d should have been marked used.", i) + } + } else if sv.Used(i) { + t.Errorf("Key #%d should have been marked unused.", i) + } + } + + // Make sure numAvailable is not modified when the key is already used + sv.Use(usedKeys[0]) + if sv.numAvailable != sv.numKeys-uint32(len(usedKeys)) { + t.Errorf("numAvailable incorrect.\nexpected: %d\nreceived: %d", + sv.numKeys-uint32(len(usedKeys)), sv.numAvailable) + } +} + +// Tests StateVector.Used by creating a vector with known used and unused keys +// and making sure it returns the expected state for all keys in the vector. +func TestStateVector_Used(t *testing.T) { + numKeys := uint32(128) + sv := newTestStateVector("StateVectorNext", numKeys, t) + + // Set all keys to used + for i := range sv.vect { + sv.vect[i] = 0xFFFFFFFFFFFFFFFF + } + + // Set some keys to unused + unusedKeys := []uint32{0, 2, 3, 4, 6, 39, 62, 70, 98, 100} + unusedKeysMap := make(map[uint32]bool, len(unusedKeys)) + for _, keyNum := range unusedKeys { + block, pos := getBlockAndPos(keyNum) + sv.vect[block] &= ^(1 << pos) + unusedKeysMap[keyNum] = true + } + + // Check all keys for their expected states + for i := uint32(0); i < numKeys; i++ { + if unusedKeysMap[i] { + if sv.Used(i) { + t.Errorf("Key #%d should have been marked unused.", i) + } + } else if !sv.Used(i) { + t.Errorf("Key #%d should have been marked used.", i) + } + } +} + +// Tests that StateVector.Next returns the expected unused keys and marks them +// as used and that numAvailable is correctly set. +func TestStateVector_Next(t *testing.T) { + numKeys := uint32(128) + sv := newTestStateVector("StateVectorNext", numKeys, t) + + // Set all keys to used + for i := range sv.vect { + sv.vect[i] = 0xFFFFFFFFFFFFFFFF + } + + // Set some keys to unused + unusedKeys := []uint32{0, 2, 3, 4, 6, 39, 62, 70, 98, 100} + for _, keyNum := range unusedKeys { + block, pos := getBlockAndPos(keyNum) + sv.vect[block] &= ^(1 << pos) + } + + sv.numAvailable = uint32(len(unusedKeys)) + sv.firstAvailable = unusedKeys[0] + + // Check that each call of Next returns the expected key and that it is + // marked as used + for i, expected := range unusedKeys { + numKey, err := sv.Next() + if err != nil { + t.Errorf("Next returned an error (%d): %+v", i, err) + } + + if expected != numKey { + t.Errorf("Received key does not match expected."+ + "\nexpected: %d\nreceived: %d", expected, numKey) + } + + if !sv.used(numKey) { + t.Errorf("Key #%d not marked as used.", numKey) + } + + expectedNumAvailable := uint32(len(unusedKeys) - (i + 1)) + if sv.numAvailable != expectedNumAvailable { + t.Errorf("numAvailable incorrectly set.\nexpected: %d\nreceived: %d", + expectedNumAvailable, sv.numAvailable) + } + } + + // One more call should cause an error + _, err := sv.Next() + if err == nil || err.Error() != noKeysErr { + t.Errorf("Next did not return the expected error when no keys are "+ + "available.\nexpected: %s\nreceived: %+v", noKeysErr, err) + } + + // firstAvailable should now be beyond the end of the vector + if sv.firstAvailable < numKeys { + t.Errorf("firstAvailable should be beyond numKeys."+ + "\nfirstAvailable: %d\nnumKeys: %d", + sv.firstAvailable, numKeys) + } +} + +// Tests the StateVector.nextAvailable sets firstAvailable correctly for a known +// key vector. +func TestStateVector_nextAvailable(t *testing.T) { + numKeys := uint32(128) + sv := newTestStateVector("StateVectorNext", numKeys, t) + + // Set all keys to used + for i := range sv.vect { + sv.vect[i] = 0xFFFFFFFFFFFFFFFF + } + + // Set some keys to unused + unusedKeys := []uint32{0, 2, 3, 4, 6, 39, 62, 70, 98, 100} + for _, keyNum := range unusedKeys { + block, pos := getBlockAndPos(keyNum) + sv.vect[block] &= ^(1 << pos) + } + + for i, keyNum := range unusedKeys { + if sv.firstAvailable != keyNum { + t.Errorf("firstAvailable incorrect (%d)."+ + "\nexpected: %d\nreceived: %d", i, keyNum, sv.firstAvailable) + } + sv.nextAvailable() + } +} + +// Tests that StateVector.GetNumAvailable returns the expected number of +// available keys after a set number of keys are used. +func TestStateVector_GetNumAvailable(t *testing.T) { + numKeys := uint32(500) + sv := newTestStateVector("StateVectorGetNumAvailable", numKeys, t) + + i, n, used := uint32(0), uint32(5), uint32(0) + for ; i < n; i++ { + sv.use(i) + } + used = n + + if sv.GetNumAvailable() != numKeys-used { + t.Errorf("Got incorrect number of used keys."+ + "\nexpected: %d\nreceived: %d", numKeys-used, sv.GetNumAvailable()) + } + + n = uint32(112) + for ; i < n; i++ { + sv.use(i) + } + used = n + + if sv.GetNumAvailable() != numKeys-used { + t.Errorf("Got incorrect number of used keys."+ + "\nexpected: %d\nreceived: %d", numKeys-used, sv.GetNumAvailable()) + } + + i, n = uint32(400), uint32(456) + for ; i < n; i++ { + sv.use(i) + } + used += n - 400 + + if sv.GetNumAvailable() != numKeys-used { + t.Errorf("Got incorrect number of used keys."+ + "\nexpected: %d\nreceived: %d", numKeys-used, sv.GetNumAvailable()) + } +} + +// Tests that StateVector.GetNumUsed returns the expected number of used keys +// after a set number of keys are used. +func TestStateVector_GetNumUsed(t *testing.T) { + numKeys := uint32(500) + sv := newTestStateVector("StateVectorGetNumUsed", numKeys, t) + + i, n, used := uint32(0), uint32(5), uint32(0) + for ; i < n; i++ { + sv.use(i) + } + used = n + + if sv.GetNumUsed() != used { + t.Errorf("Got incorrect number of used keys."+ + "\nexpected: %d\nreceived: %d", used, sv.GetNumUsed()) + } + + n = uint32(112) + for ; i < n; i++ { + sv.use(i) + } + used = n + + if sv.GetNumUsed() != used { + t.Errorf("Got incorrect number of used keys."+ + "\nexpected: %d\nreceived: %d", used, sv.GetNumUsed()) + } + + i, n = uint32(400), uint32(456) + for ; i < n; i++ { + sv.use(i) + } + used += n - 400 + + if sv.GetNumUsed() != used { + t.Errorf("Got incorrect number of used keys."+ + "\nexpected: %d\nreceived: %d", used, sv.GetNumUsed()) + } +} + +// Tests that StateVector.GetNumKeys returns the correct number of keys. +func TestStateVector_GetNumKeys(t *testing.T) { + numKeys := uint32(32) + sv := newTestStateVector("StateVectorGetNumKeys", numKeys, t) + + if sv.GetNumKeys() != numKeys { + t.Errorf("Got incorrect number of keys.\nexpected: %d\nreceived: %d", + numKeys, sv.GetNumKeys()) + } +} + +// Tests that StateVector.GetUnusedKeyNums returns a list of all odd-numbered +// keys when all even-numbered keys are used. +func TestStateVector_GetUnusedKeyNums(t *testing.T) { + numKeys := uint32(1000) + sv := newTestStateVector("StateVectorGetUnusedKeyNums", numKeys, t) + + // Use every other key + for i := uint32(0); i < numKeys; i += 2 { + sv.use(i) + } + + sv.firstAvailable = 1 + + // Check that every other key is in the list + for i, keyNum := range sv.GetUnusedKeyNums() { + if keyNum != uint32(2*i)+1 { + t.Errorf("Key number #%d incorrect."+ + "\nexpected: %d\nreceived: %d", i, 2*i+1, keyNum) + } + } +} + +// Tests that StateVector.GetUsedKeyNums returns a list of all even-numbered +// keys when all even-numbered keys are used. +func TestStateVector_GetUsedKeyNums(t *testing.T) { + numKeys := uint32(1000) + sv := newTestStateVector("StateVectorGetUsedKeyNums", numKeys, t) + + // Use every other key + for i := uint32(0); i < numKeys; i += 2 { + sv.use(i) + } + + // Check that every other key is in the list + for i, keyNum := range sv.GetUsedKeyNums() { + if keyNum != uint32(2*i) { + t.Errorf("Key number #%d incorrect."+ + "\nexpected: %d\nreceived: %d", i, 2*i, keyNum) + } + } +} + +// Tests that StateVector.String returns the expected string. +func TestStateVector_String(t *testing.T) { + key := "StateVectorString" + expected := "stateVector: " + makeStateVectorKey(key) + sv := newTestStateVector(key, 500, t) + // Use every other key + for i := uint32(0); i < sv.numKeys; i += 2 { + sv.use(i) + } + + if expected != sv.String() { + t.Errorf("String does not match expected.\nexpected: %q\nreceived: %q", + expected, sv.String()) + } +} + +// Tests that StateVector.GoString returns the expected string. +func TestStateVector_GoString(t *testing.T) { + expected := "{vect:[6148914691236517205 6148914691236517205 " + + "6148914691236517205 6148914691236517205 6148914691236517205 " + + "6148914691236517205 6148914691236517205 1501199875790165] " + + "firstAvailable:1 numKeys:500 numAvailable:250 " + + "key:stateVectorStateVectorGoString" + sv := newTestStateVector("StateVectorGoString", 500, t) + // Use every other key + for i := uint32(0); i < sv.numKeys; i += 2 { + sv.use(i) + } + + received := strings.Split(sv.GoString(), " kv:")[0] + if expected != received { + t.Errorf("String does not match expected.\nexpected: %q\nreceived: %q", + expected, received) + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// Tests that a StateVector loaded from storage via LoadStateVector matches the +// original. +func TestLoadStateVector(t *testing.T) { + numKeys := uint32(1000) + key := "StateVectorLoadStateVector" + sv := newTestStateVector(key, numKeys, t) + + // Use every other key + for i := uint32(0); i < numKeys; i += 2 { + sv.Use(i) + } + + // Attempt to load StateVector from storage + loadedSV, err := LoadStateVector(sv.kv, key) + if err != nil { + t.Fatalf("LoadStateVector returned an error: %+v", err) + } + + if !reflect.DeepEqual(sv.vect, loadedSV.vect) { + t.Errorf("Loaded StateVector does not match original saved."+ + "\nexpected: %#v\nreceived: %#v", sv.vect, loadedSV.vect) + } +} + +// Tests that a StateVector loaded from storage via LoadStateVector matches the +// original. +func TestLoadStateVector_GetError(t *testing.T) { + key := "StateVectorLoadStateVector" + kv := versioned.NewKV(make(ekv.Memstore)) + expectedErr := "object not found" + + _, err := LoadStateVector(kv, key) + if err == nil || err.Error() != expectedErr { + t.Fatalf("LoadStateVector did not return the expected error when no "+ + "object exists in storage.\nexpected: %s\nreceived: %+v", + expectedErr, err) + } +} + +// Tests that a StateVector loaded from storage via LoadStateVector matches the +// original. +func TestLoadStateVector_UnmarshalError(t *testing.T) { + key := "StateVectorLoadStateVector" + kv := versioned.NewKV(make(ekv.Memstore)) + + // Save invalid StateVector to storage + obj := versioned.Object{ + Version: currentStateVectorVersion, + Timestamp: netTime.Now(), + Data: []byte("invalidStateVector"), + } + err := kv.Set(makeStateVectorKey(key), currentStateVectorVersion, &obj) + if err != nil { + t.Errorf("Failed to save invalid StateVector to storage: %+v", err) + } + + expectedErr := strings.Split(loadUnmarshalErr, "%")[0] + _, err = LoadStateVector(kv, key) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Fatalf("LoadStateVector did not return the expected error when the "+ + "data in storage is invalid.\nexpected: %s\nreceived: %+v", + expectedErr, err) + } +} + +// Tests that StateVector.save saves the correct data to storage and that it can +// be loaded. +func TestStateVector_save(t *testing.T) { + key := "StateVectorSave" + sv := &StateVector{ + vect: make([]uint64, (1000+63)/64), + firstAvailable: 0, + numKeys: 1000, + numAvailable: 1000, + key: makeStateVectorKey(key), + kv: versioned.NewKV(make(ekv.Memstore)), + } + expectedData, err := sv.marshal() + if err != nil { + t.Errorf("Failed to marshal StateVector: %+v", err) + } + + // Save to storage + err = sv.save() + if err != nil { + t.Errorf("save returned an error: %+v", err) + } + + // Check that the object can be loaded + loadedData, err := sv.kv.Get(sv.key, currentStateVectorVersion) + if err != nil { + t.Errorf("Failed to load StateVector from storage: %+v", err) + } + + if !bytes.Equal(expectedData, loadedData.Data) { + t.Errorf("Loaded data does not match expected."+ + "\nexpected: %v\nreceived: %v", expectedData, loadedData) + } +} + +// Tests that StateVector.Delete removes the StateVector from storage. +func TestStateVector_Delete(t *testing.T) { + sv := newTestStateVector("StateVectorDelete", 1000, t) + + err := sv.Delete() + if err != nil { + t.Errorf("Delete returned an error: %+v", err) + } + + // Check that the object can be loaded + loadedData, err := sv.kv.Get(sv.key, currentStateVectorVersion) + if err == nil { + t.Errorf("Loaded StateVector from storage when it should be deleted: %v", + loadedData) + } +} + +func TestStateVector_marshal_unmarshal(t *testing.T) { + // Generate new StateVector and use ever other key + sv1 := newTestStateVector("StateVectorMarshalUnmarshal", 224, t) + for i := uint32(0); i < sv1.GetNumKeys(); i += 2 { + sv1.Use(i) + } + + // Marshal and unmarshal the StateVector + marshalledData, err := sv1.marshal() + if err != nil { + t.Errorf("marshal returned an error: %+v", err) + } + + // Unmarshal into new StateVector + sv2 := &StateVector{key: sv1.key, kv: sv1.kv} + err = sv2.unmarshal(marshalledData) + if err != nil { + t.Errorf("unmarshal returned an error: %+v", err) + } + + // Make sure that the unmarshalled StateVector matches the original + if !reflect.DeepEqual(sv1, sv2) { + t.Errorf("Marshalled and unmarshalled StateVector does not match "+ + "original.\nexpected: %#v\nreceived: %#v", sv1, sv2) + } +} + +// Consistency test of makeStateVectorKey. +func Test_makeStateVectorKey(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + expectedStrings := []string{ + "stateVectorU4x/lrFkvxuXu59LtHLon1sU", + "stateVectorhPJSCcnZND6SugndnVLf15tN", + "stateVectordkKbYXoMn58NO6VbDMDWFEyI", + "stateVectorhTWEGsvgcJsHWAg/YdN1vAK0", + "stateVectorHfT5GSnhj9qeb4LlTnSOgeee", + "stateVectorS71v40zcuoQ+6NY+jE/+HOvq", + "stateVectorVG2PrBPdGqwEzi6ih3xVec+i", + "stateVectorx44bC6+uiBuCp1EQikLtPJA8", + "stateVectorqkNGWnhiBhaXiu0M48bE8657", + "stateVectorw+BJW1cS/v2+DBAoh+EA2s0t", + } + + for i, expected := range expectedStrings { + b := make([]byte, 18) + prng.Read(b) + key := makeStateVectorKey(base64.StdEncoding.EncodeToString(b)) + + if expected != key { + t.Errorf("New StateVector key does not match expected (%d)."+ + "\nexpected: %q\nreceived: %q", i, expected, key) + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Testing Functions // +//////////////////////////////////////////////////////////////////////////////// + +// Tests that StateVector.SaveTEST saves the correct data to storage and that it +// can be loaded. +func TestStateVector_SaveTEST(t *testing.T) { + key := "StateVectorSaveTEST" + sv := &StateVector{ + vect: make([]uint64, (1000+63)/64), + firstAvailable: 0, + numKeys: 1000, + numAvailable: 1000, + key: makeStateVectorKey(key), + kv: versioned.NewKV(make(ekv.Memstore)), + } + expectedData, err := sv.marshal() + if err != nil { + t.Errorf("Failed to marshal StateVector: %+v", err) + } + + // Save to storage + err = sv.SaveTEST(t) + if err != nil { + t.Errorf("save returned an error: %+v", err) + } + + // Check that the object can be loaded + loadedData, err := sv.kv.Get(sv.key, currentStateVectorVersion) + if err != nil { + t.Errorf("Failed to load StateVector from storage: %+v", err) + } + + if !bytes.Equal(expectedData, loadedData.Data) { + t.Errorf("Loaded data does not match expected."+ + "\nexpected: %v\nreceived: %v", expectedData, loadedData) + } +} + +// Panic path: tests that StateVector.SaveTEST panics when provided a non- +// testing interface. +func TestStateVector_SaveTEST_InvalidInterfaceError(t *testing.T) { + sv := &StateVector{} + expectedErr := fmt.Sprintf(testInterfaceErr, "SaveTEST") + + defer func() { + if r := recover(); r == nil || r.(string) != expectedErr { + t.Errorf("Failed to panic with expected error when provided a "+ + "non-testing interface.\nexpected: %s\nreceived: %+v", + expectedErr, r) + } + }() + + _ = sv.SaveTEST(struct{}{}) +} + +// Tests that StateVector.SetFirstAvailableTEST correctly sets firstAvailable. +func TestStateVector_SetFirstAvailableTEST(t *testing.T) { + sv := newTestStateVector("StateVectorSetFirstAvailableTEST", 1000, t) + + firstAvailable := uint32(78) + sv.SetFirstAvailableTEST(firstAvailable, t) + + if sv.firstAvailable != firstAvailable { + t.Errorf("Failed to set firstAvailable.\nexpected: %d\nreceived: %d", + firstAvailable, sv.firstAvailable) + } +} + +// Panic path: tests that StateVector.SetFirstAvailableTEST panics when provided +// a non-testing interface. +func TestStateVector_SetFirstAvailableTEST_InvalidInterfaceError(t *testing.T) { + sv := &StateVector{} + expectedErr := fmt.Sprintf(testInterfaceErr, "SetFirstAvailableTEST") + + defer func() { + if r := recover(); r == nil || r.(string) != expectedErr { + t.Errorf("Failed to panic with expected error when provided a "+ + "non-testing interface.\nexpected: %s\nreceived: %+v", + expectedErr, r) + } + }() + + sv.SetFirstAvailableTEST(0, struct{}{}) +} + +// Tests that StateVector.SetNumKeysTEST correctly sets numKeys. +func TestStateVector_SetNumKeysTEST(t *testing.T) { + sv := newTestStateVector("StateVectorSetNumKeysTEST", 1000, t) + + numKeys := uint32(78) + sv.SetNumKeysTEST(numKeys, t) + + if sv.numKeys != numKeys { + t.Errorf("Failed to set numKeys.\nexpected: %d\nreceived: %d", + numKeys, sv.numKeys) + } +} + +// Panic path: tests that StateVector.SetNumKeysTEST panics when provided a non- +// testing interface. +func TestStateVector_SetNumKeysTEST_InvalidInterfaceError(t *testing.T) { + sv := &StateVector{} + expectedErr := fmt.Sprintf(testInterfaceErr, "SetNumKeysTEST") + + defer func() { + if r := recover(); r == nil || r.(string) != expectedErr { + t.Errorf("Failed to panic with expected error when provided a "+ + "non-testing interface.\nexpected: %s\nreceived: %+v", + expectedErr, r) + } + }() + + sv.SetNumKeysTEST(0, struct{}{}) +} + +// Tests that StateVector.SetNumAvailableTEST correctly sets numKeys. +func TestStateVector_SetNumAvailableTEST(t *testing.T) { + sv := newTestStateVector("StateVectorSetNumAvailableTEST", 1000, t) + + numAvailable := uint32(78) + sv.SetNumAvailableTEST(numAvailable, t) + + if sv.numAvailable != numAvailable { + t.Errorf("Failed to set numAvailable.\nexpected: %d\nreceived: %d", + numAvailable, sv.numAvailable) + } +} + +// Panic path: tests that StateVector.SetNumAvailableTEST panics when provided a +// non-testing interface. +func TestStateVector_SetNumAvailableTEST_InvalidInterfaceError(t *testing.T) { + sv := &StateVector{} + expectedErr := fmt.Sprintf(testInterfaceErr, "SetNumAvailableTEST") + + defer func() { + if r := recover(); r == nil || r.(string) != expectedErr { + t.Errorf("Failed to panic with expected error when provided a "+ + "non-testing interface.\nexpected: %s\nreceived: %+v", + expectedErr, r) + } + }() + + sv.SetNumAvailableTEST(0, struct{}{}) +} + +// Tests that StateVector.SetKvTEST correctly sets the versioned.KV. +func TestStateVector_SetKvTEST(t *testing.T) { + sv := newTestStateVector("SetKvTEST", 1000, t) + + kv := versioned.NewKV(make(ekv.Memstore)).Prefix("NewKV") + sv.SetKvTEST(kv, t) + + if sv.kv != kv { + t.Errorf("Failed to set the KV.\nexpected: %v\nreceived: %v", kv, sv.kv) + } +} + +// Panic path: tests that StateVector.SetKvTEST panics when provided a non- +// testing interface. +func TestStateVector_SetKvTEST_InvalidInterfaceError(t *testing.T) { + sv := &StateVector{} + expectedErr := fmt.Sprintf(testInterfaceErr, "SetKvTEST") + + defer func() { + if r := recover(); r == nil || r.(string) != expectedErr { + t.Errorf("Failed to panic with expected error when provided a "+ + "non-testing interface.\nexpected: %s\nreceived: %+v", + expectedErr, r) + } + }() + + sv.SetKvTEST(nil, struct{}{}) +} + +// newTestStateVector produces a new StateVector using the specified number of +// keys and key string for testing. +func newTestStateVector(key string, numKeys uint32, t *testing.T) *StateVector { + kv := versioned.NewKV(make(ekv.Memstore)) + + sv, err := NewStateVector(kv, key, numKeys) + if err != nil { + t.Fatalf("Failed to create new StateVector: %+v", err) + } + + return sv +} diff --git a/ud/utils_test.go b/ud/utils_test.go index c04b9cb1579d65efae12c5191db0f07eaa8c8d07..fc224802e63a1b489013f7e07aac82293e249665 100644 --- a/ud/utils_test.go +++ b/ud/utils_test.go @@ -68,7 +68,7 @@ func (tnm *testNetworkManager) SendCMIX(format.Message, *id.ID, params.CMIX) (id return 0, ephemeral.Id{}, nil } -func (tnm *testNetworkManager) SendManyCMIX(map[id.ID]format.Message, params.CMIX) (id.Round, []ephemeral.Id, error) { +func (tnm *testNetworkManager) SendManyCMIX([]message.TargetedCmixMessage, params.CMIX) (id.Round, []ephemeral.Id, error) { return 0, nil, nil }