diff --git a/api/authenticatedChannel.go b/api/authenticatedChannel.go index b302f78d46b81b9ebdb20d9f3e7e08f9ce2abb47..4b5f15c16a7a8d1e03fdea41e5a9e2338bb0ceee 100644 --- a/api/authenticatedChannel.go +++ b/api/authenticatedChannel.go @@ -126,6 +126,13 @@ func (c *Client) MakePrecannedAuthenticatedChannel(precannedID uint) (contact.Co Source: precan.ID[:], }, me) + // File transfer end + c.storage.GetEdge().Add(edge.Preimage{ + Data: sessionPartner.GetFileTransferPreimage(), + Type: preimage.EndFT, + Source: precan.ID[:], + }, me) + return precan, err } diff --git a/api/client.go b/api/client.go index c0c5e5133803a66610ecfa21792c525643969502..4f31af1e887c5f3cbddea6090c22b5742cdaadf5 100644 --- a/api/client.go +++ b/api/client.go @@ -626,6 +626,7 @@ func (c *Client) DeleteContact(partnerId *id.ID) error { } e2ePreimage := partner.GetE2EPreimage() rekeyPreimage := partner.GetRekeyPreimage() + fileTransferPreimage := partner.GetFileTransferPreimage() //delete the partner if err = c.storage.E2e().DeletePartner(partnerId); err != nil { @@ -650,6 +651,15 @@ func (c *Client) DeleteContact(partnerId *id.ID) error { "from %s on contact deletion: %+v", partnerId, err) } + if err = c.storage.GetEdge().Remove(edge.Preimage{ + Data: fileTransferPreimage, + Type: preimage.EndFT, + Source: partnerId[:], + }, c.storage.GetUser().ReceptionID); err != nil { + jww.WARN.Printf("Failed delete the preimage for file transfer "+ + "from %s on contact deletion: %+v", partnerId, err) + } + if err = c.storage.Auth().Delete(partnerId); err != nil { return err } diff --git a/auth/callback.go b/auth/callback.go index e3d53fb47f4aa868ee23b27c757dd9d05145bb39..7f10aa06b948ba405da33e64ccea615d6ee09b1d 100644 --- a/auth/callback.go +++ b/auth/callback.go @@ -347,6 +347,13 @@ func (m *Manager) doConfirm(sr *auth.SentRequest, grp *cyclic.Group, Source: sr.GetPartner()[:], }, me) + // File transfer end + m.storage.GetEdge().Add(edge.Preimage{ + Data: sessionPartner.GetFileTransferPreimage(), + Type: preimage.EndFT, + Source: sr.GetPartner()[:], + }, me) + // delete the in progress negotiation // this undoes the request lock if err := m.storage.Auth().Delete(sr.GetPartner()); err != nil { diff --git a/auth/confirm.go b/auth/confirm.go index 62216ed28935abe65db07f3fe6c313f3bd273b0c..1afe54b9cfbd9f541731c89fc8de422c4efff692 100644 --- a/auth/confirm.go +++ b/auth/confirm.go @@ -138,6 +138,13 @@ func ConfirmRequestAuth(partner contact.Contact, rng io.Reader, Source: partner.ID[:], }, me) + // File transfer end + storage.GetEdge().Add(edge.Preimage{ + Data: sessionPartner.GetFileTransferPreimage(), + Type: preimage.EndFT, + Source: partner.ID[:], + }, me) + // delete the in progress negotiation // this unlocks the request lock //fixme - do these deletes at a later date diff --git a/bindings/fileTransfer.go b/bindings/fileTransfer.go index 037686cb42eaceba92a967bd133cfe05e100e0da..1bf1258271867e25cdf5a0d93893e1416f6c2938 100644 --- a/bindings/fileTransfer.go +++ b/bindings/fileTransfer.go @@ -154,11 +154,11 @@ func (f *FileTransfer) RegisterSendProgressCallback(transferID []byte, // Convert period to time.Duration period := time.Duration(periodMS) * time.Millisecond - return f.m.RegisterSendProgressCallback(tid, progressCB, period) + return f.m.RegisterSentProgressCallback(tid, progressCB, period) } // Resend resends a file if sending fails. This function should only be called -//if the interfaces.SentProgressCallback returns an error. +// if the interfaces.SentProgressCallback returns an error. func (f *FileTransfer) Resend(transferID []byte) error { // Unmarshal transfer ID tid := ftCrypto.UnmarshalTransferID(transferID) @@ -212,7 +212,7 @@ func (f *FileTransfer) RegisterReceiveProgressCallback(transferID []byte, // Convert period to time.Duration period := time.Duration(periodMS) * time.Millisecond - return f.m.RegisterReceiveProgressCallback(tid, progressCB, period) + return f.m.RegisterReceivedProgressCallback(tid, progressCB, period) } //////////////////////////////////////////////////////////////////////////////// @@ -258,7 +258,7 @@ type FilePartTracker struct { // 2 = arrived (sender has sent a part, and it has arrived) // 3 = received (receiver has received a part) func (fpt *FilePartTracker) GetPartStatus(partNum int) int { - return fpt.m.GetPartStatus(uint16(partNum)) + return int(fpt.m.GetPartStatus(uint16(partNum))) } // GetNumParts returns the total number of file parts in the transfer. diff --git a/bindings/notifications.go b/bindings/notifications.go index 55d262acaa2d36393e9cc89fab3e4b0ba814cd7a..c7d4b605578b2cafe98f2042f06d64adbbd48a10 100644 --- a/bindings/notifications.go +++ b/bindings/notifications.go @@ -36,13 +36,14 @@ func (nfmr *NotificationForMeReport) Source() []byte { // NotificationForMe Check if a notification received is for me // It returns a NotificationForMeReport which contains a ForMe bool stating if it is for the caller, // a Type, and a source. These are as follows: -// TYPE SOURCE DESCRIPTION +// TYPE SOURCE DESCRIPTION // "default" recipient user ID A message with no association // "request" sender user ID A channel request has been received -// "confirm" sender user ID A channel request has been accepted -// "rekey" sender user ID keys with a user have been rotated -// "e2e" sender user ID reception of an E2E message -// "group" group ID reception of a group chat message +// "confirm" sender user ID A channel request has been accepted +// "rekey" sender user ID keys with a user have been rotated +// "e2e" sender user ID reception of an E2E message +// "group" group ID reception of a group chat message +// "endFT" sender user ID Last message sent confirming end of file transfer func NotificationForMe(messageHash, idFP string, preimages string) (*NotificationForMeReport, error) { //handle message hash and idFP messageHashBytes, err := base64.StdEncoding.DecodeString(messageHash) diff --git a/cmd/fileTransfer.go b/cmd/fileTransfer.go index a80ef51e0e9b1c9dfa07aa6139a09fb8ebf60008..254e788abf70a69b85df28dc3e5e15ca750ab826 100644 --- a/cmd/fileTransfer.go +++ b/cmd/fileTransfer.go @@ -247,7 +247,7 @@ func receiveNewFileTransfers(receive chan receivedFtResults, done, "bytes with preview: %q\n", r.fileName, r.size, r.preview) cb := newReceiveProgressCB(r.tid, done, m) - err := m.RegisterReceiveProgressCallback(r.tid, cb, callbackPeriod) + err := m.RegisterReceivedProgressCallback(r.tid, cb, callbackPeriod) if err != nil { jww.FATAL.Panicf("Failed to register new receive progress "+ "callback for transfer %s: %+v", r.tid, err) diff --git a/fileTransfer/ftMessages.pb.go b/fileTransfer/ftMessages.pb.go index 4b6f510eb08daeafc86c3944291ea4159a540ca6..4bc094d03ea4cd8e3776298f8c216972c286e45e 100644 --- a/fileTransfer/ftMessages.pb.go +++ b/fileTransfer/ftMessages.pb.go @@ -27,6 +27,8 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +// NewFileTransfer is transmitted first on the initialization of a file transfer +// to inform the receiver about the incoming file. type NewFileTransfer struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache diff --git a/fileTransfer/ftMessages.proto b/fileTransfer/ftMessages.proto index 2802df2b580932c34651b9b43801cd4ae07abcf2..22976e3697ba8e34692f0a94f0225098f36ec33c 100644 --- a/fileTransfer/ftMessages.proto +++ b/fileTransfer/ftMessages.proto @@ -10,6 +10,8 @@ syntax = "proto3"; package parse; option go_package = "fileTransfer/"; +// NewFileTransfer is transmitted first on the initialization of a file transfer +// to inform the receiver about the incoming file. message NewFileTransfer { string fileName = 1; // Name of the file; max 48 characters string fileType = 2; // Type of file; max 8 characters diff --git a/fileTransfer/manager.go b/fileTransfer/manager.go index c63d8eb8de15d795fa0defb598a2e79f4188bfdd..9785856651ddc65e4dc13fdef42ec679c0c2a6f8 100644 --- a/fileTransfer/manager.go +++ b/fileTransfer/manager.go @@ -59,12 +59,13 @@ const ( newManagerReceivedErr = "failed to load or create new list of received file transfers: %+v" // Manager.Send - fileNameSizeErr = "length of filename (%d) greater than max allowed length (%d)" - fileTypeSizeErr = "length of file type (%d) greater than max allowed length (%d)" - fileSizeErr = "size of file (%d bytes) greater than max allowed size (%d bytes)" - previewSizeErr = "size of preview (%d bytes) greater than max allowed size (%d bytes)" - getPartSizeErr = "failed to get file part size: %+v" - sendInitMsgErr = "failed to send initial file transfer message: %+v" + sendNetworkHealthErr = "cannot initiate file transfer of %q to %s when network is not healthy." + fileNameSizeErr = "length of filename (%d) greater than max allowed length (%d)" + fileTypeSizeErr = "length of file type (%d) greater than max allowed length (%d)" + fileSizeErr = "size of file (%d bytes) greater than max allowed size (%d bytes)" + previewSizeErr = "size of preview (%d bytes) greater than max allowed size (%d bytes)" + getPartSizeErr = "failed to get file part size: %+v" + sendInitMsgErr = "failed to send initial file transfer message: %+v" // Manager.Resend transferNotFailedErr = "transfer %s has not failed" @@ -90,10 +91,10 @@ type Manager struct { receiveCB interfaces.ReceiveCallback // Storage-backed structure for tracking sent file transfers - sent *ftStorage.SentFileTransfers + sent *ftStorage.SentFileTransfersStore // Storage-backed structure for tracking received file transfers - received *ftStorage.ReceivedFileTransfers + received *ftStorage.ReceivedFileTransfersStore // Queue of parts to send sendQueue chan queuedPart @@ -101,12 +102,15 @@ type Manager struct { // Maximum data transfer speed in bytes per second maxThroughput int + // Indicates if old transfers saved to storage have been recovered after + // file transfer is closed and reopened + oldTransfersRecovered bool + // Client interfaces client *api.Client store *storage.Session swb interfaces.Switchboard net interfaces.NetworkManager - healthy chan bool rng *fastRNG.StreamGenerator getRoundResults getRoundResultsFunc } @@ -141,31 +145,31 @@ func newManager(client *api.Client, store *storage.Session, // Create a new list of sent file transfers or load one if it exists in // storage - sent, err := ftStorage.NewOrLoadSentFileTransfers(kv) + sent, err := ftStorage.NewOrLoadSentFileTransfersStore(kv) if err != nil { return nil, errors.Errorf(newManagerSentErr, err) } // Create a new list of received file transfers or load one if it exists in // storage - received, err := ftStorage.NewOrLoadReceivedFileTransfers(kv) + received, err := ftStorage.NewOrLoadReceivedFileTransfersStore(kv) if err != nil { return nil, errors.Errorf(newManagerReceivedErr, err) } return &Manager{ - receiveCB: receiveCB, - sent: sent, - received: received, - sendQueue: make(chan queuedPart, sendQueueBuffLen), - maxThroughput: p.MaxThroughput, - client: client, - store: store, - swb: swb, - net: net, - healthy: make(chan bool, networkHealthBuffLen), - rng: rng, - getRoundResults: getRoundResults, + receiveCB: receiveCB, + sent: sent, + received: received, + sendQueue: make(chan queuedPart, sendQueueBuffLen), + maxThroughput: p.MaxThroughput, + oldTransfersRecovered: false, + client: client, + store: store, + swb: swb, + net: net, + rng: rng, + getRoundResults: getRoundResults, }, nil } @@ -188,7 +192,13 @@ func (m *Manager) startProcesses(newFtChan, filePartChan chan message.Receive) ( // Register network health channel that is used by the sending thread to // ensure the network is healthy before sending - m.net.GetHealthTracker().AddChannel(m.healthy) + healthyRecover := make(chan bool, networkHealthBuffLen) + m.net.GetHealthTracker().AddChannel(healthyRecover) + healthySend := make(chan bool, networkHealthBuffLen) + m.net.GetHealthTracker().AddChannel(healthySend) + + // Recover unsent parts from storage + m.oldTransferRecovery(healthyRecover) // Start the new file transfer message reception thread newFtStop := stoppable.NewSingle(newFtStoppableName) @@ -204,7 +214,7 @@ func (m *Manager) startProcesses(newFtChan, filePartChan chan message.Receive) ( // Start the file part sending thread sendStop := stoppable.NewSingle(sendStoppableName) - go m.sendThread(sendStop, getRandomNumParts) + go m.sendThread(sendStop, healthySend, getRandomNumParts) // Create a multi stoppable multiStoppable := stoppable.NewMulti(fileTransferStoppableName) @@ -224,6 +234,12 @@ func (m Manager) Send(fileName, fileType string, fileData []byte, progressCB interfaces.SentProgressCallback, period time.Duration) ( ftCrypto.TransferID, error) { + // Return an error if the network is not healthy + if !m.net.GetHealthTracker().IsHealthy() { + return ftCrypto.TransferID{}, errors.Errorf( + sendNetworkHealthErr, fileName, recipient) + } + // Return an error if the file name is too long if len(fileName) > FileNameMaxLen { return ftCrypto.TransferID{}, errors.Errorf( @@ -290,16 +306,17 @@ func (m Manager) Send(fileName, fileType string, fileData []byte, } rng.Close() - m.queueParts(transferID, numParts) + // Add all parts to queue + m.queueParts(transferID, makeListOfPartNums(numParts)) return transferID, nil } -// RegisterSendProgressCallback adds the sent progress callback to the sent +// RegisterSentProgressCallback adds the sent progress callback to the sent // transfer so that it will be called when updates for the transfer occur. The // progress callback is called when initially added and on transfer updates, at // most once per period. -func (m Manager) RegisterSendProgressCallback(tid ftCrypto.TransferID, +func (m Manager) RegisterSentProgressCallback(tid ftCrypto.TransferID, progressCB interfaces.SentProgressCallback, period time.Duration) error { // Get the transfer for the given ID transfer, err := m.sent.GetTransfer(tid) @@ -317,10 +334,10 @@ func (m Manager) RegisterSendProgressCallback(tid ftCrypto.TransferID, // was already called or if the transfer did not run out of retries. This // function should only be called if the interfaces.SentProgressCallback returns // an error. -// TODO: add test -// TODO: write test -// TODO: can you resend? Can you reuse fingerprints? -// TODO: what to do if sendE2E fails? +// TODO: Need to implement Resend but there are some unanswered questions. +// - Can you resend? +// - Can you reuse fingerprints? +// - What to do if sendE2E fails? func (m Manager) Resend(tid ftCrypto.TransferID) error { // Get the transfer for the given ID transfer, err := m.sent.GetTransfer(tid) @@ -379,11 +396,11 @@ func (m Manager) Receive(tid ftCrypto.TransferID) ([]byte, error) { return file, m.received.DeleteTransfer(tid) } -// RegisterReceiveProgressCallback adds the reception progress callback to the +// RegisterReceivedProgressCallback adds the reception progress callback to the // received transfer so that it will be called when updates for the transfer // occur. The progress callback is called when initially added and on transfer // updates, at most once per period. -func (m Manager) RegisterReceiveProgressCallback(tid ftCrypto.TransferID, +func (m Manager) RegisterReceivedProgressCallback(tid ftCrypto.TransferID, progressCB interfaces.ReceivedProgressCallback, period time.Duration) error { // Get the transfer for the given ID transfer, err := m.received.GetTransfer(tid) diff --git a/fileTransfer/manager_test.go b/fileTransfer/manager_test.go index eaec128086306cd3579b7925a0c637a3d4eb999b..21f1bbd5068709d8ed92162b0b9bdcfe7030c55d 100644 --- a/fileTransfer/manager_test.go +++ b/fileTransfer/manager_test.go @@ -14,8 +14,10 @@ import ( "github.com/golang/protobuf/proto" "gitlab.com/elixxir/client/interfaces" "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/interfaces/params" ftStorage "gitlab.com/elixxir/client/storage/fileTransfer" "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/crypto/diffieHellman" ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" "gitlab.com/elixxir/ekv" "gitlab.com/xx_network/primitives/id" @@ -41,17 +43,17 @@ func Test_newManager(t *testing.T) { t.Errorf("newManager returned an error: %+v", err) } - // Check that the SentFileTransfers is new and correct - expectedSent, _ := ftStorage.NewSentFileTransfers(kv) + // Check that the SentFileTransfersStore is new and correct + expectedSent, _ := ftStorage.NewSentFileTransfersStore(kv) if !reflect.DeepEqual(expectedSent, m.sent) { - t.Errorf("SentFileTransfers in manager incorrect."+ + t.Errorf("SentFileTransfersStore in manager incorrect."+ "\nexpected: %+v\nreceived: %+v", expectedSent, m.sent) } - // Check that the ReceivedFileTransfers is new and correct - expectedReceived, _ := ftStorage.NewReceivedFileTransfers(kv) + // Check that the ReceivedFileTransfersStore is new and correct + expectedReceived, _ := ftStorage.NewReceivedFileTransfersStore(kv) if !reflect.DeepEqual(expectedReceived, m.received) { - t.Errorf("ReceivedFileTransfers in manager incorrect."+ + t.Errorf("ReceivedFileTransfersStore in manager incorrect."+ "\nexpected: %+v\nreceived: %+v", expectedReceived, m.received) } @@ -67,14 +69,14 @@ func Test_newManager(t *testing.T) { // Tests that Manager.Send adds a new sent transfer, sends the NewFileTransfer // E2E message, and adds all the file parts to the queue. func TestManager_Send(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) prng := NewPrng(42) recipient := id.NewIdFromString("recipient", id.User, t) fileName := "testFile" fileType := "txt" numParts := uint16(16) partSize, _ := m.getPartSize() - fileData, _ := newFile(numParts, uint32(partSize), prng, t) + fileData, _ := newFile(numParts, partSize, prng, t) preview := []byte("filePreview") retry := float32(1.5) numFps := calcNumberOfFingerprints(numParts, retry) @@ -147,10 +149,28 @@ func TestManager_Send(t *testing.T) { } } +// Error path: tests that Manager.Send returns the expected error when the +// network is not healthy. +func TestManager_Send_NetworkHealthError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, nil, t) + + fileName := "MySentFile" + recipient := id.NewIdFromString("recipient", id.User, t) + expectedErr := fmt.Sprintf(sendNetworkHealthErr, fileName, recipient) + + m.net.(*testNetworkManager).health.healthy = false + + _, err := m.Send(fileName, "", nil, recipient, 0, nil, nil, 0) + if err == nil || err.Error() != expectedErr { + t.Errorf("Send did not return the expected error when the network is "+ + "not healthy.\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + // Error path: tests that Manager.Send returns the expected error when the // provided file name is longer than FileNameMaxLen. func TestManager_Send_FileNameLengthError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) fileName := strings.Repeat("A", FileNameMaxLen+1) expectedErr := fmt.Sprintf(fileNameSizeErr, len(fileName), FileNameMaxLen) @@ -165,7 +185,7 @@ func TestManager_Send_FileNameLengthError(t *testing.T) { // Error path: tests that Manager.Send returns the expected error when the // provided file type is longer than FileTypeMaxLen. func TestManager_Send_FileTypeLengthError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) fileType := strings.Repeat("A", FileTypeMaxLen+1) expectedErr := fmt.Sprintf(fileTypeSizeErr, len(fileType), FileTypeMaxLen) @@ -180,7 +200,7 @@ func TestManager_Send_FileTypeLengthError(t *testing.T) { // Error path: tests that Manager.Send returns the expected error when the // provided file is larger than FileMaxSize. func TestManager_Send_FileSizeError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) fileData := make([]byte, FileMaxSize+1) expectedErr := fmt.Sprintf(fileSizeErr, len(fileData), FileMaxSize) @@ -195,7 +215,7 @@ func TestManager_Send_FileSizeError(t *testing.T) { // Error path: tests that Manager.Send returns the expected error when the // provided preview is larger than PreviewMaxSize. func TestManager_Send_PreviewSizeError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) previewData := make([]byte, PreviewMaxSize+1) expectedErr := fmt.Sprintf(previewSizeErr, len(previewData), PreviewMaxSize) @@ -210,18 +230,18 @@ func TestManager_Send_PreviewSizeError(t *testing.T) { // Error path: tests that Manager.Send returns the expected error when the E2E // message fails to send. func TestManager_Send_SendE2eError(t *testing.T) { - m := newTestManager(true, nil, nil, nil, t) + m := newTestManager(true, nil, nil, nil, nil, t) prng := NewPrng(42) recipient := id.NewIdFromString("recipient", id.User, t) fileName := "testFile" fileType := "bytes" numParts := uint16(16) partSize, _ := m.getPartSize() - fileData, _ := newFile(numParts, uint32(partSize), prng, t) + fileData, _ := newFile(numParts, partSize, prng, t) preview := []byte("filePreview") retry := float32(1.5) - expectedErr := fmt.Sprintf(sendE2eErr, recipient, "") + expectedErr := fmt.Sprintf(newFtSendE2eErr, recipient, "") _, err := m.Send( fileName, fileType, fileData, recipient, retry, preview, nil, 0) @@ -231,11 +251,12 @@ func TestManager_Send_SendE2eError(t *testing.T) { } } -// Tests that Manager.RegisterSendProgressCallback calls the callback when it is +// Tests that Manager.RegisterSentProgressCallback calls the callback when it is // added to the transfer and that the callback is associated with the expected // transfer and is called when calling from the transfer. -func TestManager_RegisterSendProgressCallback(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers([]uint16{12, 4, 1}, false, nil, t) +func TestManager_RegisterSentProgressCallback(t *testing.T) { + m, sti, _ := newTestManagerWithTransfers( + []uint16{12, 4, 1}, false, false, nil, nil, t) expectedErr := errors.New("CallbackError") // Create new callback and channel for the callback to trigger @@ -275,9 +296,9 @@ func TestManager_RegisterSendProgressCallback(t *testing.T) { } }() - err := m.RegisterSendProgressCallback(sti[0].tid, cb, time.Millisecond) + err := m.RegisterSentProgressCallback(sti[0].tid, cb, 1*time.Millisecond) if err != nil { - t.Errorf("RegisterSendProgressCallback returned an error: %+v", err) + t.Errorf("RegisterSentProgressCallback returned an error: %+v", err) } <-done0 @@ -288,15 +309,15 @@ func TestManager_RegisterSendProgressCallback(t *testing.T) { <-done1 } -// Error path: tests that Manager.RegisterSendProgressCallback returns an error +// Error path: tests that Manager.RegisterSentProgressCallback returns an error // when no transfer with the ID exists. -func TestManager_RegisterSendProgressCallback_NoTransferError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) +func TestManager_RegisterSentProgressCallback_NoTransferError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, nil, t) tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) - err := m.RegisterSendProgressCallback(tid, nil, 0) + err := m.RegisterSentProgressCallback(tid, nil, 0) if err == nil { - t.Error("RegisterSendProgressCallback did not return an error when " + + t.Error("RegisterSentProgressCallback did not return an error when " + "no transfer with the ID exists.") } } @@ -308,7 +329,7 @@ func TestManager_Resend(t *testing.T) { // Error path: tests that Manager.Resend returns an error when no transfer with // the ID exists. func TestManager_Resend_NoTransferError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) err := m.Resend(tid) @@ -321,7 +342,8 @@ func TestManager_Resend_NoTransferError(t *testing.T) { // Error path: tests that Manager.Resend returns the error when the transfer has // not run out of fingerprints. func TestManager_Resend_NoFingerprints(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers([]uint16{16}, false, nil, t) + m, sti, _ := newTestManagerWithTransfers( + []uint16{16}, false, false, nil, nil, t) expectedErr := fmt.Sprintf(transferNotFailedErr, sti[0].tid) // Delete the transfer err := m.Resend(sti[0].tid) @@ -335,7 +357,8 @@ func TestManager_Resend_NoFingerprints(t *testing.T) { // Tests that Manager.CloseSend deletes the transfer when it has run out of // fingerprints but is not complete. func TestManager_CloseSend_NoFingerprints(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers([]uint16{16}, false, nil, t) + m, sti, _ := newTestManagerWithTransfers( + []uint16{16}, false, false, nil, nil, t) prng := NewPrng(42) partSize, _ := m.getPartSize() @@ -365,7 +388,8 @@ func TestManager_CloseSend_NoFingerprints(t *testing.T) { // Tests that Manager.CloseSend deletes the transfer when it completed but has // fingerprints. func TestManager_CloseSend_Complete(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers([]uint16{3}, false, nil, t) + m, sti, _ := newTestManagerWithTransfers( + []uint16{3}, false, false, nil, nil, t) // Set all parts to finished transfer, _ := m.sent.GetTransfer(sti[0].tid) @@ -373,11 +397,16 @@ func TestManager_CloseSend_Complete(t *testing.T) { if err != nil { t.Errorf("Failed to set parts to in-progress: %+v", err) } - err = transfer.FinishTransfer(0) + complete, err := transfer.FinishTransfer(0) if err != nil { t.Errorf("Failed to set parts to finished: %+v", err) } + // Ensure that FinishTransfer reported the transfer as complete + if !complete { + t.Error("FinishTransfer did not report the transfer as complete.") + } + // Delete the transfer err = m.CloseSend(sti[0].tid) if err != nil { @@ -394,7 +423,7 @@ func TestManager_CloseSend_Complete(t *testing.T) { // Error path: tests that Manager.CloseSend returns an error when no transfer // with the ID exists. func TestManager_CloseSend_NoTransferError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) err := m.CloseSend(tid) @@ -407,7 +436,8 @@ func TestManager_CloseSend_NoTransferError(t *testing.T) { // Error path: tests that Manager.CloseSend returns an error when the transfer // has not run out of fingerprints and is not complete func TestManager_CloseSend_NotCompleteErr(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers([]uint16{16}, false, nil, t) + m, sti, _ := newTestManagerWithTransfers( + []uint16{16}, false, false, nil, nil, t) expectedErr := fmt.Sprintf(transferInProgressErr, sti[0].tid) err := m.CloseSend(sti[0].tid) @@ -420,7 +450,7 @@ func TestManager_CloseSend_NotCompleteErr(t *testing.T) { // Error path: tests that Manager.Receive returns an error when no transfer with // the ID exists. func TestManager_Receive_NoTransferError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) _, err := m.Receive(tid) @@ -433,7 +463,8 @@ func TestManager_Receive_NoTransferError(t *testing.T) { // Error path: tests that Manager.Receive returns an error when the file is // incomplete. func TestManager_Receive_GetFileError(t *testing.T) { - m, _, rti := newTestManagerWithTransfers([]uint16{12, 4, 1}, false, nil, t) + m, _, rti := newTestManagerWithTransfers( + []uint16{12, 4, 1}, false, false, nil, nil, t) _, err := m.Receive(rti[0].tid) if err == nil || !strings.Contains(err.Error(), "missing") { @@ -442,11 +473,12 @@ func TestManager_Receive_GetFileError(t *testing.T) { } } -// Tests that Manager.RegisterReceiveProgressCallback calls the callback when it -// is added to the transfer and that the callback is associated with the +// Tests that Manager.RegisterReceivedProgressCallback calls the callback when +// it is added to the transfer and that the callback is associated with the // expected transfer and is called when calling from the transfer. -func TestManager_RegisterReceiveProgressCallback(t *testing.T) { - m, _, rti := newTestManagerWithTransfers([]uint16{12, 4, 1}, false, nil, t) +func TestManager_RegisterReceivedProgressCallback(t *testing.T) { + m, _, rti := newTestManagerWithTransfers( + []uint16{12, 4, 1}, false, false, nil, nil, t) expectedErr := errors.New("CallbackError") // Create new callback and channel for the callback to trigger @@ -486,9 +518,9 @@ func TestManager_RegisterReceiveProgressCallback(t *testing.T) { } }() - err := m.RegisterReceiveProgressCallback(rti[0].tid, cb, time.Millisecond) + err := m.RegisterReceivedProgressCallback(rti[0].tid, cb, time.Millisecond) if err != nil { - t.Errorf("RegisterReceiveProgressCallback returned an error: %+v", err) + t.Errorf("RegisterReceivedProgressCallback returned an error: %+v", err) } <-done0 @@ -499,15 +531,15 @@ func TestManager_RegisterReceiveProgressCallback(t *testing.T) { <-done1 } -// Error path: tests that Manager.RegisterReceiveProgressCallback returns an +// Error path: tests that Manager.RegisterReceivedProgressCallback returns an // error when no transfer with the ID exists. -func TestManager_RegisterReceiveProgressCallback_NoTransferError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) +func TestManager_RegisterReceivedProgressCallback_NoTransferError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, nil, t) tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) - err := m.RegisterReceiveProgressCallback(tid, nil, 0) + err := m.RegisterReceivedProgressCallback(tid, nil, 0) if err == nil { - t.Error("RegisterReceiveProgressCallback did not return an error " + + t.Error("RegisterReceivedProgressCallback did not return an error " + "when no transfer with the ID exists.") } } @@ -563,8 +595,18 @@ func Test_FileTransfer(t *testing.T) { filePartChan2 := make(chan message.Receive, rawMessageBuffSize) // Generate sending and receiving managers - m1 := newTestManager(false, filePartChan2, newFtChan2, nil, t) - m2 := newTestManager(false, filePartChan1, newFtChan1, receiveNewCB, t) + m1 := newTestManager(false, filePartChan2, newFtChan2, nil, nil, t) + m2 := newTestManager(false, filePartChan1, newFtChan1, receiveNewCB, nil, t) + + // Add partner + dhKey := m1.store.E2e().GetGroup().NewInt(42) + pubKey := diffieHellman.GeneratePublicKey(dhKey, m1.store.E2e().GetGroup()) + p := params.GetDefaultE2ESessionParams() + recipient := id.NewIdFromString("recipient", id.User, t) + err := m1.store.E2e().AddPartner(recipient, pubKey, dhKey, p, p) + if err != nil { + t.Errorf("Failed to add partner %s: %+v", recipient, err) + } stop1, err := m1.startProcesses(newFtChan1, filePartChan1) if err != nil { @@ -605,9 +647,8 @@ func Test_FileTransfer(t *testing.T) { partSize, _ := m1.getPartSize() fileName := "testFile" fileType := "file" - file, parts := newFile(32, uint32(partSize), prng, t) + file, parts := newFile(32, partSize, prng, t) preview := parts[0] - recipient := id.NewIdFromString("recipient", id.User, t) // Send file sendTid, err := m1.Send(fileName, fileType, file, recipient, 0.5, preview, @@ -654,8 +695,7 @@ func Test_FileTransfer(t *testing.T) { // Count the number of parts marked as received count := 0 for j := uint16(0); j < r.total; j++ { - status := r.tracker.GetPartStatus(j) - if status == 3 { + if r.tracker.GetPartStatus(j) == interfaces.FpReceived { count++ } } @@ -663,9 +703,9 @@ func Test_FileTransfer(t *testing.T) { // Ensure that the number of parts received reported by the // callback matches the number marked received if count != int(r.received) { - t.Errorf("Number of parts marked received does not match "+ - "number reported by callback.\nmarked: %d\ncallback: %d", - count, r.received) + t.Errorf("Number of parts marked received does not "+ + "match number reported by callback."+ + "\nmarked: %d\ncallback: %d", count, r.received) } return @@ -675,7 +715,7 @@ func Test_FileTransfer(t *testing.T) { t.Error("Receive progress callback never reported file finishing to receive.") }() - err = m2.RegisterReceiveProgressCallback(receiveTid, receiveCb, time.Millisecond) + err = m2.RegisterReceivedProgressCallback(receiveTid, receiveCb, time.Millisecond) if err != nil { t.Errorf("Failed to register receive progress callback: %+v", err) } diff --git a/fileTransfer/oldTransferRecovery.go b/fileTransfer/oldTransferRecovery.go new file mode 100644 index 0000000000000000000000000000000000000000..3d6c6c0d0f9a733317b4f1ab3d723607b86feb90 --- /dev/null +++ b/fileTransfer/oldTransferRecovery.go @@ -0,0 +1,109 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/primitives/id" +) + +// Error messages. +const ( + oldTransfersRoundResultsErr = "failed to recover round information for " + + "%d rounds for old file transfers after %d attempts" +) + +// roundResultsMaxAttempts is the maximum number of attempts to get round +// results via api.RoundEventCallback before stopping to try +const roundResultsMaxAttempts = 5 + +// oldTransferRecovery adds all unsent file parts back into the queue and +// updates the in-progress file parts by getting round updates. +func (m Manager) oldTransferRecovery(healthyChan chan bool) { + + // Exit if old transfers have already been recovered + if m.oldTransfersRecovered { + jww.DEBUG.Printf("Old file transfer recovery thread not starting: " + + "none to recover (app was not closed)") + return + } + + // Get list of unsent parts and rounds that parts were sent on + unsentParts, sentRounds := m.sent.GetUnsentPartsAndSentRounds() + + // Add all unsent parts to the queue + for tid, partNums := range unsentParts { + m.queueParts(tid, partNums) + } + + // Update parts that were sent by looking up the status of the rounds they + // were sent on + go func() { + err := m.updateSentRounds(healthyChan, sentRounds) + if err != nil { + jww.ERROR.Print(err) + } + }() +} + +// updateSentRounds looks up the status of each round that parts were sent on +// but never arrived. It updates the status of each part depending on if the +// round failed or succeeded. +func (m Manager) updateSentRounds(healthyChan chan bool, + sentRounds map[id.Round][]ftCrypto.TransferID) error { + // Tracks the number of attempts to get round results + var getRoundResultsAttempts int + + jww.DEBUG.Print("Starting old file transfer recovery thread.") + + // Wait for network to be healthy to attempt to get round states + for getRoundResultsAttempts < roundResultsMaxAttempts { + select { + case healthy := <-healthyChan: + // If the network is unhealthy, wait until it becomes healthy + if !healthy { + jww.DEBUG.Print("Suspending old file transfer recovery " + + "thread: network is unhealthy.") + } + for !healthy { + healthy = <-healthyChan + } + jww.DEBUG.Print("Old file transfer recovery thread: " + + "network is healthy.") + + // Register callback to get Round results and retry on error + roundList := roundIdMapToList(sentRounds) + err := m.getRoundResults(roundList, roundResultsTimeout, + m.makeRoundEventCallback(sentRounds)) + if err != nil { + jww.WARN.Printf("Failed to get round results for old "+ + "transfers for rounds %d (attempt %d/%d): %+v", + getRoundResultsAttempts, roundResultsMaxAttempts, + roundList, err) + } else { + jww.INFO.Printf("Successfully recovered old file transfers.") + return nil + } + getRoundResultsAttempts++ + } + } + + return errors.Errorf( + oldTransfersRoundResultsErr, len(sentRounds), getRoundResultsAttempts) +} + +// roundIdMapToList returns a list of all round IDs in the map. +func roundIdMapToList(roundMap map[id.Round][]ftCrypto.TransferID) []id.Round { + roundSlice := make([]id.Round, 0, len(roundMap)) + for rid := range roundMap { + roundSlice = append(roundSlice, rid) + } + return roundSlice +} diff --git a/fileTransfer/oldTransferRecovery_test.go b/fileTransfer/oldTransferRecovery_test.go new file mode 100644 index 0000000000000000000000000000000000000000..4379843c84d7743929aad122b3e0cd8aa689519b --- /dev/null +++ b/fileTransfer/oldTransferRecovery_test.go @@ -0,0 +1,337 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "gitlab.com/elixxir/client/api" + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/primitives/id" + "math/rand" + "reflect" + "sort" + "sync" + "testing" + "time" +) + +// Tests that Manager.oldTransferRecovery adds all unsent parts to the queue. +func TestManager_oldTransferRecovery(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + m, sti, _ := newTestManagerWithTransfers( + []uint16{6, 12, 18}, false, true, nil, kv, t) + + finishedRounds := make(map[id.Round][]ftCrypto.TransferID) + expectedStatus := make( + map[ftCrypto.TransferID]map[uint16]interfaces.FpStatus, len(sti)) + numCbCalls := make(map[ftCrypto.TransferID]int, len(sti)) + var numUnsent int + + for i, st := range sti { + transfer, err := m.sent.GetTransfer(st.tid) + if err != nil { + t.Fatalf("Failed to get transfer #%d %s: %+v", i, st.tid, err) + } + + expectedStatus[st.tid] = make(map[uint16]interfaces.FpStatus, st.numParts) + + // Loop through each part and set it individually + for j, k := uint16(0), 0; j < transfer.GetNumParts(); j++ { + rid := id.Round(j) + switch j % 3 { + case 0: + // Part is sent (in-progress) + _, _ = transfer.SetInProgress(rid, j) + if k%2 == 0 { + finishedRounds[rid] = append(finishedRounds[rid], st.tid) + expectedStatus[st.tid][j] = 2 + } else { + expectedStatus[st.tid][j] = 0 + numUnsent++ + } + numCbCalls[st.tid]++ + k++ + case 1: + // Part is sent and arrived (finished) + _, _ = transfer.SetInProgress(rid, j) + _, _ = transfer.FinishTransfer(rid) + finishedRounds[rid] = append(finishedRounds[rid], st.tid) + expectedStatus[st.tid][j] = 2 + case 2: + // Part is unsent (neither in-progress nor arrived) + expectedStatus[st.tid][j] = 0 + numUnsent++ + } + } + } + + // Returns an error on function and round failure on callback if sendErr is + // set; otherwise, it reports round successes and returns nil + rr := func(rIDs []id.Round, _ time.Duration, cb api.RoundEventCallback) error { + rounds := make(map[id.Round]api.RoundResult, len(rIDs)) + for _, rid := range rIDs { + if finishedRounds[rid] != nil { + rounds[rid] = api.Succeeded + } else { + rounds[rid] = api.Failed + } + } + cb(true, false, rounds) + + return nil + } + + // Load new manager from the original manager's storage + loadedManager, err := newManager( + nil, nil, nil, nil, nil, rr, kv, nil, DefaultParams()) + if err != nil { + t.Errorf("Failed to create new manager from KV: %+v", err) + } + + // Register new sent progress callbacks + cbChans := make([]chan sentProgressResults, len(sti)) + for i, st := range sti { + // Create sent progress callback and channel + cbChan := make(chan sentProgressResults, 8) + cb := func(completed bool, sent, arrived, total uint16, + tr interfaces.FilePartTracker, err error) { + cbChan <- sentProgressResults{completed, sent, arrived, total, tr, err} + } + cbChans[i] = cbChan + + err = loadedManager.RegisterSentProgressCallback(st.tid, cb, 0) + if err != nil { + t.Errorf("Failed to register SentProgressCallback for transfer "+ + "%d %s: %+v", i, st.tid, err) + } + } + + // Create health chan + healthyRecover := make(chan bool, networkHealthBuffLen) + healthyRecover <- true + + // Wait until callbacks have been called to know the transfers have been + // recovered + var wg sync.WaitGroup + for i, st := range sti { + wg.Add(1) + go func(i, callNum int, cbChan chan sentProgressResults, st sentTransferInfo) { + defer wg.Done() + for j := 0; j < callNum; j++ { + select { + case <-time.NewTimer(250 * time.Millisecond).C: + t.Errorf("Timed out waiting for SentProgressCallback #%d "+ + "for transfer #%d %s", j, i, st.tid) + return + case <-cbChan: + } + } + }(i, numCbCalls[st.tid], cbChans[i], st) + } + + loadedManager.oldTransferRecovery(healthyRecover) + + wg.Wait() + + // Check the status of each round in each transfer + for i, st := range sti { + transfer, err := loadedManager.sent.GetTransfer(st.tid) + if err != nil { + t.Fatalf("Failed to get transfer #%d %s: %+v", i, st.tid, err) + } + + _, _, _, _, track := transfer.GetProgress() + for j := uint16(0); j < track.GetNumParts(); j++ { + if track.GetPartStatus(j) != expectedStatus[st.tid][j] { + t.Errorf("Unexpected part #%d status for transfer #%d %s."+ + "\nexpected: %d\nreceived: %d", j, i, st.tid, + expectedStatus[st.tid][j], track.GetPartStatus(j)) + } + } + } + + // Check that each item in the queue is unsent + var queueCount int + for done := false; !done; { + select { + case <-time.NewTimer(5 * time.Millisecond).C: + done = true + case p := <-loadedManager.sendQueue: + queueCount++ + if expectedStatus[p.tid][p.partNum] != 0 { + t.Errorf("Part #%d for transfer %s not expected in qeueu."+ + "\nexpected: %d\nreceived: %d", p.partNum, p.tid, + expectedStatus[p.tid][p.partNum], 0) + } + } + } + + // Check that the number of items in the queue is correct + if queueCount != numUnsent { + t.Errorf("Number of items incorrect.\nexpected: %d\nreceived: %d", + numUnsent, queueCount) + } +} + +// Tests that Manager.updateSentRounds updates the status of each round +// correctly by using the part tracker and checks that all the correct parts +// were added to the queue. +func TestManager_updateSentRounds(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + m, sti, _ := newTestManagerWithTransfers( + []uint16{6, 12, 18}, false, true, nil, kv, t) + + finishedRounds := make(map[id.Round][]ftCrypto.TransferID) + expectedStatus := make( + map[ftCrypto.TransferID]map[uint16]interfaces.FpStatus, len(sti)) + var numUnsent int + + for i, st := range sti { + transfer, err := m.sent.GetTransfer(st.tid) + if err != nil { + t.Fatalf("Failed to get transfer #%d %s: %+v", i, st.tid, err) + } + + expectedStatus[st.tid] = make( + map[uint16]interfaces.FpStatus, st.numParts) + + // Loop through each part and set it individually + for j, k := uint16(0), 0; j < transfer.GetNumParts(); j++ { + rid := id.Round(j) + switch j % 3 { + case 0: + // Part is sent (in-progress) + _, _ = transfer.SetInProgress(rid, j) + if k%2 == 0 { + finishedRounds[rid] = append(finishedRounds[rid], st.tid) + expectedStatus[st.tid][j] = 2 + } else { + expectedStatus[st.tid][j] = 0 + numUnsent++ + } + k++ + case 1: + // Part is sent and arrived (finished) + _, _ = transfer.SetInProgress(rid, j) + _, _ = transfer.FinishTransfer(rid) + finishedRounds[rid] = append(finishedRounds[rid], st.tid) + expectedStatus[st.tid][j] = 2 + case 2: + // Part is unsent (neither in-progress nor arrived) + expectedStatus[st.tid][j] = 0 + } + } + } + + // Returns an error on function and round failure on callback if sendErr is + // set; otherwise, it reports round successes and returns nil + rr := func(rIDs []id.Round, _ time.Duration, cb api.RoundEventCallback) error { + rounds := make(map[id.Round]api.RoundResult, len(rIDs)) + for _, rid := range rIDs { + if finishedRounds[rid] != nil { + rounds[rid] = api.Succeeded + } else { + rounds[rid] = api.Failed + } + } + cb(true, false, rounds) + + return nil + } + + loadedManager, err := newManager( + nil, nil, nil, nil, nil, rr, kv, nil, DefaultParams()) + if err != nil { + t.Errorf("Failed to create new manager from KV: %+v", err) + } + + // Create health chan + healthyRecover := make(chan bool, networkHealthBuffLen) + healthyRecover <- true + + // Get list of rounds that parts were sent on + _, loadedSentRounds := m.sent.GetUnsentPartsAndSentRounds() + + err = loadedManager.updateSentRounds(healthyRecover, loadedSentRounds) + if err != nil { + t.Errorf("updateSentRounds returned an error: %+v", err) + } + + // Check the status of each round in each transfer + for i, st := range sti { + transfer, err := loadedManager.sent.GetTransfer(st.tid) + if err != nil { + t.Fatalf("Failed to get transfer #%d %s: %+v", i, st.tid, err) + } + + _, _, _, _, track := transfer.GetProgress() + for j := uint16(0); j < track.GetNumParts(); j++ { + if track.GetPartStatus(j) != expectedStatus[st.tid][j] { + t.Errorf("Unexpected part #%d status for transfer #%d %s."+ + "\nexpected: %d\nreceived: %d", j, i, st.tid, + expectedStatus[st.tid][j], track.GetPartStatus(j)) + } + } + } + + // Check that each item in the queue is unsent + var queueCount int + for done := false; !done; { + select { + case <-time.NewTimer(5 * time.Millisecond).C: + done = true + case p := <-loadedManager.sendQueue: + queueCount++ + if expectedStatus[p.tid][p.partNum] != 0 { + t.Errorf("Part #%d for transfer %s not expected in qeueu."+ + "\nexpected: %d\nreceived: %d", p.partNum, p.tid, + expectedStatus[p.tid][p.partNum], 0) + } + } + } + + // Check that the number of items in the queue is correct + if queueCount != numUnsent { + t.Errorf("Number of items incorrect.\nexpected: %d\nreceived: %d", + numUnsent, queueCount) + } +} + +// Tests that roundIdMapToList returns all the round IDs in the map. +func Test_roundIdMapToList(t *testing.T) { + n := 10 + roundMap := make(map[id.Round][]ftCrypto.TransferID, n) + expectedRoundList := make([]id.Round, n) + + csPrng := NewPrng(42) + prng := rand.New(rand.NewSource(42)) + + for i := 0; i < n; i++ { + rid := id.Round(i) + roundMap[rid] = make([]ftCrypto.TransferID, prng.Intn(10)) + for j := range roundMap[rid] { + roundMap[rid][j], _ = ftCrypto.NewTransferID(csPrng) + } + + expectedRoundList[i] = rid + } + + receivedRoundList := roundIdMapToList(roundMap) + + sort.SliceStable(receivedRoundList, func(i, j int) bool { + return receivedRoundList[i] < receivedRoundList[j] + }) + + if !reflect.DeepEqual(expectedRoundList, receivedRoundList) { + t.Errorf("Round list does not match expected."+ + "\nexpected: %v\nreceived: %v", + expectedRoundList, receivedRoundList) + } +} diff --git a/fileTransfer/receiveNew_test.go b/fileTransfer/receiveNew_test.go index 5bbfe8e523b7fb6c5ba0ddd840416f5e4a4c42f1..872012beee4ed019a035212c85f45a16ea7a1926 100644 --- a/fileTransfer/receiveNew_test.go +++ b/fileTransfer/receiveNew_test.go @@ -31,7 +31,7 @@ func TestManager_receiveNewFileTransfer(t *testing.T) { } // Create new manager, stoppable, and channel to receive messages - m := newTestManager(false, nil, nil, receiveCB, t) + m := newTestManager(false, nil, nil, receiveCB, nil, t) stop := stoppable.NewSingle(newFtStoppableName) rawMsgs := make(chan message.Receive, rawMessageBuffSize) @@ -100,7 +100,7 @@ func TestManager_receiveNewFileTransfer_Stop(t *testing.T) { } // Create new manager, stoppable, and channel to receive messages - m := newTestManager(false, nil, nil, receiveCB, t) + m := newTestManager(false, nil, nil, receiveCB, nil, t) stop := stoppable.NewSingle(newFtStoppableName) rawMsgs := make(chan message.Receive, rawMessageBuffSize) @@ -161,7 +161,7 @@ func TestManager_receiveNewFileTransfer_InvalidMessageError(t *testing.T) { } // Create new manager, stoppable, and channel to receive messages - m := newTestManager(false, nil, nil, receiveCB, t) + m := newTestManager(false, nil, nil, receiveCB, nil, t) stop := stoppable.NewSingle(newFtStoppableName) rawMsgs := make(chan message.Receive, rawMessageBuffSize) @@ -194,7 +194,7 @@ func TestManager_receiveNewFileTransfer_InvalidMessageError(t *testing.T) { // Tests that Manager.readNewFileTransferMessage returns the expected sender ID, // file size, and preview. func TestManager_readNewFileTransferMessage(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) // Create new message.Send containing marshalled NewFileTransfer recipient := id.NewIdFromString("recipient", id.User, t) @@ -254,7 +254,7 @@ func TestManager_readNewFileTransferMessage(t *testing.T) { // Error path: tests that Manager.readNewFileTransferMessage returns the // expected error when the message.Receive has the wrong MessageType. func TestManager_readNewFileTransferMessage_MessageTypeError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) expectedErr := receiveMessageTypeErr // Create message.Receive with marshalled NewFileTransfer @@ -275,7 +275,7 @@ func TestManager_readNewFileTransferMessage_MessageTypeError(t *testing.T) { // expected error when the payload of the message.Receive cannot be // unmarshalled. func TestManager_readNewFileTransferMessage_ProtoUnmarshalError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) expectedErr := strings.Split(protoUnmarshalErr, "%")[0] // Create message.Receive with marshalled NewFileTransfer diff --git a/fileTransfer/receive_test.go b/fileTransfer/receive_test.go index 9d56109e8c0b891aac12e091aec6ec5bd3b7fdaf..2b51e234471f4478828dc63d63f688cc76746a8d 100644 --- a/fileTransfer/receive_test.go +++ b/fileTransfer/receive_test.go @@ -24,8 +24,8 @@ import ( // receiving a single message. func TestManager_receive(t *testing.T) { // Build a manager for sending and a manger for receiving - m1 := newTestManager(false, nil, nil, nil, t) - m2 := newTestManager(false, nil, nil, nil, t) + m1 := newTestManager(false, nil, nil, nil, nil, t) + m2 := newTestManager(false, nil, nil, nil, nil, t) // Create transfer components prng := NewPrng(42) @@ -34,7 +34,7 @@ func TestManager_receive(t *testing.T) { numParts := uint16(16) numFps := calcNumberOfFingerprints(numParts, 0.5) partSize, _ := m1.getPartSize() - file, parts := newFile(numParts, uint32(partSize), prng, t) + file, parts := newFile(numParts, partSize, prng, t) fileSize := uint32(len(file)) mac := ftCrypto.CreateTransferMAC(file, key) @@ -142,8 +142,8 @@ func TestManager_receive(t *testing.T) { // stoppable is triggered. func TestManager_receive_Stop(t *testing.T) { // Build a manager for sending and a manger for receiving - m1 := newTestManager(false, nil, nil, nil, t) - m2 := newTestManager(false, nil, nil, nil, t) + m1 := newTestManager(false, nil, nil, nil, nil, t) + m2 := newTestManager(false, nil, nil, nil, nil, t) // Create transfer components prng := NewPrng(42) @@ -152,7 +152,7 @@ func TestManager_receive_Stop(t *testing.T) { numParts := uint16(16) numFps := calcNumberOfFingerprints(numParts, 0.5) partSize, _ := m1.getPartSize() - file, parts := newFile(numParts, uint32(partSize), prng, t) + file, parts := newFile(numParts, partSize, prng, t) fileSize := uint32(len(file)) mac := ftCrypto.CreateTransferMAC(file, key) @@ -248,8 +248,8 @@ func TestManager_receive_Stop(t *testing.T) { func TestManager_readMessage(t *testing.T) { // Build a manager for sending and a manger for receiving - m1 := newTestManager(false, nil, nil, nil, t) - m2 := newTestManager(false, nil, nil, nil, t) + m1 := newTestManager(false, nil, nil, nil, nil, t) + m2 := newTestManager(false, nil, nil, nil, nil, t) // Create transfer components prng := NewPrng(42) @@ -258,7 +258,7 @@ func TestManager_readMessage(t *testing.T) { numParts := uint16(16) numFps := calcNumberOfFingerprints(numParts, 0.5) partSize, _ := m1.getPartSize() - file, parts := newFile(numParts, uint32(partSize), prng, t) + file, parts := newFile(numParts, partSize, prng, t) fileSize := uint32(len(file)) mac := ftCrypto.CreateTransferMAC(file, key) diff --git a/fileTransfer/send.go b/fileTransfer/send.go index 99bfd34b5a9ab34e0de538408ff831fe3b055d66..01ea18ae207e4526e2edb4cb182c0b7c1b5a1ae3 100644 --- a/fileTransfer/send.go +++ b/fileTransfer/send.go @@ -15,11 +15,14 @@ import ( "gitlab.com/elixxir/client/api" "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/interfaces/utility" "gitlab.com/elixxir/client/stoppable" ftStorage "gitlab.com/elixxir/client/storage/fileTransfer" + ds "gitlab.com/elixxir/comms/network/dataStructures" ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" "gitlab.com/elixxir/crypto/shuffle" "gitlab.com/elixxir/primitives/format" + "gitlab.com/elixxir/primitives/states" "gitlab.com/xx_network/crypto/csprng" "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/netTime" @@ -31,7 +34,7 @@ const ( // Manager.sendParts sendManyCmixWarn = "Failed to send %d file parts %v via SendManyCMIX: %+v" setInProgressErr = "Failed to set parts %v to in-progress for transfer %s" - getRoundResultsErr = "Failed to get round results for round %d: %+v" + getRoundResultsErr = "Failed to get round results for round %d for file transfers %v: %+v" // Manager.buildMessages noSentTransferWarn = "Could not get transfer %s for part %d: %+v" @@ -39,12 +42,17 @@ const ( newCmixMessageErr = "Failed to assemble cMix message for file part %d on transfer %s: %+v" // Manager.makeRoundEventCallback - finishTransferPanic = "Failed to set part(s) to finished for transfer %s: %+v" - roundFailureWarn = "Failed to send file parts for transmission %s on round %d: %s" - unsetInProgressPanic = "Failed to remove parts from in-progress list for transfer %s after %s" - roundFailureCbErr = "Failed to send parts %v for transfer %s on round %d: %s" - roundFailErr = "round failed" - trackerTimeoutErr = "tracker timed out" + finishPassNoTransferErr = "Failed to mark in-progress parts as finished on success of round %d for transfer %s: %+v" + finishTransferErr = "Failed to set part(s) to finished for transfer %s: %+v" + finishedEndE2eMsfErr = "Failed to send E2E message to %s on completion of file transfer %s: %+v" + finishFailNoTransferErr = "Failed to requeue in-progress parts on failure of round %d for transfer %s: %+v" + roundFailureWarn = "Failed to send file parts for file transfers %v on round %d: round %s" + unsetInProgressErr = "Failed to remove parts from in-progress list for transfer %s: round %s" + roundFailureCbErr = "Failed to send parts %d for transfer %s on round %d: round %s" + + // Manager.sendEndE2eMessage + endE2eGetPartnerErr = "failed to get file transfer partner %s: %+v" + endE2eSendErr = "failed to send end file transfer message: %+v" // getRandomNumParts getRandomNumPartsRandPanic = "Failed to generate random number of file parts to send: %+v" @@ -60,7 +68,7 @@ const pollSleepDuration = 100 * time.Millisecond // receives a random number between 1 and 11 of file parts, they are encrypted, // put into cMix messages, and sent to their recipients. Failed messages are // added to the end of the queue. -func (m *Manager) sendThread(stop *stoppable.Single, getNumParts getRngNum) { +func (m *Manager) sendThread(stop *stoppable.Single, healthChan chan bool, getNumParts getRngNum) { jww.DEBUG.Print("Starting file part sending thread.") // Calculate the average amount of data sent via SendManyCMIX @@ -91,15 +99,16 @@ func (m *Manager) sendThread(stop *stoppable.Single, getNumParts getRngNum) { m.closeSendThread(partList, stop) return - case healthy := <-m.healthy: + case healthy := <-healthChan: // If the network is unhealthy, wait until it becomes healthy - jww.TRACE.Print("Suspending file part sending thread: network is " + - "unhealthy") + if !healthy { + jww.TRACE.Print("Suspending file part sending thread: " + + "network is unhealthy.") + } for !healthy { - healthy = <-m.healthy + healthy = <-healthChan } - jww.TRACE.Print("Resuming file part sending thread: network is " + - "healthy") + jww.TRACE.Print("File part sending thread: network is healthy.") case part := <-m.sendQueue: // When a part is received from the queue, add it to the list of // parts to be sent @@ -144,7 +153,7 @@ func (m *Manager) sendThread(stop *stoppable.Single, getNumParts getRngNum) { // queue and setting the stoppable to stopped. func (m *Manager) closeSendThread(partList []queuedPart, stop *stoppable.Single) { // Exit the thread if the stoppable is triggered - jww.DEBUG.Print("Stopping file part sending thread: stoppable triggered") + jww.DEBUG.Print("Stopping file part sending thread: stoppable triggered.") // Add all the unsent parts back in the queue for _, part := range partList { @@ -175,13 +184,11 @@ func (m *Manager) handleSend(partList *[]queuedPart, lastSend *time.Time, } // Send all the messages - rng := m.rng.GetStream() - err := m.sendParts(*partList, rng) + err := m.sendParts(*partList) *lastSend = netTime.Now() if err != nil { jww.FATAL.Panic(err) } - rng.Close() // Clear partList once done *partList = nil @@ -191,11 +198,11 @@ func (m *Manager) handleSend(partList *[]queuedPart, lastSend *time.Time, // sendParts handles the composing and sending of a cMix message for each part // in the list. All errors returned are fatal errors. -func (m *Manager) sendParts(partList []queuedPart, rng csprng.Source) error { +func (m *Manager) sendParts(partList []queuedPart) error { // Build cMix messages - messages, transfers, groupedParts, partsToResend, err := m.buildMessages( - partList, rng) + messages, transfers, groupedParts, partsToResend, err := + m.buildMessages(partList) if err != nil { return err } @@ -220,6 +227,9 @@ func (m *Manager) sendParts(partList []queuedPart, rng csprng.Source) error { return nil } + // Create list for transfer IDs to watch with the round results callback + tIDs := make([]ftCrypto.TransferID, 0, len(transfers)) + // Set all parts to in-progress for tid, transfer := range transfers { err, exists := transfer.SetInProgress(rid, groupedParts[tid]...) @@ -229,19 +239,23 @@ func (m *Manager) sendParts(partList []queuedPart, rng csprng.Source) error { transfer.CallProgressCB(nil) - // Set up tracker waiting for the round to end to update state and - // progress; skip if the tracker has already been launched for this - // transfer and round ID + // Add transfer ID to list to be tracked; skip if the tracker has + // already been launched for this transfer and round ID if !exists { - roundResultCB := m.makeRoundEventCallback(rid, tid, transfer) - err = m.getRoundResults( - []id.Round{rid}, roundResultsTimeout, roundResultCB) - if err != nil { - return errors.Errorf(getRoundResultsErr, rid, err) - } + tIDs = append(tIDs, tid) } } + // Set up tracker waiting for the round to end to update state and update + // progress + roundResultCB := m.makeRoundEventCallback( + map[id.Round][]ftCrypto.TransferID{rid: tIDs}) + err = m.getRoundResults( + []id.Round{rid}, roundResultsTimeout, roundResultCB) + if err != nil { + return errors.Errorf(getRoundResultsErr, rid, tIDs, err) + } + return nil } @@ -252,7 +266,7 @@ func (m *Manager) sendParts(partList []queuedPart, rng csprng.Source) error { // of partList index of parts that will be sent. Any part that encounters a // non-fatal error will be skipped and will not be included in an of the lists. // All errors returned are fatal errors. -func (m *Manager) buildMessages(partList []queuedPart, rng csprng.Source) ( +func (m *Manager) buildMessages(partList []queuedPart) ( []message.TargetedCmixMessage, map[ftCrypto.TransferID]*ftStorage.SentTransfer, map[ftCrypto.TransferID][]uint16, []int, error) { messages := make([]message.TargetedCmixMessage, 0, len(partList)) @@ -260,6 +274,9 @@ func (m *Manager) buildMessages(partList []queuedPart, rng csprng.Source) ( groupedParts := map[ftCrypto.TransferID][]uint16{} partsToResend := make([]int, 0, len(partList)) + rng := m.rng.GetStream() + defer rng.Close() + for i, part := range partList { // Lookup the transfer by the ID; if the transfer does not exist, then // print a warning and skip this message @@ -337,80 +354,160 @@ func (m *Manager) newCmixMessage(transfer *ftStorage.SentTransfer, } // makeRoundEventCallback returns an api.RoundEventCallback that is called once -// the round file parts were sent on either succeeds or fails. If the round -// succeeds, then all file parts are marked as finished and the progress -// callback is called with the current progress. If the round fails, then each -// part is removed from the in-progress list, added to the end of the sending -// queue, and the callback called with an error. -func (m *Manager) makeRoundEventCallback(rid id.Round, tid ftCrypto.TransferID, - transfer *ftStorage.SentTransfer) api.RoundEventCallback { - - return func(allSucceeded, timedOut bool, _ map[id.Round]api.RoundResult) { - if allSucceeded { - // If the round succeeded, then set all parts for this round to - // finished and call the progress callback - err := transfer.FinishTransfer(rid) - if err != nil { - jww.FATAL.Panicf(finishTransferPanic, tid, err) +// the round that file parts were sent on either succeeds or fails. If the round +// succeeds, then all file parts for each transfer are marked as finished and +// the progress callback is called with the current progress. If the round +// fails, then each part for each transfer is removed from the in-progress list, +// added to the end of the sending queue, and the callback called with an error. +func (m *Manager) makeRoundEventCallback( + sentRounds map[id.Round][]ftCrypto.TransferID) api.RoundEventCallback { + + return func(allSucceeded, timedOut bool, rounds map[id.Round]api.RoundResult) { + for rid, roundResult := range rounds { + if roundResult == api.Succeeded { + // If the round succeeded, then set all parts for each transfer + // for this round to finished and call the progress callback + for _, tid := range sentRounds[rid] { + transfer, err := m.sent.GetTransfer(tid) + if err != nil { + jww.ERROR.Printf(finishPassNoTransferErr, rid, tid, err) + continue + } + + // Mark as finished + completed, err := transfer.FinishTransfer(rid) + if err != nil { + jww.ERROR.Printf(finishTransferErr, tid, err) + continue + } + + // If the transfer is complete, send an E2E message to the + // recipient informing them + if completed { + go func(tid ftCrypto.TransferID, recipient *id.ID) { + err = m.sendEndE2eMessage(recipient) + if err != nil { + jww.ERROR.Printf(finishedEndE2eMsfErr, + recipient, tid, err) + } + }(tid, transfer.GetRecipient()) + } + + transfer.CallProgressCB(nil) + } + } else { + + jww.WARN.Printf(roundFailureWarn, sentRounds[rid], rid, roundResult) + + // If the round failed, then remove all parts for each transfer + // for this round from the in-progress list, call the progress + // callback with an error, and add the parts back into the queue + for _, tid := range sentRounds[rid] { + transfer, err := m.sent.GetTransfer(tid) + if err != nil { + jww.ERROR.Printf(finishFailNoTransferErr, rid, tid, err) + continue + } + + // Remove parts from in-progress list + partsToResend, err := transfer.UnsetInProgress(rid) + if err != nil { + jww.ERROR.Printf(unsetInProgressErr, tid, roundResult) + } + + // Return the error on the progress callback + cbErr := errors.Errorf( + roundFailureCbErr, partsToResend, tid, rid, roundResult) + transfer.CallProgressCB(cbErr) + + // Add all the unsent parts back in the queue + m.queueParts(tid, partsToResend) + } } + } + } +} + +// sendEndE2eMessage sends an E2E message to the recipient once the transfer +// complete information them that all file parts have been sent. +// TODO: test +func (m *Manager) sendEndE2eMessage(recipient *id.ID) error { + // Get the partner + partner, err := m.store.E2e().GetPartner(recipient) + if err != nil { + return errors.Errorf(endE2eGetPartnerErr, recipient, err) + } - transfer.CallProgressCB(nil) - } else { - // If the round failed, then remove all parts for this round from - // the in-progress list, call the progress callback with an error, - // and add the parts back into the queue + // Build the message + sendMsg := message.Send{ + Recipient: recipient, + MessageType: message.EndFileTransfer, + } - // Determine the correct error message, for logging - roundErr := roundFailErr - if timedOut { - roundErr = trackerTimeoutErr - } + // Send the message under file transfer preimage + e2eParams := params.GetDefaultE2E() + e2eParams.IdentityPreimage = partner.GetFileTransferPreimage() - jww.WARN.Printf(roundFailureWarn, tid, rid, roundErr) + // Store the message in the critical messages buffer first to ensure it is + // present if the send fails + m.store.GetCriticalMessages().AddProcessing(sendMsg, e2eParams) - // Remove parts from in-progress list - partsToResend, err := transfer.UnsetInProgress(rid) - if err != nil { - jww.FATAL.Panicf(unsetInProgressPanic, tid, roundErr) - } + rounds, e2eMsgID, _, err := m.net.SendE2E(sendMsg, e2eParams, nil) + if err != nil { + return errors.Errorf(endE2eSendErr, err) + } - // Return the error on the progress callback - cbErr := errors.Errorf( - roundFailureCbErr, partsToResend, tid, rid, roundErr) - transfer.CallProgressCB(cbErr) + // Register the event for all rounds + sendResults := make(chan ds.EventReturn, len(rounds)) + roundEvents := m.net.GetInstance().GetRoundEvents() + for _, r := range rounds { + roundEvents.AddRoundEventChan(r, sendResults, 10*time.Second, + states.COMPLETED, states.FAILED) + } - // Add all the unsent parts back in the queue - for _, partNum := range partsToResend { - m.sendQueue <- queuedPart{ - tid: tid, - partNum: partNum, - } - } - } + // Wait until the result tracking responds + success, numTimeOut, numRoundFail := utility.TrackResults( + sendResults, len(rounds)) + + // If a single partition of the end file transfer message does not transmit, + // then the partner will not be able to read the confirmation + if !success { + jww.ERROR.Printf("Sending E2E message %s to end file transfer with "+ + "%s failed to transmit %d/%d partitions: %d round failures, %d "+ + "timeouts", recipient, e2eMsgID, numRoundFail+numTimeOut, + len(rounds), numRoundFail, numTimeOut) + m.store.GetCriticalMessages().Failed(sendMsg, e2eParams) + return nil } + + // Otherwise, the transmission is a success and this should be denoted in + // the session and the log + m.store.GetCriticalMessages().Succeeded(sendMsg, e2eParams) + jww.INFO.Printf("Sending of message %s informing %s that a transfer ended"+ + "successful.", e2eMsgID, recipient) + + return nil } -// queueParts adds an entry for each file part in a transfer into the sendQueue +// queueParts adds an entry for each file part in the list into the sendQueue // channel in a random order. -func (m *Manager) queueParts(tid ftCrypto.TransferID, numParts uint16) { - // Add each part number to the buffer in the shuffled order - for _, partNum := range getShuffledPartNumList(numParts) { - m.sendQueue <- queuedPart{tid, uint16(partNum)} +func (m *Manager) queueParts(tid ftCrypto.TransferID, partNums []uint16) { + // Shuffle the list + shuffle.Shuffle16(&partNums) + + // Add each part to the queue + for _, partNum := range partNums { + m.sendQueue <- queuedPart{tid, partNum} } } -// getShuffledPartNumList returns a list of number of file part, from 0 to -// numParts, shuffled in a random order. -func getShuffledPartNumList(numParts uint16) []uint64 { - // Create list of part numbers - partNumList := make([]uint64, numParts) +// makeListOfPartNums returns a list of number of file part, from 0 to numParts. +func makeListOfPartNums(numParts uint16) []uint16 { + partNumList := make([]uint16, numParts) for i := range partNumList { - partNumList[i] = uint64(i) + partNumList[i] = uint16(i) } - // Shuffle list of part numbers - shuffle.Shuffle(&partNumList) - return partNumList } diff --git a/fileTransfer/sendNew.go b/fileTransfer/sendNew.go index 1513453c3af06c0affe40ba9a616b1d8fb4ce800..e78b7f9c4f49ee96b55abf550e46083bf06dc2b6 100644 --- a/fileTransfer/sendNew.go +++ b/fileTransfer/sendNew.go @@ -18,8 +18,8 @@ import ( // Error messages. const ( - protoMarshalErr = "failed to form new file transfer message: %+v" - sendE2eErr = "failed to send new file transfer message via E2E to recipient %s: %+v" + newFtProtoMarshalErr = "failed to form new file transfer message: %+v" + newFtSendE2eErr = "failed to send new file transfer message via E2E to recipient %s: %+v" ) // sendNewFileTransfer sends the initial file transfer message over E2E. @@ -31,13 +31,13 @@ func (m *Manager) sendNewFileTransfer(recipient *id.ID, fileName, sendMsg, err := newNewFileTransferE2eMessage(recipient, fileName, fileType, key, mac, numParts, fileSize, retry, preview) if err != nil { - return errors.Errorf(protoMarshalErr, err) + return errors.Errorf(newFtProtoMarshalErr, err) } // Send E2E message rounds, _, _, err := m.net.SendE2E(sendMsg, params.GetDefaultE2E(), nil) if err != nil && len(rounds) == 0 { - return errors.Errorf(sendE2eErr, recipient, err) + return errors.Errorf(newFtSendE2eErr, recipient, err) } return nil diff --git a/fileTransfer/sendNew_test.go b/fileTransfer/sendNew_test.go index 53adcb934d8dbf3d544087a27ad5780c35f631ad..6465f359c41ce50b65781729777e5be3b6b3d6cc 100644 --- a/fileTransfer/sendNew_test.go +++ b/fileTransfer/sendNew_test.go @@ -21,7 +21,7 @@ import ( // Tests that the E2E message sent via Manager.sendNewFileTransfer matches // expected. func TestManager_sendNewFileTransfer(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) recipient := id.NewIdFromString("recipient", id.User, t) fileName := "testFile" @@ -55,12 +55,12 @@ func TestManager_sendNewFileTransfer(t *testing.T) { // when SendE2E fails. func TestManager_sendNewFileTransfer_E2eError(t *testing.T) { // Create new test manager with a SendE2E error triggered - m := newTestManager(true, nil, nil, nil, t) + m := newTestManager(true, nil, nil, nil, nil, t) recipient := id.NewIdFromString("recipient", id.User, t) key, _ := ftCrypto.NewTransferKey(NewPrng(42)) - expectedErr := fmt.Sprintf(sendE2eErr, recipient, "") + expectedErr := fmt.Sprintf(newFtSendE2eErr, recipient, "") err := m.sendNewFileTransfer(recipient, "", "", key, nil, 16, 256, 1.5, nil) if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Errorf("sendNewFileTransfer di dnot return the expected error when "+ diff --git a/fileTransfer/send_test.go b/fileTransfer/send_test.go index 51ab9f3aaac046cdea5126ca8469248de00803ee..37e65378208963351de4c5b5e48170b23210f3db 100644 --- a/fileTransfer/send_test.go +++ b/fileTransfer/send_test.go @@ -10,6 +10,7 @@ package fileTransfer import ( "bytes" "fmt" + "gitlab.com/elixxir/client/api" "gitlab.com/elixxir/client/interfaces" "gitlab.com/elixxir/client/stoppable" ftStorage "gitlab.com/elixxir/client/storage/fileTransfer" @@ -31,7 +32,8 @@ import ( // Tests that Manager.sendThread successfully sends the parts and reports their // progress on the callback. func TestManager_sendThread(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers([]uint16{12, 4, 1}, false, nil, t) + m, sti, _ := newTestManagerWithTransfers( + []uint16{12, 4, 1}, false, true, nil, nil, t) // Add three transfers partsToSend := [][]uint16{ @@ -46,7 +48,7 @@ func TestManager_sendThread(t *testing.T) { go func(i int, st sentTransferInfo) { for j := 0; j < 2; j++ { select { - case <-time.NewTimer(20 * time.Millisecond).C: + case <-time.NewTimer(160 * time.Millisecond).C: t.Errorf("Timed out waiting for callback #%d", i) case r := <-st.cbChan: if j > 0 { @@ -82,9 +84,13 @@ func TestManager_sendThread(t *testing.T) { return len(queuedParts) } + // Register channel for health tracking + healthyChan := make(chan bool) + m.net.GetHealthTracker().AddChannel(healthyChan) + // Start sending thread stop := stoppable.NewSingle("testSendThreadStoppable") - go m.sendThread(stop, getNumParts) + go m.sendThread(stop, healthyChan, getNumParts) // Add parts to queue for _, part := range queuedParts { @@ -107,7 +113,8 @@ func TestManager_sendThread(t *testing.T) { // Tests that Manager.sendThread successfully sends a partially filled batch // of the correct length when its times out waiting for messages. func TestManager_sendThread_Timeout(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers([]uint16{12, 4, 1}, false, nil, t) + m, sti, _ := newTestManagerWithTransfers( + []uint16{12, 4, 1}, false, false, nil, nil, t) // Add three transfers partsToSend := [][]uint16{ @@ -155,9 +162,13 @@ func TestManager_sendThread_Timeout(t *testing.T) { return len(queuedParts) } + // Register channel for health tracking + healthyChan := make(chan bool) + m.net.GetHealthTracker().AddChannel(healthyChan) + // Start sending thread stop := stoppable.NewSingle("testSendThreadStoppable") - go m.sendThread(stop, getNumParts) + go m.sendThread(stop, healthyChan, getNumParts) // Add parts to queue for _, part := range queuedParts[:5] { @@ -182,8 +193,8 @@ func TestManager_sendThread_Timeout(t *testing.T) { // Tests that Manager.sendParts sends all the correct cMix messages and calls // the progress callbacks with the correct values. func TestManager_sendParts(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers([]uint16{12, 4, 1}, false, nil, t) - prng := NewPrng(42) + m, sti, _ := newTestManagerWithTransfers( + []uint16{12, 4, 1}, false, true, nil, nil, t) // Add three transfers partsToSend := [][]uint16{ @@ -229,7 +240,7 @@ func TestManager_sendParts(t *testing.T) { queuedParts[i], queuedParts[j] = queuedParts[j], queuedParts[i] }) - err := m.sendParts(queuedParts, prng) + err := m.sendParts(queuedParts) if err != nil { t.Errorf("sendParts returned an error: %+v", err) } @@ -267,8 +278,8 @@ func TestManager_sendParts(t *testing.T) { // parts back into the queue, does not call the callback, and does not update // the progress. func TestManager_sendParts_SendManyCmixError(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers([]uint16{12, 4, 1}, true, nil, t) - prng := NewPrng(42) + m, sti, _ := newTestManagerWithTransfers( + []uint16{12, 4, 1}, true, false, nil, nil, t) partsToSend := [][]uint16{ {0, 1, 3, 5, 6, 7}, {1, 2, 3, 0}, @@ -303,7 +314,7 @@ func TestManager_sendParts_SendManyCmixError(t *testing.T) { } } - err := m.sendParts(queuedParts, prng) + err := m.sendParts(queuedParts) if err != nil { t.Errorf("sendParts returned an error: %+v", err) } @@ -324,8 +335,8 @@ func TestManager_sendParts_SendManyCmixError(t *testing.T) { // Tests that Manager.buildMessages returns the expected values for a group // of 11 file parts from three different transfers. func TestManager_buildMessages(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers([]uint16{12, 4, 1}, false, nil, t) - prng := NewPrng(42) + m, sti, _ := newTestManagerWithTransfers( + []uint16{12, 4, 1}, false, false, nil, nil, t) partsToSend := [][]uint16{ {0, 1, 3, 5, 6, 7}, {1, 2, 3, 0}, @@ -349,9 +360,8 @@ func TestManager_buildMessages(t *testing.T) { }) // Build the messages - prng = NewPrng(2) - messages, transfers, groupedParts, partsToResend, err := m.buildMessages( - queuedParts, prng) + messages, transfers, groupedParts, partsToResend, err := + m.buildMessages(queuedParts) if err != nil { t.Errorf("buildMessages returned an error: %+v", err) } @@ -420,7 +430,7 @@ func TestManager_buildMessages(t *testing.T) { // Tests that Manager.buildMessages skips file parts with deleted transfers or // transfers that have run out of fingerprints. func TestManager_buildMessages_MessageBuildFailureError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) // Add transfer @@ -484,7 +494,7 @@ func TestManager_buildMessages_MessageBuildFailureError(t *testing.T) { // Build the messages prng = NewPrng(46) messages, transfers, groupedParts, partsToResend, err := m.buildMessages( - queuedParts, prng) + queuedParts) if err != nil { t.Errorf("buildMessages returned an error: %+v", err) } @@ -523,7 +533,7 @@ func TestManager_buildMessages_MessageBuildFailureError(t *testing.T) { // Tests that Manager.buildMessages returns the expected error when a queued // part has an invalid part number. func TestManager_buildMessages_NewCmixMessageError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) // Add transfer prng := NewPrng(42) @@ -541,7 +551,7 @@ func TestManager_buildMessages_NewCmixMessageError(t *testing.T) { // Build the messages expectedErr := fmt.Sprintf( newCmixMessageErr, queuedParts[0].partNum, tid, "") - _, _, _, _, err = m.buildMessages(queuedParts, prng) + _, _, _, _, err = m.buildMessages(queuedParts) if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Errorf("buildMessages did not return the expected error when the "+ "queuedPart part number is invalid.\nexpected: %s\nreceived: %+v", @@ -553,13 +563,13 @@ func TestManager_buildMessages_NewCmixMessageError(t *testing.T) { // Tests that Manager.newCmixMessage returns a format.Message with the correct // MAC, fingerprint, and contents. func TestManager_newCmixMessage(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) prng := NewPrng(42) tid, _ := ftCrypto.NewTransferID(prng) key, _ := ftCrypto.NewTransferKey(prng) recipient := id.NewIdFromString("recipient", id.User, t) partSize, _ := m.getPartSize() - _, parts := newFile(16, uint32(partSize), prng, t) + _, parts := newFile(16, partSize, prng, t) numFps := calcNumberOfFingerprints(uint16(len(parts)), 1.5) kv := versioned.NewKV(make(ekv.Memstore)) @@ -601,7 +611,7 @@ func TestManager_newCmixMessage(t *testing.T) { // Tests that Manager.makeRoundEventCallback returns a callback that calls the // progress callback when a round succeeds. func TestManager_makeRoundEventCallback(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) // Add transfer type callbackResults struct { @@ -621,7 +631,7 @@ func TestManager_makeRoundEventCallback(t *testing.T) { for i := 0; i < 2; i++ { select { case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting for callback.") + t.Errorf("Timed out waiting for callback (%d).", i) case r := <-callbackChan: switch i { case 0: @@ -661,16 +671,17 @@ func TestManager_makeRoundEventCallback(t *testing.T) { rid := id.Round(42) - _, transfers, groupedParts, _, err := m.buildMessages(queuedParts, prng) + _, transfers, groupedParts, _, err := m.buildMessages(queuedParts) // Set all parts to in-progress for tid, transfer := range transfers { _, _ = transfer.SetInProgress(rid, groupedParts[tid]...) } - roundEventCB := m.makeRoundEventCallback(rid, tid, transfers[tid]) + roundEventCB := m.makeRoundEventCallback( + map[id.Round][]ftCrypto.TransferID{rid: {tid}}) - roundEventCB(true, false, nil) + roundEventCB(true, false, map[id.Round]api.RoundResult{rid: api.Succeeded}) <-done1 } @@ -678,7 +689,7 @@ func TestManager_makeRoundEventCallback(t *testing.T) { // Tests that Manager.makeRoundEventCallback returns a callback that calls the // progress callback with the correct error when a round fails. func TestManager_makeRoundEventCallback_RoundFailure(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) rid := id.Round(42) @@ -718,7 +729,7 @@ func TestManager_makeRoundEventCallback_RoundFailure(t *testing.T) { done0 <- true case 1: expectedErr := fmt.Sprintf( - roundFailureCbErr, partsToSend, tid, rid, roundFailErr) + roundFailureCbErr, partsToSend, tid, rid, api.Failed) if r.err == nil || !strings.Contains(r.err.Error(), expectedErr) { t.Errorf("Callback received unexpected error when round "+ "failed.\nexpected: %s\nreceived: %v", expectedErr, r.err) @@ -735,7 +746,7 @@ func TestManager_makeRoundEventCallback_RoundFailure(t *testing.T) { queuedParts[i] = queuedPart{tid, uint16(i)} } - _, transfers, groupedParts, _, err := m.buildMessages(queuedParts, prng) + _, transfers, groupedParts, _, err := m.buildMessages(queuedParts) <-done0 @@ -744,53 +755,18 @@ func TestManager_makeRoundEventCallback_RoundFailure(t *testing.T) { _, _ = transfer.SetInProgress(rid, groupedParts[tid]...) } - roundEventCB := m.makeRoundEventCallback(rid, tid, transfers[tid]) + roundEventCB := m.makeRoundEventCallback( + map[id.Round][]ftCrypto.TransferID{rid: {tid}}) - roundEventCB(false, false, nil) + roundEventCB(false, false, map[id.Round]api.RoundResult{rid: api.Failed}) <-done1 } -// Panic path: tests that Manager.makeRoundEventCallback panics when -// SentTransfer.FinishTransfer returns an error because the file parts had not -// previously been set to in-progress. -func TestManager_makeRoundEventCallback_FinishTransferPanic(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) - - prng := NewPrng(42) - recipient := id.NewIdFromString("recipient", id.User, t) - key, _ := ftCrypto.NewTransferKey(prng) - _, parts := newFile(4, 64, prng, t) - tid, err := m.sent.AddTransfer( - recipient, key, parts, 6, nil, time.Millisecond, prng) - if err != nil { - t.Errorf("Failed to add new transfer: %+v", err) - } - - // Create queued part list add parts - queuedParts := []queuedPart{{tid, 0}, {tid, 1}, {tid, 2}, {tid, 3}} - rid := id.Round(42) - - _, transfers, _, _, err := m.buildMessages(queuedParts, prng) - - expectedErr := strings.Split(finishTransferPanic, "%")[0] - defer func() { - err2 := recover() - if err2 == nil || !strings.Contains(err2.(string), expectedErr) { - t.Errorf("makeRoundEventCallback failed to panic or returned the "+ - "wrong error.\nexpected: %s\nreceived: %+v", expectedErr, err2) - } - }() - - roundEventCB := m.makeRoundEventCallback(rid, tid, transfers[tid]) - - roundEventCB(true, false, nil) -} - // Tests that Manager.queueParts adds all the expected parts to the sendQueue // channel. func TestManager_queueParts(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) csPrng := NewPrng(42) prng := rand.New(rand.NewSource(42)) @@ -808,7 +784,8 @@ func TestManager_queueParts(t *testing.T) { // Queue all parts for tid, partList := range parts { - m.queueParts(tid, uint16(len(partList))) + partNums := makeListOfPartNums(uint16(len(partList))) + m.queueParts(tid, partNums) } // Read through all parts in channel ensuring none are missing or duplicate @@ -846,22 +823,20 @@ func TestManager_queueParts(t *testing.T) { } } -// Tests that getShuffledPartNumList returns a list with all the part numbers. -func Test_getShuffledPartNumList(t *testing.T) { - n := 100 - numList := getShuffledPartNumList(uint16(n)) +// Tests that makeListOfPartNums returns a list with all the part numbers. +func Test_makeListOfPartNums(t *testing.T) { + n := uint16(100) + numList := makeListOfPartNums(n) - if len(numList) != n { - t.Errorf("Length of shuffled list incorrect."+ - "\nexpected: %d\nreceived: %d", n, len(numList)) + if len(numList) != int(n) { + t.Errorf("Length of list incorrect.\nexpected: %d\nreceived: %d", + n, len(numList)) } - for i := 0; i < n; i++ { - j := 0 - for ; i != int(numList[j]); j++ { - } - if i != int(numList[j]) { - t.Errorf("Failed to find part number %d in shuffled list.", i) + for i := uint16(0); i < n; i++ { + if numList[i] != i { + t.Errorf("Part number at index %d incorrect."+ + "\nexpected: %d\nreceived: %d", i, i, numList[i]) } } } @@ -869,7 +844,7 @@ func Test_getShuffledPartNumList(t *testing.T) { // Tests that the part size returned by Manager.getPartSize matches the manually // calculated part size. func TestManager_getPartSize(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) + m := newTestManager(false, nil, nil, nil, nil, t) // Calculate the expected part size primeByteLen := m.store.Cmix().GetGroup().GetP().ByteLen() @@ -893,7 +868,7 @@ func TestManager_getPartSize(t *testing.T) { func Test_partitionFile(t *testing.T) { prng := rand.New(rand.NewSource(42)) partSize := 96 - fileData, expectedParts := newFile(24, uint32(partSize), prng, t) + fileData, expectedParts := newFile(24, partSize, prng, t) receivedParts := partitionFile(fileData, partSize) diff --git a/fileTransfer/utils_test.go b/fileTransfer/utils_test.go index 374082112afa7a861fad78183f738bcbcc217a48..b80846b847158c9483efd8a84d8e3def1333ada9 100644 --- a/fileTransfer/utils_test.go +++ b/fileTransfer/utils_test.go @@ -22,6 +22,7 @@ import ( "gitlab.com/elixxir/client/storage/versioned" "gitlab.com/elixxir/client/switchboard" "gitlab.com/elixxir/comms/network" + "gitlab.com/elixxir/crypto/diffieHellman" "gitlab.com/elixxir/crypto/e2e" "gitlab.com/elixxir/crypto/fastRNG" ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" @@ -43,14 +44,14 @@ import ( // newFile generates a file with random data of size numParts * partSize. // Returns the full file and the file parts. If the partSize allows, each part // starts with a "|<[PART_001]" and ends with a ">|". -func newFile(numParts uint16, partSize uint32, prng io.Reader, t *testing.T) ( +func newFile(numParts uint16, partSize int, prng io.Reader, t *testing.T) ( []byte, [][]byte) { const ( prefix = "|<[PART_%3d]" suffix = ">|" ) // Create file buffer of the expected size - fileBuff := bytes.NewBuffer(make([]byte, 0, uint32(numParts)*partSize)) + fileBuff := bytes.NewBuffer(make([]byte, 0, int(numParts)*partSize)) partList := make([][]byte, numParts) // Create new rand.Rand with the seed generated from the io.Reader @@ -63,7 +64,7 @@ func newFile(numParts uint16, partSize uint32, prng io.Reader, t *testing.T) ( randPrng := rand.New(rand.NewSource(int64(seed))) for partNum := range partList { - s := RandStringBytes(int(partSize), randPrng) + s := RandStringBytes(partSize, randPrng) if len(s) >= (len(prefix) + len(suffix)) { partList[partNum] = []byte( prefix + s[:len(s)-(len(prefix)+len(suffix))] + suffix) @@ -148,28 +149,34 @@ func (s *PrngErr) SetSeed([]byte) error { return errors.New("SetSeedFailure" // newTestManager creates a new Manager that has groups stored for testing. One // of the groups in the list is also returned. func newTestManager(sendErr bool, sendChan, sendE2eChan chan message.Receive, - receiveCB interfaces.ReceiveCallback, t *testing.T) *Manager { + receiveCB interfaces.ReceiveCallback, kv *versioned.KV, t *testing.T) *Manager { - kv := versioned.NewKV(make(ekv.Memstore)) - sent, err := ftStorage.NewSentFileTransfers(kv) + if kv == nil { + kv = versioned.NewKV(make(ekv.Memstore)) + } + sent, err := ftStorage.NewSentFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to createw new SentFileTransfers: %+v", err) + t.Fatalf("Failed to createw new SentFileTransfersStore: %+v", err) } - received, err := ftStorage.NewReceivedFileTransfers(kv) + received, err := ftStorage.NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to createw new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to createw new ReceivedFileTransfersStore: %+v", err) } net := newTestNetworkManager(sendErr, sendChan, sendE2eChan, t) - // Register channel for health tracking - healthy := make(chan bool) - net.GetHealthTracker().AddChannel(healthy) - // Returns an error on function and round failure on callback if sendErr is // set; otherwise, it reports round successes and returns nil - rr := func(_ []id.Round, _ time.Duration, cb api.RoundEventCallback) error { - cb(!sendErr, false, nil) + rr := func(rIDs []id.Round, _ time.Duration, cb api.RoundEventCallback) error { + rounds := make(map[id.Round]api.RoundResult, len(rIDs)) + for _, rid := range rIDs { + if sendErr { + rounds[rid] = api.Failed + } else { + rounds[rid] = api.Succeeded + } + } + cb(!sendErr, false, rounds) if sendErr { return errors.New("SendError") } @@ -189,7 +196,6 @@ func newTestManager(sendErr bool, sendChan, sendE2eChan chan message.Receive, store: storage.InitTestingSession(t), swb: switchboard.New(), net: net, - healthy: healthy, rng: fastRNG.NewStreamGenerator(1000, 10, csprng.NewSystemRNG), getRoundResults: rr, } @@ -199,10 +205,10 @@ func newTestManager(sendErr bool, sendChan, sendE2eChan chan message.Receive, // newTestManagerWithTransfers creates a new test manager with transfers added // to it. -func newTestManagerWithTransfers(numParts []uint16, sendErr bool, - receiveCB interfaces.ReceiveCallback, t *testing.T) ( +func newTestManagerWithTransfers(numParts []uint16, sendErr, addPartners bool, + receiveCB interfaces.ReceiveCallback, kv *versioned.KV, t *testing.T) ( *Manager, []sentTransferInfo, []receivedTransferInfo) { - m := newTestManager(sendErr, nil, nil, receiveCB, t) + m := newTestManager(sendErr, nil, nil, receiveCB, kv, t) sti := make([]sentTransferInfo, len(numParts)) rti := make([]receivedTransferInfo, len(numParts)) var err error @@ -212,13 +218,17 @@ func newTestManagerWithTransfers(numParts []uint16, sendErr bool, t.Errorf("Failed to get part size: %+v", err) } + // Add sent transfers to manager and populate the sentTransferInfo list for i := range sti { + // Generate PRNG, the file and its parts, and the transfer key prng := NewPrng(int64(42 + i)) - file, parts := newFile(numParts[i], uint32(partSize), prng, t) + file, parts := newFile(numParts[i], partSize, prng, t) key, _ := ftCrypto.NewTransferKey(prng) + recipient := id.NewIdFromString("recipient"+strconv.Itoa(i), id.User, t) + // Create a sentTransferInfo with all the transfer information sti[i] = sentTransferInfo{ - recipient: id.NewIdFromString("recipient"+strconv.Itoa(i), id.User, t), + recipient: recipient, key: key, parts: parts, file: file, @@ -229,28 +239,46 @@ func newTestManagerWithTransfers(numParts []uint16, sendErr bool, prng: prng, } - cbChan := make(chan sentProgressResults, 6) - + // Create sent progress callback and channel + cbChan := make(chan sentProgressResults, 8) cb := func(completed bool, sent, arrived, total uint16, tr interfaces.FilePartTracker, err error) { cbChan <- sentProgressResults{completed, sent, arrived, total, tr, err} } + // Add callback and channel to the sentTransferInfo sti[i].cbChan = cbChan sti[i].cb = cb - sti[i].tid, err = m.sent.AddTransfer(sti[i].recipient, sti[i].key, + // Add the transfer to the manager + sti[i].tid, err = m.sent.AddTransfer(recipient, sti[i].key, sti[i].parts, sti[i].numFps, sti[i].cb, sti[i].period, sti[i].prng) if err != nil { t.Errorf("Failed to add sent transfer #%d: %+v", i, err) } + + // Add recipient as partner + if addPartners { + grp := m.store.E2e().GetGroup() + dhKey := grp.NewInt(int64(i + 42)) + pubKey := diffieHellman.GeneratePublicKey(dhKey, grp) + p := params.GetDefaultE2ESessionParams() + err = m.store.E2e().AddPartner(recipient, pubKey, dhKey, p, p) + if err != nil { + t.Errorf("Failed to add partner #%d %s: %+v", i, recipient, err) + } + } } + // Add received transfers to manager and populate the receivedTransferInfo + // list for i := range rti { + // Generate PRNG, the file and its parts, and the transfer key prng := NewPrng(int64(42 + i)) - file, parts := newFile(numParts[i], uint32(partSize), prng, t) + file, parts := newFile(numParts[i], partSize, prng, t) key, _ := ftCrypto.NewTransferKey(prng) + // Create a receivedTransferInfo with all the transfer information rti[i] = receivedTransferInfo{ key: key, mac: ftCrypto.CreateTransferMAC(file, key), @@ -264,16 +292,18 @@ func newTestManagerWithTransfers(numParts []uint16, sendErr bool, prng: prng, } - cbChan := make(chan receivedProgressResults, 6) - + // Create received progress callback and channel + cbChan := make(chan receivedProgressResults, 8) cb := func(completed bool, received, total uint16, tr interfaces.FilePartTracker, err error) { cbChan <- receivedProgressResults{completed, received, total, tr, err} } + // Add callback and channel to the receivedTransferInfo rti[i].cbChan = cbChan rti[i].cb = cb + // Add the transfer to the manager rti[i].tid, err = m.received.AddTransfer(rti[i].key, rti[i].mac, rti[i].fileSize, rti[i].numParts, rti[i].numFps, rti[i].prng) if err != nil { @@ -439,8 +469,8 @@ func (tnm *testNetworkManager) SendManyCMIX(messages []message.TargetedCmixMessa } else { tnm.updateRid = true } + tnm.Unlock() }() - defer tnm.Unlock() if tnm.sendErr { return 0, nil, errors.New("SendManyCMIX error") diff --git a/go.mod b/go.mod index 5fdf6886398abebda1dc3ebb5008e7ce96a1d863..aca1594f285da11ba13dc2432922f14220300500 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,7 @@ require ( github.com/spf13/viper v1.7.1 gitlab.com/elixxir/bloomfilter v0.0.0-20200930191214-10e9ac31b228 gitlab.com/elixxir/comms v0.0.4-0.20211215181643-55748529a237 - gitlab.com/elixxir/crypto v0.0.7-0.20211215181538-688f96f4c4c0 + gitlab.com/elixxir/crypto v0.0.7-0.20211218190454-ecf7e6e1f41f gitlab.com/elixxir/ekv v0.1.5 gitlab.com/elixxir/primitives v0.0.3-0.20211208211148-752546cf2e46 gitlab.com/xx_network/comms v0.0.4-0.20211215181459-0918c1141509 diff --git a/go.sum b/go.sum index fba3cc4eaafe453c4954475b0c0a3718c6930e85..94723085a30493052bd898c44127b3785906c113 100644 --- a/go.sum +++ b/go.sum @@ -261,6 +261,8 @@ gitlab.com/elixxir/crypto v0.0.0-20200804182833-984246dea2c4/go.mod h1:ucm9SFKJo gitlab.com/elixxir/crypto v0.0.3/go.mod h1:ZNgBOblhYToR4m8tj4cMvJ9UsJAUKq+p0gCp07WQmhA= gitlab.com/elixxir/crypto v0.0.7-0.20211215181538-688f96f4c4c0 h1:WUrLqko7VYPRtxjL0o++RHFpS5Xjtlxed7qHuMu1iAo= gitlab.com/elixxir/crypto v0.0.7-0.20211215181538-688f96f4c4c0/go.mod h1:SQHmwjgX9taGCbzrtHGbIcZmV5iPielNP7c5wzLCUhM= +gitlab.com/elixxir/crypto v0.0.7-0.20211218190454-ecf7e6e1f41f h1:rLZKhzxZcXlm6WvVuEhXSrvXhmzjplcP9fyDl9lgHRQ= +gitlab.com/elixxir/crypto v0.0.7-0.20211218190454-ecf7e6e1f41f/go.mod h1:SQHmwjgX9taGCbzrtHGbIcZmV5iPielNP7c5wzLCUhM= gitlab.com/elixxir/ekv v0.1.5 h1:R8M1PA5zRU1HVnTyrtwybdABh7gUJSCvt1JZwUSeTzk= gitlab.com/elixxir/ekv v0.1.5/go.mod h1:e6WPUt97taFZe5PFLPb1Dupk7tqmDCTQu1kkstqJvw4= gitlab.com/elixxir/primitives v0.0.0-20200731184040-494269b53b4d/go.mod h1:OQgUZq7SjnE0b+8+iIAT2eqQF+2IFHn73tOo+aV11mg= diff --git a/interfaces/fileTransfer.go b/interfaces/fileTransfer.go index 49867124123e8dcbac7123d92d8f3e3ab698c155..dcd5b29b6246bfcef6a43db20b44d2bcfe30f068 100644 --- a/interfaces/fileTransfer.go +++ b/interfaces/fileTransfer.go @@ -10,6 +10,7 @@ package interfaces import ( ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" "gitlab.com/xx_network/primitives/id" + "strconv" "time" ) @@ -43,7 +44,7 @@ type FileTransfer interface { retry float32, preview []byte, progressCB SentProgressCallback, period time.Duration) (ftCrypto.TransferID, error) - // RegisterSendProgressCallback allows for the registration of a callback to + // RegisterSentProgressCallback allows for the registration of a callback to // track the progress of an individual sent file transfer. The callback will // be called immediately when added to report the current status of the // transfer. It will then call every time a file part is sent, a file part @@ -51,7 +52,7 @@ type FileTransfer interface { // once ever period, which means if events occur faster than the period, // then they will not be reported and instead the progress will be reported // once at the end of the period. - RegisterSendProgressCallback(tid ftCrypto.TransferID, + RegisterSentProgressCallback(tid ftCrypto.TransferID, progressCB SentProgressCallback, period time.Duration) error // Resend resends a file if sending fails. Returns an error if CloseSend @@ -70,9 +71,9 @@ type FileTransfer interface { // verified, or if the transfer cannot be found. Receive(tid ftCrypto.TransferID) ([]byte, error) - // RegisterReceiveProgressCallback allows for the registration of a callback - // to track the progress of an individual received file transfer. The - // callback will be called immediately when added to report the current + // RegisterReceivedProgressCallback allows for the registration of a + // callback to track the progress of an individual received file transfer. + // The callback will be called immediately when added to report the current // status of the transfer. It will then call every time a file part is // received, the transfer completes, or an error occurs. It is called at // most once ever period, which means if events occur faster than the @@ -80,7 +81,7 @@ type FileTransfer interface { // reported once at the end of the period. // Once the callback reports that the transfer has completed, the recipient // can get the full file by calling Receive. - RegisterReceiveProgressCallback(tid ftCrypto.TransferID, + RegisterReceivedProgressCallback(tid ftCrypto.TransferID, progressCB ReceivedProgressCallback, period time.Duration) error } @@ -93,8 +94,47 @@ type FilePartTracker interface { // 1 = sent (sender has sent a part, but it has not arrived) // 2 = arrived (sender has sent a part, and it has arrived) // 3 = received (receiver has received a part) - GetPartStatus(partNum uint16) int + GetPartStatus(partNum uint16) FpStatus // GetNumParts returns the total number of file parts in the transfer. GetNumParts() uint16 } + +// FpStatus is the file part status and indicates the status of individual file +// parts in a file transfer. +type FpStatus int + +// Possible values for FpStatus. +const ( + // FpUnsent indicates that the file part has not been sent + FpUnsent FpStatus = iota + + // FpSent indicates that the file part has been sent (sender has sent a + // part, but it has not arrived) + FpSent + + // FpArrived indicates that the file part has arrived (sender has sent a + // part, and it has arrived) + FpArrived + + // FpReceived indicates that the file part has been received (receiver has + // received a part) + FpReceived +) + +// String returns the string representing of the FpStatus. This functions +// satisfies the fmt.Stringer interface. +func (fps FpStatus) String() string { + switch fps { + case FpUnsent: + return "unsent" + case FpSent: + return "sent" + case FpArrived: + return "arrived" + case FpReceived: + return "received" + default: + return "INVALID FpStatus: " + strconv.Itoa(int(fps)) + } +} diff --git a/interfaces/message/type.go b/interfaces/message/type.go index daf1a6dabdc1762774ac10c4b7e87491ba05bbba..4a5115a162e562e04c6e2577d7a80d64589d8b4f 100644 --- a/interfaces/message/type.go +++ b/interfaces/message/type.go @@ -54,7 +54,11 @@ const ( // A group chat request message sent to all members in a group. GroupCreationRequest = 40 - // NewFileTransfer is the initial message that initiates a large file - // transfer. + // NewFileTransfer is transmitted first on the initialization of a file + // transfer to inform the receiver about the incoming file. NewFileTransfer = 50 + + // EndFileTransfer is sent once all file parts have been transmitted to + // inform the receiver that the file transfer has ended. + EndFileTransfer = 51 ) diff --git a/interfaces/preimage/types.go b/interfaces/preimage/types.go index 3588b4f24ca24ecdb5cc6840d35ade097b5468d8..52842446a235b265d53812c1343a12156272e7cf 100644 --- a/interfaces/preimage/types.go +++ b/interfaces/preimage/types.go @@ -7,4 +7,5 @@ const ( Rekey = "rekey" E2e = "e2e" Group = "group" + EndFT = "endFT" ) diff --git a/storage/e2e/manager.go b/storage/e2e/manager.go index 68119395dc81f3c3dff50dfcfdd5b5d8b6da9ea4..f7537fb34f770ea64c3d74a4f88d21cef2394ca6 100644 --- a/storage/e2e/manager.go +++ b/storage/e2e/manager.go @@ -266,3 +266,9 @@ func (m *Manager) GetE2EPreimage() []byte { func (m *Manager) GetRekeyPreimage() []byte { return preimage.Generate(m.GetRelationshipFingerprintBytes(), preimage.Rekey) } + +// GetFileTransferPreimage returns a hash of the unique +// fingerprint for an E2E end file transfer message. +func (m *Manager) GetFileTransferPreimage() []byte { + return preimage.Generate(m.GetRelationshipFingerprintBytes(), preimage.EndFT) +} diff --git a/storage/fileTransfer/receiveFileTransfers.go b/storage/fileTransfer/receiveFileTransfers.go index bd5c84244f11f9251b05c87e2fb9463e8ac9ab0d..c9acc3e54aef9b2562b762fd23e036eeb5357820 100644 --- a/storage/fileTransfer/receiveFileTransfers.go +++ b/storage/fileTransfer/receiveFileTransfers.go @@ -20,12 +20,12 @@ import ( ) const ( - receivedFileTransfersPrefix = "FileTransferReceivedFileTransfersStore" - receivedFileTransfersKey = "ReceivedFileTransfers" - receivedFileTransfersVersion = 0 + receivedFileTransfersStorePrefix = "FileTransferReceivedFileTransfersStore" + receivedFileTransfersStoreKey = "ReceivedFileTransfers" + receivedFileTransfersStoreVersion = 0 ) -// Error messages for ReceivedFileTransfers. +// Error messages for ReceivedFileTransfersStore. const ( saveReceivedTransfersListErr = "failed to save list of received items in transfer map to storage: %+v" loadReceivedTransfersListErr = "failed to load list of received items in transfer map from storage: %+v" @@ -39,22 +39,24 @@ const ( deleteReceivedTransferErr = "failed to delete received transfer with ID %s from store: %+v" ) -// ReceivedFileTransfers contains information for tracking a received +// ReceivedFileTransfersStore contains information for tracking a received // ReceivedTransfer to its transfer ID. It also maps a received part, partInfo, // to the message fingerprint. -type ReceivedFileTransfers struct { +type ReceivedFileTransfersStore struct { transfers map[ftCrypto.TransferID]*ReceivedTransfer info map[format.Fingerprint]*partInfo mux sync.Mutex kv *versioned.KV } -// NewReceivedFileTransfers creates a new ReceivedFileTransfers with empty maps. -func NewReceivedFileTransfers(kv *versioned.KV) (*ReceivedFileTransfers, error) { - rft := &ReceivedFileTransfers{ +// NewReceivedFileTransfersStore creates a new ReceivedFileTransfersStore with +// empty maps. +func NewReceivedFileTransfersStore(kv *versioned.KV) ( + *ReceivedFileTransfersStore, error) { + rft := &ReceivedFileTransfersStore{ transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), info: make(map[format.Fingerprint]*partInfo), - kv: kv.Prefix(receivedFileTransfersPrefix), + kv: kv.Prefix(receivedFileTransfersStorePrefix), } return rft, rft.saveTransfersList() @@ -62,7 +64,7 @@ func NewReceivedFileTransfers(kv *versioned.KV) (*ReceivedFileTransfers, error) // AddTransfer creates a new empty ReceivedTransfer, adds it to the transfers // map, and adds an entry for each file part fingerprint to the fingerprint map. -func (rft *ReceivedFileTransfers) AddTransfer(key ftCrypto.TransferKey, +func (rft *ReceivedFileTransfersStore) AddTransfer(key ftCrypto.TransferKey, transferMAC []byte, fileSize uint32, numParts, numFps uint16, rng csprng.Source) (ftCrypto.TransferID, error) { @@ -96,7 +98,7 @@ func (rft *ReceivedFileTransfers) AddTransfer(key ftCrypto.TransferKey, // addFingerprints generates numFps fingerprints, creates a new partInfo // for each, and adds each to the info map. -func (rft *ReceivedFileTransfers) addFingerprints(key ftCrypto.TransferKey, +func (rft *ReceivedFileTransfersStore) addFingerprints(key ftCrypto.TransferKey, tid ftCrypto.TransferID, numFps uint16) { // Generate list of fingerprints @@ -110,7 +112,7 @@ func (rft *ReceivedFileTransfers) addFingerprints(key ftCrypto.TransferKey, // GetTransfer returns the ReceivedTransfer with the given transfer ID. An error // is returned if no corresponding transfer is found. -func (rft *ReceivedFileTransfers) GetTransfer(tid ftCrypto.TransferID) ( +func (rft *ReceivedFileTransfersStore) GetTransfer(tid ftCrypto.TransferID) ( *ReceivedTransfer, error) { rft.mux.Lock() defer rft.mux.Unlock() @@ -125,7 +127,7 @@ func (rft *ReceivedFileTransfers) GetTransfer(tid ftCrypto.TransferID) ( // DeleteTransfer removes the ReceivedTransfer with the associated transfer ID // from memory and storage. -func (rft *ReceivedFileTransfers) DeleteTransfer(tid ftCrypto.TransferID) error { +func (rft *ReceivedFileTransfersStore) DeleteTransfer(tid ftCrypto.TransferID) error { rft.mux.Lock() defer rft.mux.Unlock() @@ -173,8 +175,8 @@ func (rft *ReceivedFileTransfers) DeleteTransfer(tid ftCrypto.TransferID) error // added to the transfer with the corresponding transfer ID. Returns the // transfer that the part was added to so that a progress callback can be // called. Returns the transfer ID so that it can be used for logging. -func (rft *ReceivedFileTransfers) AddPart(encryptedPart, padding, mac []byte, - partNum uint16, fp format.Fingerprint) (*ReceivedTransfer, +func (rft *ReceivedFileTransfersStore) AddPart(encryptedPart, padding, + mac []byte, partNum uint16, fp format.Fingerprint) (*ReceivedTransfer, ftCrypto.TransferID, error) { rft.mux.Lock() defer rft.mux.Unlock() @@ -207,7 +209,8 @@ func (rft *ReceivedFileTransfers) AddPart(encryptedPart, padding, mac []byte, // GetFile returns the combined file parts as a single byte slice for the given // transfer ID. An error is returned if no file with the given transfer ID is // found or if the file is missing parts. -func (rft *ReceivedFileTransfers) GetFile(tid ftCrypto.TransferID) ([]byte, error) { +func (rft *ReceivedFileTransfersStore) GetFile(tid ftCrypto.TransferID) ([]byte, + error) { rft.mux.Lock() defer rft.mux.Unlock() @@ -224,12 +227,14 @@ func (rft *ReceivedFileTransfers) GetFile(tid ftCrypto.TransferID) ([]byte, erro // Storage Functions // //////////////////////////////////////////////////////////////////////////////// -// LoadReceivedFileTransfers loads all ReceivedFileTransfers from storage. -func LoadReceivedFileTransfers(kv *versioned.KV) (*ReceivedFileTransfers, error) { - rft := &ReceivedFileTransfers{ +// LoadReceivedFileTransfersStore loads all ReceivedFileTransfersStore from +// storage. +func LoadReceivedFileTransfersStore(kv *versioned.KV) ( + *ReceivedFileTransfersStore, error) { + rft := &ReceivedFileTransfersStore{ transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), info: make(map[format.Fingerprint]*partInfo), - kv: kv.Prefix(receivedFileTransfersPrefix), + kv: kv.Prefix(receivedFileTransfersStorePrefix), } // Get the list of transfer IDs corresponding to each received transfer from @@ -248,20 +253,22 @@ func LoadReceivedFileTransfers(kv *versioned.KV) (*ReceivedFileTransfers, error) return rft, nil } -// NewOrLoadReceivedFileTransfers loads all ReceivedFileTransfers from storage, -// if they exist. Otherwise, a new ReceivedFileTransfers is returned. -func NewOrLoadReceivedFileTransfers(kv *versioned.KV) (*ReceivedFileTransfers, error) { - rft := &ReceivedFileTransfers{ +// NewOrLoadReceivedFileTransfersStore loads all ReceivedFileTransfersStore from +// storage, if they exist. Otherwise, a new ReceivedFileTransfersStore is +// returned. +func NewOrLoadReceivedFileTransfersStore(kv *versioned.KV) ( + *ReceivedFileTransfersStore, error) { + rft := &ReceivedFileTransfersStore{ transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), info: make(map[format.Fingerprint]*partInfo), - kv: kv.Prefix(receivedFileTransfersPrefix), + kv: kv.Prefix(receivedFileTransfersStorePrefix), } // If the transfer list cannot be loaded from storage, then create a new - // ReceivedFileTransfers - vo, err := rft.kv.Get(receivedFileTransfersKey, receivedFileTransfersVersion) + // ReceivedFileTransfersStore + vo, err := rft.kv.Get(receivedFileTransfersStoreKey, receivedFileTransfersStoreVersion) if err != nil { - return NewReceivedFileTransfers(kv) + return NewReceivedFileTransfersStore(kv) } // Unmarshal data into list of saved transfer IDs @@ -277,23 +284,26 @@ func NewOrLoadReceivedFileTransfers(kv *versioned.KV) (*ReceivedFileTransfers, e } // saveTransfersList saves a list of items in the transfers map to storage. -func (rft *ReceivedFileTransfers) saveTransfersList() error { +func (rft *ReceivedFileTransfersStore) saveTransfersList() error { // Create new versioned object with a list of items in the transfers map obj := &versioned.Object{ - Version: receivedFileTransfersVersion, + Version: receivedFileTransfersStoreVersion, Timestamp: netTime.Now(), Data: rft.marshalTransfersList(), } // Save list of items in the transfers map to storage - return rft.kv.Set(receivedFileTransfersKey, receivedFileTransfersVersion, obj) + return rft.kv.Set( + receivedFileTransfersStoreKey, receivedFileTransfersStoreVersion, obj) } // loadTransfersList gets the list of transfer IDs corresponding to each saved // received transfer from storage. -func (rft *ReceivedFileTransfers) loadTransfersList() ([]ftCrypto.TransferID, error) { +func (rft *ReceivedFileTransfersStore) loadTransfersList() ( + []ftCrypto.TransferID, error) { // Get transfers list from storage - vo, err := rft.kv.Get(receivedFileTransfersKey, receivedFileTransfersVersion) + vo, err := rft.kv.Get( + receivedFileTransfersStoreKey, receivedFileTransfersStoreVersion) if err != nil { return nil, err } @@ -305,7 +315,7 @@ func (rft *ReceivedFileTransfers) loadTransfersList() ([]ftCrypto.TransferID, er // load gets each ReceivedTransfer in the list from storage and adds them to the // map. Also adds all unused fingerprints in each ReceivedTransfer to the info // map. -func (rft *ReceivedFileTransfers) load(list []ftCrypto.TransferID) error { +func (rft *ReceivedFileTransfersStore) load(list []ftCrypto.TransferID) error { // Load each sentTransfer from storage into the map for _, tid := range list { // Load the transfer with the given transfer ID from storage @@ -336,7 +346,7 @@ func (rft *ReceivedFileTransfers) load(list []ftCrypto.TransferID) error { // marshalTransfersList creates a list of all transfer IDs in the transfers map // and serialises it. -func (rft *ReceivedFileTransfers) marshalTransfersList() []byte { +func (rft *ReceivedFileTransfersStore) marshalTransfersList() []byte { buff := bytes.NewBuffer(nil) buff.Grow(ftCrypto.TransferIdLength * len(rft.transfers)) diff --git a/storage/fileTransfer/receiveFileTransfers_test.go b/storage/fileTransfer/receiveFileTransfers_test.go index e60036a9d51ebfdb84578203aa20067fc384b8b1..d0bde16aa2e75b14bba6d602a7dc30234c818a65 100644 --- a/storage/fileTransfer/receiveFileTransfers_test.go +++ b/storage/fileTransfer/receiveFileTransfers_test.go @@ -22,43 +22,43 @@ import ( "testing" ) -// Tests that NewReceivedFileTransfers creates a new object with empty maps and -// that it is saved to storage -func TestNewReceivedFileTransfers(t *testing.T) { +// Tests that NewReceivedFileTransfersStore creates a new object with empty maps +// and that it is saved to storage +func TestNewReceivedFileTransfersStore(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - expectedRFT := &ReceivedFileTransfers{ + expectedRFT := &ReceivedFileTransfersStore{ transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), info: make(map[format.Fingerprint]*partInfo), - kv: kv.Prefix(receivedFileTransfersPrefix), + kv: kv.Prefix(receivedFileTransfersStorePrefix), } - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Errorf("NewReceivedFileTransfers returned an error: %+v", err) + t.Errorf("NewReceivedFileTransfersStore returned an error: %+v", err) } - // Check that the new ReceivedFileTransfers matches the expected + // Check that the new ReceivedFileTransfersStore matches the expected if !reflect.DeepEqual(expectedRFT, rft) { - t.Errorf("New ReceivedFileTransfers does not match expected."+ + t.Errorf("New ReceivedFileTransfersStore does not match expected."+ "\nexpected: %+v\nreceived: %+v", expectedRFT, rft) } // Ensure that the transfer list is saved to storage _, err = expectedRFT.kv.Get( - receivedFileTransfersKey, receivedFileTransfersVersion) + receivedFileTransfersStoreKey, receivedFileTransfersStoreVersion) if err != nil { t.Errorf("Failed to load transfer list from storage: %+v", err) } } -// Tests that ReceivedFileTransfers.AddTransfer adds a new transfer and adds all -// the fingerprints to the info map. -func TestReceivedFileTransfers_AddTransfer(t *testing.T) { +// Tests that ReceivedFileTransfersStore.AddTransfer adds a new transfer and +// adds all the fingerprints to the info map. +func TestReceivedFileTransfersStore_AddTransfer(t *testing.T) { prng := NewPrng(42) kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) } // Generate info for new transfer @@ -121,14 +121,14 @@ func TestReceivedFileTransfers_AddTransfer(t *testing.T) { } } -// Error path: tests that ReceivedFileTransfers.AddTransfer returns the expected -// error when the PRNG returns an error. -func TestReceivedFileTransfers_AddTransfer_NewTransferIdRngError(t *testing.T) { +// Error path: tests that ReceivedFileTransfersStore.AddTransfer returns the +// expected error when the PRNG returns an error. +func TestReceivedFileTransfersStore_AddTransfer_NewTransferIdRngError(t *testing.T) { prng := NewPrngErr() kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) } // Add the transfer @@ -141,14 +141,14 @@ func TestReceivedFileTransfers_AddTransfer_NewTransferIdRngError(t *testing.T) { } -// Tests that ReceivedFileTransfers.addFingerprints adds all the fingerprints -// to the map -func TestReceivedFileTransfers_addFingerprints(t *testing.T) { +// Tests that ReceivedFileTransfersStore.addFingerprints adds all the +// fingerprints to the map +func TestReceivedFileTransfersStore_addFingerprints(t *testing.T) { prng := NewPrng(42) kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) } key, _ := ftCrypto.NewTransferKey(prng) @@ -176,14 +176,14 @@ func TestReceivedFileTransfers_addFingerprints(t *testing.T) { } } -// Tests that ReceivedFileTransfers.GetTransfer returns the newly added +// Tests that ReceivedFileTransfersStore.GetTransfer returns the newly added // transfer. -func TestReceivedFileTransfers_GetTransfer(t *testing.T) { +func TestReceivedFileTransfersStore_GetTransfer(t *testing.T) { prng := NewPrng(42) kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) } // Generate random info for new transfer @@ -225,14 +225,15 @@ func TestReceivedFileTransfers_GetTransfer(t *testing.T) { } } -// Error path: tests that ReceivedFileTransfers.GetTransfer returns the expected -// error when the provided transfer ID does not correlate to any saved transfer. -func TestReceivedFileTransfers_GetTransfer_NoTransferError(t *testing.T) { +// Error path: tests that ReceivedFileTransfersStore.GetTransfer returns the +// expected error when the provided transfer ID does not correlate to any saved +// transfer. +func TestReceivedFileTransfersStore_GetTransfer_NoTransferError(t *testing.T) { prng := NewPrng(42) kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) } // Generate random info for new transfer @@ -257,12 +258,12 @@ func TestReceivedFileTransfers_GetTransfer_NoTransferError(t *testing.T) { } // Tests that DeleteTransfer removed a transfer from memory and storage. -func TestReceivedFileTransfers_DeleteTransfer(t *testing.T) { +func TestReceivedFileTransfersStore_DeleteTransfer(t *testing.T) { prng := NewPrng(42) kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) } // Add the transfer @@ -304,15 +305,15 @@ func TestReceivedFileTransfers_DeleteTransfer(t *testing.T) { } } -// Error path: tests that ReceivedFileTransfers.DeleteTransfer returns the +// Error path: tests that ReceivedFileTransfersStore.DeleteTransfer returns the // expected error when the provided transfer ID does not correlate to any saved // transfer. -func TestReceivedFileTransfers_DeleteTransfer_NoTransferError(t *testing.T) { +func TestReceivedFileTransfersStore_DeleteTransfer_NoTransferError(t *testing.T) { prng := NewPrng(42) kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) } // Delete the transfer @@ -326,13 +327,13 @@ func TestReceivedFileTransfers_DeleteTransfer_NoTransferError(t *testing.T) { } } -// Tests that ReceivedFileTransfers.AddPart modifies the expected transfer in -// memory. -func TestReceivedFileTransfers_AddPart(t *testing.T) { +// Tests that ReceivedFileTransfersStore.AddPart modifies the expected transfer +// in memory. +func TestReceivedFileTransfersStore_AddPart(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) } prng := NewPrng(42) @@ -379,13 +380,13 @@ func TestReceivedFileTransfers_AddPart(t *testing.T) { } } -// Error path: tests that ReceivedFileTransfers.AddPart returns the expected -// error when the provided fingerprint does not correlate to any part. -func TestReceivedFileTransfers_AddPart_NoFingerprintError(t *testing.T) { +// Error path: tests that ReceivedFileTransfersStore.AddPart returns the +// expected error when the provided fingerprint does not correlate to any part. +func TestReceivedFileTransfersStore_AddPart_NoFingerprintError(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) } // Create encrypted part @@ -401,13 +402,14 @@ func TestReceivedFileTransfers_AddPart_NoFingerprintError(t *testing.T) { } } -// Error path: tests that ReceivedFileTransfers.AddPart returns the expected -// error when the provided transfer ID does not correlate to any saved transfer. -func TestReceivedFileTransfers_AddPart_NoTransferError(t *testing.T) { +// Error path: tests that ReceivedFileTransfersStore.AddPart returns the +// expected error when the provided transfer ID does not correlate to any saved +// transfer. +func TestReceivedFileTransfersStore_AddPart_NoTransferError(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) } prng := NewPrng(42) @@ -433,13 +435,13 @@ func TestReceivedFileTransfers_AddPart_NoTransferError(t *testing.T) { } } -// Error path: tests that ReceivedFileTransfers.AddPart returns the expected -// error when the encrypted part data, MAC, and padding are invalid. -func TestReceivedFileTransfers_AddPart_AddPartError(t *testing.T) { +// Error path: tests that ReceivedFileTransfersStore.AddPart returns the +// expected error when the encrypted part data, MAC, and padding are invalid. +func TestReceivedFileTransfersStore_AddPart_AddPartError(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) } prng := NewPrng(42) @@ -469,12 +471,12 @@ func TestReceivedFileTransfers_AddPart_AddPartError(t *testing.T) { } } -// Tests that ReceivedFileTransfers.GetFile returns the complete file. -func TestReceivedFileTransfers_GetFile(t *testing.T) { +// Tests that ReceivedFileTransfersStore.GetFile returns the complete file. +func TestReceivedFileTransfersStore_GetFile(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) } prng := NewPrng(42) @@ -514,13 +516,13 @@ func TestReceivedFileTransfers_GetFile(t *testing.T) { } } -// Error path: tests that ReceivedFileTransfers.GetFile returns the expected -// error when no transfer with the ID exists. -func TestReceivedFileTransfers_GetFile_NoTransferError(t *testing.T) { +// Error path: tests that ReceivedFileTransfersStore.GetFile returns the +// expected error when no transfer with the ID exists. +func TestReceivedFileTransfersStore_GetFile_NoTransferError(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) } tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) @@ -537,13 +539,13 @@ func TestReceivedFileTransfers_GetFile_NoTransferError(t *testing.T) { // Storage Functions // //////////////////////////////////////////////////////////////////////////////// -// Tests that the ReceivedFileTransfers loaded from storage by -// LoadReceivedFileTransfers matches the original in memory. -func TestLoadReceivedFileTransfers(t *testing.T) { +// Tests that the ReceivedFileTransfersStore loaded from storage by +// LoadReceivedFileTransfersStore matches the original in memory. +func TestLoadReceivedFileTransfersStore(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to make new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to make new ReceivedFileTransfersStore: %+v", err) } // Add 10 transfers to map in memory @@ -571,64 +573,64 @@ func TestLoadReceivedFileTransfers(t *testing.T) { t.Errorf("Faileds to save transfers list: %+v", err) } - // Load ReceivedFileTransfers from storage - loadedRFT, err := LoadReceivedFileTransfers(kv) + // Load ReceivedFileTransfersStore from storage + loadedRFT, err := LoadReceivedFileTransfersStore(kv) if err != nil { - t.Errorf("LoadReceivedFileTransfers returned an error: %+v", err) + t.Errorf("LoadReceivedFileTransfersStore returned an error: %+v", err) } if !reflect.DeepEqual(rft, loadedRFT) { - t.Errorf("Loaded ReceivedFileTransfers does not match original in "+ - "memory.\nexpected: %+v\nreceived: %+v", rft, loadedRFT) + t.Errorf("Loaded ReceivedFileTransfersStore does not match original"+ + "in memory.\nexpected: %+v\nreceived: %+v", rft, loadedRFT) } } -// Error path: tests that ReceivedFileTransfers returns the expected error when -// the transfer list cannot be loaded from storage. -func TestLoadReceivedFileTransfers_NoListInStorageError(t *testing.T) { +// Error path: tests that ReceivedFileTransfersStore returns the expected error +// when the transfer list cannot be loaded from storage. +func TestLoadReceivedFileTransfersStore_NoListInStorageError(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) expectedErr := strings.Split(loadReceivedTransfersListErr, "%")[0] - // Load ReceivedFileTransfers from storage - _, err := LoadReceivedFileTransfers(kv) + // Load ReceivedFileTransfersStore from storage + _, err := LoadReceivedFileTransfersStore(kv) if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("LoadReceivedFileTransfers did not return the expected error "+ - "when there is no transfer list saved in storage."+ + t.Errorf("LoadReceivedFileTransfersStore did not return the expected "+ + "error when there is no transfer list saved in storage."+ "\nexpected: %s\nreceived: %+v", expectedErr, err) } } -// Error path: tests that ReceivedFileTransfers returns the expected error when -// the first transfer loaded from storage does not exist. -func TestLoadReceivedFileTransfers_NoTransferInStorageError(t *testing.T) { +// Error path: tests that ReceivedFileTransfersStore returns the expected error +// when the first transfer loaded from storage does not exist. +func TestLoadReceivedFileTransfersStore_NoTransferInStorageError(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) expectedErr := strings.Split(loadReceivedFileTransfersErr, "%")[0] // Save list of one transfer ID to storage obj := &versioned.Object{ - Version: receivedFileTransfersVersion, + Version: receivedFileTransfersStoreVersion, Timestamp: netTime.Now(), Data: ftCrypto.UnmarshalTransferID([]byte("testID_01")).Bytes(), } - err := kv.Prefix(receivedFileTransfersPrefix).Set( - receivedFileTransfersKey, receivedFileTransfersVersion, obj) + err := kv.Prefix(receivedFileTransfersStorePrefix).Set( + receivedFileTransfersStoreKey, receivedFileTransfersStoreVersion, obj) - // Load ReceivedFileTransfers from storage - _, err = LoadReceivedFileTransfers(kv) + // Load ReceivedFileTransfersStore from storage + _, err = LoadReceivedFileTransfersStore(kv) if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("LoadReceivedFileTransfers did not return the expected error "+ - "when there is no transfer saved in storage."+ + t.Errorf("LoadReceivedFileTransfersStore did not return the expected "+ + "error when there is no transfer saved in storage."+ "\nexpected: %s\nreceived: %+v", expectedErr, err) } } -// Tests that the ReceivedFileTransfers loaded from storage by -// NewOrLoadReceivedFileTransfers matches the original in memory. -func TestNewOrLoadReceivedFileTransfers(t *testing.T) { +// Tests that the ReceivedFileTransfersStore loaded from storage by +// NewOrLoadReceivedFileTransfersStore matches the original in memory. +func TestNewOrLoadReceivedFileTransfersStore(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to make new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to make new ReceivedFileTransfersStore: %+v", err) } // Add 10 transfers to map in memory @@ -656,67 +658,69 @@ func TestNewOrLoadReceivedFileTransfers(t *testing.T) { t.Errorf("Faileds to save transfers list: %+v", err) } - // Load ReceivedFileTransfers from storage - loadedRFT, err := NewOrLoadReceivedFileTransfers(kv) + // Load ReceivedFileTransfersStore from storage + loadedRFT, err := NewOrLoadReceivedFileTransfersStore(kv) if err != nil { - t.Errorf("NewOrLoadReceivedFileTransfers returned an error: %+v", err) + t.Errorf("NewOrLoadReceivedFileTransfersStore returned an error: %+v", + err) } if !reflect.DeepEqual(rft, loadedRFT) { - t.Errorf("Loaded ReceivedFileTransfers does not match original in "+ - "memory.\nexpected: %+v\nreceived: %+v", rft, loadedRFT) + t.Errorf("Loaded ReceivedFileTransfersStore does not match original "+ + "in memory.\nexpected: %+v\nreceived: %+v", rft, loadedRFT) } } -// Tests that NewOrLoadReceivedFileTransfers returns a new ReceivedFileTransfers -// when there is none in storage. -func TestNewOrLoadReceivedFileTransfers_NewSentFileTransfers(t *testing.T) { +// Tests that NewOrLoadReceivedFileTransfersStore returns a new +// ReceivedFileTransfersStore when there is none in storage. +func TestNewOrLoadReceivedFileTransfersStore_NewReceivedFileTransfersStore(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - // Load ReceivedFileTransfers from storage - loadedRFT, err := NewOrLoadReceivedFileTransfers(kv) + // Load ReceivedFileTransfersStore from storage + loadedRFT, err := NewOrLoadReceivedFileTransfersStore(kv) if err != nil { - t.Errorf("NewOrLoadReceivedFileTransfers returned an error: %+v", err) + t.Errorf("NewOrLoadReceivedFileTransfersStore returned an error: %+v", + err) } - newRFT, _ := NewReceivedFileTransfers(kv) + newRFT, _ := NewReceivedFileTransfersStore(kv) if !reflect.DeepEqual(newRFT, loadedRFT) { - t.Errorf("Returned ReceivedFileTransfers does not match new."+ + t.Errorf("Returned ReceivedFileTransfersStore does not match new."+ "\nexpected: %+v\nreceived: %+v", newRFT, loadedRFT) } } -// Error path: tests that NewOrLoadReceivedFileTransfers returns the expected -// error when the first transfer loaded from storage does not exist. -func TestNewOrLoadReceivedFileTransfers_NoTransferInStorageError(t *testing.T) { +// Error path: tests that NewOrLoadReceivedFileTransfersStore returns the +// expected error when the first transfer loaded from storage does not exist. +func TestNewOrLoadReceivedFileTransfersStore_NoTransferInStorageError(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) expectedErr := strings.Split(loadReceivedFileTransfersErr, "%")[0] // Save list of one transfer ID to storage obj := &versioned.Object{ - Version: receivedFileTransfersVersion, + Version: receivedFileTransfersStoreVersion, Timestamp: netTime.Now(), Data: ftCrypto.UnmarshalTransferID([]byte("testID_01")).Bytes(), } - err := kv.Prefix(receivedFileTransfersPrefix).Set( - receivedFileTransfersKey, receivedFileTransfersVersion, obj) + err := kv.Prefix(receivedFileTransfersStorePrefix).Set( + receivedFileTransfersStoreKey, receivedFileTransfersStoreVersion, obj) - // Load ReceivedFileTransfers from storage - _, err = NewOrLoadReceivedFileTransfers(kv) + // Load ReceivedFileTransfersStore from storage + _, err = NewOrLoadReceivedFileTransfersStore(kv) if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("NewOrLoadReceivedFileTransfers did not return the expected "+ - "error when there is no transfer saved in storage."+ + t.Errorf("NewOrLoadReceivedFileTransfersStore did not return the "+ + "expected error when there is no transfer saved in storage."+ "\nexpected: %s\nreceived: %+v", expectedErr, err) } } -// Tests that the list saved by ReceivedFileTransfers.saveTransfersList matches -// the list loaded by ReceivedFileTransfers.load. -func TestLoadReceivedFileTransfers_saveTransfersList_loadTransfersList(t *testing.T) { - rft, err := NewReceivedFileTransfers(versioned.NewKV(make(ekv.Memstore))) +// Tests that the list saved by ReceivedFileTransfersStore.saveTransfersList +// matches the list loaded by ReceivedFileTransfersStore.load. +func TestReceivedFileTransfersStore_saveTransfersList_loadTransfersList(t *testing.T) { + rft, err := NewReceivedFileTransfersStore(versioned.NewKV(make(ekv.Memstore))) if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) } // Fill map with transfers @@ -761,13 +765,13 @@ func TestLoadReceivedFileTransfers_saveTransfersList_loadTransfersList(t *testin } } -// Tests that the list loaded by ReceivedFileTransfers.load matches the original -// saved to storage. -func TestLoadReceivedFileTransfers_load(t *testing.T) { +// Tests that the list loaded by ReceivedFileTransfersStore.load matches the +// original saved to storage. +func TestReceivedFileTransfersStore_load(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfers(kv) + rft, err := NewReceivedFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfers: %+v", err) + t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) } // Fill map with transfers @@ -789,11 +793,11 @@ func TestLoadReceivedFileTransfers_load(t *testing.T) { t.Errorf("saveTransfersList returned an error: %+v", err) } - // Build new ReceivedFileTransfers - newRFT := &ReceivedFileTransfers{ + // Build new ReceivedFileTransfersStore + newRFT := &ReceivedFileTransfersStore{ transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), info: make(map[format.Fingerprint]*partInfo), - kv: kv.Prefix(receivedFileTransfersPrefix), + kv: kv.Prefix(receivedFileTransfersStorePrefix), } // Load saved transfers from storage @@ -836,11 +840,11 @@ func TestLoadReceivedFileTransfers_load(t *testing.T) { } // Tests that a transfer list marshalled with -// ReceivedFileTransfers.marshalTransfersList and unmarshalled with +// ReceivedFileTransfersStore.marshalTransfersList and unmarshalled with // unmarshalTransfersList matches the original. -func TestReceivedFileTransfers_marshalTransfersList_unmarshalTransfersList(t *testing.T) { +func TestReceivedFileTransfersStore_marshalTransfersList_unmarshalTransfersList(t *testing.T) { prng := NewPrng(42) - rft := &ReceivedFileTransfers{ + rft := &ReceivedFileTransfersStore{ transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), } diff --git a/storage/fileTransfer/receiveTransfer.go b/storage/fileTransfer/receiveTransfer.go index fc5b48fac25406b9edae66a65ac368ed45ddce90..233278c0846685c7a3c68740c10a456e1c1682cd 100644 --- a/storage/fileTransfer/receiveTransfer.go +++ b/storage/fileTransfer/receiveTransfer.go @@ -270,7 +270,7 @@ func (rt *ReceivedTransfer) AddProgressCB( rt.mux.Unlock() // Trigger the initial call - rct.callNow(rt, nil) + rct.callNow(true, rt, nil) } // AddPart decrypts an encrypted file part, adds it to the list of received diff --git a/storage/fileTransfer/receiveTransfer_test.go b/storage/fileTransfer/receiveTransfer_test.go index e3cff77508170c0b02153977e6c028247d472499..ce8516a655df79bc707bce1af325b4009bf41895 100644 --- a/storage/fileTransfer/receiveTransfer_test.go +++ b/storage/fileTransfer/receiveTransfer_test.go @@ -232,10 +232,10 @@ func checkReceivedTracker(track ReceivedPartTracker, numParts uint16, var done bool for _, receivedNum := range received { if receivedNum == partNum { - if track.GetPartStatus(partNum) != receivedStatus { + if track.GetPartStatus(partNum) != interfaces.FpReceived { t.Errorf("Part number %d has unexpected status."+ - "\nexpected: %d\nreceived: %d", partNum, receivedStatus, - track.GetPartStatus(partNum)) + "\nexpected: %d\nreceived: %d", partNum, + interfaces.FpReceived, track.GetPartStatus(partNum)) } done = true break @@ -245,10 +245,10 @@ func checkReceivedTracker(track ReceivedPartTracker, numParts uint16, continue } - if track.GetPartStatus(partNum) != unsentStatus { + if track.GetPartStatus(partNum) != interfaces.FpUnsent { t.Errorf("Part number %d has incorrect status."+ "\nexpected: %d\nreceived: %d", - partNum, unsentStatus, track.GetPartStatus(partNum)) + partNum, interfaces.FpUnsent, track.GetPartStatus(partNum)) } } } diff --git a/storage/fileTransfer/receivedCallbackTracker.go b/storage/fileTransfer/receivedCallbackTracker.go index a0b3f558c212d8ee1603b01c80eacfc00bfa10e2..9b49a06bf39cd780c134e0a51af685eb10a4eaf6 100644 --- a/storage/fileTransfer/receivedCallbackTracker.go +++ b/storage/fileTransfer/receivedCallbackTracker.go @@ -12,6 +12,7 @@ import ( "gitlab.com/elixxir/client/stoppable" "gitlab.com/xx_network/primitives/netTime" "sync" + "sync/atomic" "time" ) @@ -28,6 +29,7 @@ type receivedCallbackTracker struct { period time.Duration // How often to call the callback lastCall time.Time // Timestamp of the last call scheduled bool // Denotes if callback call is scheduled + completed uint64 // Atomic that tells if transfer is completed stop *stoppable.Single // Stops the scheduled callback from triggering cb interfaces.ReceivedProgressCallback mux sync.RWMutex @@ -40,6 +42,7 @@ func newReceivedCallbackTracker(cb interfaces.ReceivedProgressCallback, period: period, lastCall: time.Time{}, scheduled: false, + completed: 0, stop: stoppable.NewSingle(receivedCallbackTrackerStoppable), cb: cb, } @@ -53,7 +56,7 @@ func newReceivedCallbackTracker(cb interfaces.ReceivedProgressCallback, func (rct *receivedCallbackTracker) call(tracker receivedProgressTracker, err error) { rct.mux.RLock() // Exit if a callback is already scheduled - if rct.scheduled { + if rct.scheduled || atomic.LoadUint64(&rct.completed) == 1 { rct.mux.RUnlock() return } @@ -70,7 +73,7 @@ func (rct *receivedCallbackTracker) call(tracker receivedProgressTracker, err er timeSinceLastCall := netTime.Since(rct.lastCall) if timeSinceLastCall > rct.period { // If no callback occurred, then trigger the callback now - rct.callNowUnsafe(tracker, err) + rct.callNowUnsafe(false, tracker, err) rct.lastCall = netTime.Now() } else { // If a callback did occur, then schedule a new callback to occur at the @@ -83,7 +86,7 @@ func (rct *receivedCallbackTracker) call(tracker receivedProgressTracker, err er return case <-time.NewTimer(rct.period - timeSinceLastCall).C: rct.mux.Lock() - rct.callNow(tracker, err) + rct.callNow(false, tracker, err) rct.lastCall = netTime.Now() rct.scheduled = false rct.mux.Unlock() @@ -98,19 +101,25 @@ func (rct *receivedCallbackTracker) stopThread() error { } // callNow calls the callback immediately regardless of the schedule or period. -func (rct *receivedCallbackTracker) callNow( +func (rct *receivedCallbackTracker) callNow(skipCompletedCheck bool, tracker receivedProgressTracker, err error) { completed, received, total, t := tracker.GetProgress() - go rct.cb(completed, received, total, t, err) + if skipCompletedCheck || !completed || + atomic.CompareAndSwapUint64(&rct.completed, 0, 1) { + go rct.cb(completed, received, total, t, err) + } } // callNowUnsafe calls the callback immediately regardless of the schedule or // period without taking a thread lock. This function should be used if a lock // is already taken on the receivedProgressTracker. -func (rct *receivedCallbackTracker) callNowUnsafe( +func (rct *receivedCallbackTracker) callNowUnsafe(skipCompletedCheck bool, tracker receivedProgressTracker, err error) { completed, received, total, t := tracker.getProgress() - go rct.cb(completed, received, total, t, err) + if skipCompletedCheck || !completed || + atomic.CompareAndSwapUint64(&rct.completed, 0, 1) { + go rct.cb(completed, received, total, t, err) + } } // receivedProgressTracker interface tracks the progress of a transfer. diff --git a/storage/fileTransfer/receivedPartTracker.go b/storage/fileTransfer/receivedPartTracker.go index b31d61c2f0f0db2193ee4bd1cb521ede4ba08213..78df88c987af35d04d56265e88784c0c37006d67 100644 --- a/storage/fileTransfer/receivedPartTracker.go +++ b/storage/fileTransfer/receivedPartTracker.go @@ -7,7 +7,10 @@ package fileTransfer -import "gitlab.com/elixxir/client/storage/utility" +import ( + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/storage/utility" +) // ReceivedPartTracker tracks the status of individual received file parts. type ReceivedPartTracker struct { @@ -31,11 +34,11 @@ func NewReceivedPartTracker(received *utility.StateVector) ReceivedPartTracker { // number. The possible values for the status are: // 0 = unreceived // 3 = received (receiver has received a part) -func (rpt ReceivedPartTracker) GetPartStatus(partNum uint16) int { +func (rpt ReceivedPartTracker) GetPartStatus(partNum uint16) interfaces.FpStatus { if rpt.receivedStatus.Used(uint32(partNum)) { - return receivedStatus + return interfaces.FpReceived } else { - return unsentStatus + return interfaces.FpUnsent } } diff --git a/storage/fileTransfer/receivedPartTracker_test.go b/storage/fileTransfer/receivedPartTracker_test.go index 1e8269a0254f6f4ec9d10ead1c779cf381a6edea..379c1d75f8e1f7cccf73038a7d8680e7a125ccd7 100644 --- a/storage/fileTransfer/receivedPartTracker_test.go +++ b/storage/fileTransfer/receivedPartTracker_test.go @@ -50,11 +50,11 @@ func TestReceivedPartTracker_GetPartStatus(t *testing.T) { // Set statuses of parts in the ReceivedTransfer and a map randomly prng := rand.New(rand.NewSource(42)) - partStatuses := make(map[uint16]int, rt.numParts) + partStatuses := make(map[uint16]interfaces.FpStatus, rt.numParts) for partNum := uint16(0); partNum < rt.numParts; partNum++ { - partStatuses[partNum] = prng.Intn(2) * receivedStatus + partStatuses[partNum] = interfaces.FpStatus(prng.Intn(2)) * interfaces.FpReceived - if partStatuses[partNum] == receivedStatus { + if partStatuses[partNum] == interfaces.FpReceived { rt.receivedStatus.Use(uint32(partNum)) } } diff --git a/storage/fileTransfer/sentCallbackTracker.go b/storage/fileTransfer/sentCallbackTracker.go index 51b25ed453392e46600276dd0c638f5508425d0d..f99540707ebd8b3ab7793264bdf93bfa2af329fa 100644 --- a/storage/fileTransfer/sentCallbackTracker.go +++ b/storage/fileTransfer/sentCallbackTracker.go @@ -12,6 +12,7 @@ import ( "gitlab.com/elixxir/client/stoppable" "gitlab.com/xx_network/primitives/netTime" "sync" + "sync/atomic" "time" ) @@ -28,6 +29,7 @@ type sentCallbackTracker struct { period time.Duration // How often to call the callback lastCall time.Time // Timestamp of the last call scheduled bool // Denotes if callback call is scheduled + completed uint64 // Atomic that tells if transfer is completed stop *stoppable.Single // Stops the scheduled callback from triggering cb interfaces.SentProgressCallback mux sync.RWMutex @@ -40,6 +42,7 @@ func newSentCallbackTracker(cb interfaces.SentProgressCallback, period: period, lastCall: time.Time{}, scheduled: false, + completed: 0, stop: stoppable.NewSingle(sentCallbackTrackerStoppable), cb: cb, } @@ -53,7 +56,7 @@ func newSentCallbackTracker(cb interfaces.SentProgressCallback, func (sct *sentCallbackTracker) call(tracker sentProgressTracker, err error) { sct.mux.RLock() // Exit if a callback is already scheduled - if sct.scheduled { + if sct.scheduled || atomic.LoadUint64(&sct.completed) == 1 { sct.mux.RUnlock() return } @@ -70,7 +73,7 @@ func (sct *sentCallbackTracker) call(tracker sentProgressTracker, err error) { timeSinceLastCall := netTime.Since(sct.lastCall) if timeSinceLastCall > sct.period { // If no callback occurred, then trigger the callback now - sct.callNowUnsafe(tracker, err) + sct.callNowUnsafe(false, tracker, err) sct.lastCall = netTime.Now() } else { // If a callback did occur, then schedule a new callback to occur at the @@ -83,7 +86,7 @@ func (sct *sentCallbackTracker) call(tracker sentProgressTracker, err error) { return case <-time.NewTimer(sct.period - timeSinceLastCall).C: sct.mux.Lock() - sct.callNow(tracker, err) + sct.callNow(false, tracker, err) sct.lastCall = netTime.Now() sct.scheduled = false sct.mux.Unlock() @@ -98,19 +101,25 @@ func (sct *sentCallbackTracker) stopThread() error { } // callNow calls the callback immediately regardless of the schedule or period. -func (sct *sentCallbackTracker) callNow( +func (sct *sentCallbackTracker) callNow(skipCompletedCheck bool, tracker sentProgressTracker, err error) { completed, sent, arrived, total, t := tracker.GetProgress() - go sct.cb(completed, sent, arrived, total, t, err) + if skipCompletedCheck || !completed || + atomic.CompareAndSwapUint64(&sct.completed, 0, 1) { + go sct.cb(completed, sent, arrived, total, t, err) + } } // callNowUnsafe calls the callback immediately regardless of the schedule or // period without taking a thread lock. This function should be used if a lock // is already taken on the sentProgressTracker. -func (sct *sentCallbackTracker) callNowUnsafe( +func (sct *sentCallbackTracker) callNowUnsafe(skipCompletedCheck bool, tracker sentProgressTracker, err error) { completed, sent, arrived, total, t := tracker.getProgress() - go sct.cb(completed, sent, arrived, total, t, err) + if skipCompletedCheck || !completed || + atomic.CompareAndSwapUint64(&sct.completed, 0, 1) { + go sct.cb(completed, sent, arrived, total, t, err) + } } // sentProgressTracker interface tracks the progress of a transfer. diff --git a/storage/fileTransfer/sentFileTransfers.go b/storage/fileTransfer/sentFileTransfers.go index c8a858f017ac62b63cef4af327cb83212519d654..8e095745bd57aad4097619b59cb41e59beaea2ba 100644 --- a/storage/fileTransfer/sentFileTransfers.go +++ b/storage/fileTransfer/sentFileTransfers.go @@ -23,9 +23,9 @@ import ( // Storage keys and versions. const ( - sentFileTransfersPrefix = "SentFileTransfersStore" - sentFileTransfersKey = "SentFileTransfers" - sentFileTransfersVersion = 0 + sentFileTransfersStorePrefix = "SentFileTransfersStore" + sentFileTransfersStoreKey = "SentFileTransfers" + sentFileTransfersStoreVersion = 0 ) // Error messages. @@ -40,18 +40,19 @@ const ( deleteSentTransferErr = "failed to delete sent transfer with ID %s from store: %+v" ) -// SentFileTransfers contains information for tracking sent file transfers. -type SentFileTransfers struct { +// SentFileTransfersStore contains information for tracking sent file transfers. +type SentFileTransfersStore struct { transfers map[ftCrypto.TransferID]*SentTransfer mux sync.Mutex kv *versioned.KV } -// NewSentFileTransfers creates a new SentFileTransfers with an empty map. -func NewSentFileTransfers(kv *versioned.KV) (*SentFileTransfers, error) { - sft := &SentFileTransfers{ +// NewSentFileTransfersStore creates a new SentFileTransfersStore with an empty +// map. +func NewSentFileTransfersStore(kv *versioned.KV) (*SentFileTransfersStore, error) { + sft := &SentFileTransfersStore{ transfers: make(map[ftCrypto.TransferID]*SentTransfer), - kv: kv.Prefix(sentFileTransfersPrefix), + kv: kv.Prefix(sentFileTransfersStorePrefix), } return sft, sft.saveTransfersList() @@ -59,7 +60,7 @@ func NewSentFileTransfers(kv *versioned.KV) (*SentFileTransfers, error) { // AddTransfer creates a new empty SentTransfer and adds it to the transfers // map. -func (sft *SentFileTransfers) AddTransfer(recipient *id.ID, +func (sft *SentFileTransfersStore) AddTransfer(recipient *id.ID, key ftCrypto.TransferKey, parts [][]byte, numFps uint16, progressCB interfaces.SentProgressCallback, period time.Duration, rng csprng.Source) (ftCrypto.TransferID, error) { @@ -91,7 +92,7 @@ func (sft *SentFileTransfers) AddTransfer(recipient *id.ID, // GetTransfer returns the SentTransfer with the given transfer ID. An error is // returned if no corresponding transfer is found. -func (sft *SentFileTransfers) GetTransfer(tid ftCrypto.TransferID) ( +func (sft *SentFileTransfersStore) GetTransfer(tid ftCrypto.TransferID) ( *SentTransfer, error) { sft.mux.Lock() defer sft.mux.Unlock() @@ -106,7 +107,7 @@ func (sft *SentFileTransfers) GetTransfer(tid ftCrypto.TransferID) ( // DeleteTransfer removes the SentTransfer with the associated transfer ID // from memory and storage. -func (sft *SentFileTransfers) DeleteTransfer(tid ftCrypto.TransferID) error { +func (sft *SentFileTransfersStore) DeleteTransfer(tid ftCrypto.TransferID) error { sft.mux.Lock() defer sft.mux.Unlock() @@ -140,15 +141,78 @@ func (sft *SentFileTransfers) DeleteTransfer(tid ftCrypto.TransferID) error { return nil } +// GetUnsentParts returns a map of all transfers and a list of their parts that +// have not been sent (parts that were never marked as in-progress). +func (sft *SentFileTransfersStore) GetUnsentParts() map[ftCrypto.TransferID][]uint16 { + sft.mux.Lock() + defer sft.mux.Unlock() + unsentParts := map[ftCrypto.TransferID][]uint16{} + + // Get list of unsent part numbers for each transfer + for tid, st := range sft.transfers { + unsentParts[tid] = st.GetUnsentPartNums() + } + + return unsentParts +} + +// GetSentRounds returns a map of all round IDs and which transfers have parts +// sent on those rounds (parts marked in-progress). +func (sft *SentFileTransfersStore) GetSentRounds() map[id.Round][]ftCrypto.TransferID { + sft.mux.Lock() + defer sft.mux.Unlock() + sentRounds := map[id.Round][]ftCrypto.TransferID{} + + // Get list of round IDs that transfers have in-progress rounds on + for tid, st := range sft.transfers { + for _, rid := range st.GetSentRounds() { + sentRounds[rid] = append(sentRounds[rid], tid) + } + } + + return sentRounds +} + +// GetUnsentPartsAndSentRounds returns two maps. The first is a map of all +// transfers and a list of their parts that have not been sent (parts that were +// never marked as in-progress). The seconds is a map of all round IDs and which +// transfers have parts sent on those rounds (parts marked in-progress). This +// function performs the same operations as GetUnsentParts and GetSentRounds but +// in a single loop. +func (sft *SentFileTransfersStore) GetUnsentPartsAndSentRounds() ( + map[ftCrypto.TransferID][]uint16, map[id.Round][]ftCrypto.TransferID) { + sft.mux.Lock() + defer sft.mux.Unlock() + + unsentParts := map[ftCrypto.TransferID][]uint16{} + sentRounds := map[id.Round][]ftCrypto.TransferID{} + + for tid, st := range sft.transfers { + // Get list of round IDs that transfers have in-progress rounds on + for _, rid := range st.GetSentRounds() { + sentRounds[rid] = append(sentRounds[rid], tid) + } + + // Get list of unsent part numbers for each transfer + stUnsentParts := st.GetUnsentPartNums() + if len(stUnsentParts) > 0 { + unsentParts[tid] = stUnsentParts + } + } + + return unsentParts, sentRounds +} + //////////////////////////////////////////////////////////////////////////////// // Storage Functions // //////////////////////////////////////////////////////////////////////////////// -// LoadSentFileTransfers loads all SentFileTransfers from storage. -func LoadSentFileTransfers(kv *versioned.KV) (*SentFileTransfers, error) { - sft := &SentFileTransfers{ +// LoadSentFileTransfersStore loads all SentFileTransfersStore from storage. +// Returns a list of unsent file parts. +func LoadSentFileTransfersStore(kv *versioned.KV) (*SentFileTransfersStore, error) { + sft := &SentFileTransfersStore{ transfers: make(map[ftCrypto.TransferID]*SentTransfer), - kv: kv.Prefix(sentFileTransfersPrefix), + kv: kv.Prefix(sentFileTransfersStorePrefix), } // Get the list of transfer IDs corresponding to each sent transfer from @@ -167,19 +231,23 @@ func LoadSentFileTransfers(kv *versioned.KV) (*SentFileTransfers, error) { return sft, nil } -// NewOrLoadSentFileTransfers loads all SentFileTransfers from storage, if they -// exist. Otherwise, a new SentFileTransfers is returned. -func NewOrLoadSentFileTransfers(kv *versioned.KV) (*SentFileTransfers, error) { - sft := &SentFileTransfers{ +// NewOrLoadSentFileTransfersStore loads all SentFileTransfersStore from storage +// and returns a list of unsent file parts, if they exist. Otherwise, a new +// SentFileTransfersStore is returned. +func NewOrLoadSentFileTransfersStore(kv *versioned.KV) (*SentFileTransfersStore, + error) { + sft := &SentFileTransfersStore{ transfers: make(map[ftCrypto.TransferID]*SentTransfer), - kv: kv.Prefix(sentFileTransfersPrefix), + kv: kv.Prefix(sentFileTransfersStorePrefix), } // If the transfer list cannot be loaded from storage, then create a new - // SentFileTransfers - vo, err := sft.kv.Get(sentFileTransfersKey, sentFileTransfersVersion) + // SentFileTransfersStore + vo, err := sft.kv.Get( + sentFileTransfersStoreKey, sentFileTransfersStoreVersion) if err != nil { - return NewSentFileTransfers(kv) + newSFT, err := NewSentFileTransfersStore(kv) + return newSFT, err } // Unmarshal data into list of saved transfer IDs @@ -195,23 +263,26 @@ func NewOrLoadSentFileTransfers(kv *versioned.KV) (*SentFileTransfers, error) { } // saveTransfersList saves a list of items in the transfers map to storage. -func (sft *SentFileTransfers) saveTransfersList() error { +func (sft *SentFileTransfersStore) saveTransfersList() error { // Create new versioned object with a list of items in the transfers map obj := &versioned.Object{ - Version: sentFileTransfersVersion, + Version: sentFileTransfersStoreVersion, Timestamp: netTime.Now(), Data: sft.marshalTransfersList(), } // Save list of items in the transfers map to storage - return sft.kv.Set(sentFileTransfersKey, sentFileTransfersVersion, obj) + return sft.kv.Set( + sentFileTransfersStoreKey, sentFileTransfersStoreVersion, obj) } // loadTransfersList gets the list of transfer IDs corresponding to each saved // sent transfer from storage. -func (sft *SentFileTransfers) loadTransfersList() ([]ftCrypto.TransferID, error) { +func (sft *SentFileTransfersStore) loadTransfersList() ([]ftCrypto.TransferID, + error) { // Get transfers list from storage - vo, err := sft.kv.Get(sentFileTransfersKey, sentFileTransfersVersion) + vo, err := sft.kv.Get( + sentFileTransfersStoreKey, sentFileTransfersStoreVersion) if err != nil { return nil, err } @@ -221,7 +292,9 @@ func (sft *SentFileTransfers) loadTransfersList() ([]ftCrypto.TransferID, error) } // loadTransfers loads each SentTransfer from the list and adds them to the map. -func (sft *SentFileTransfers) loadTransfers(list []ftCrypto.TransferID) error { +// Returns a map of all transfers and their unsent file part numbers to be used +// to add them back into the queue. +func (sft *SentFileTransfersStore) loadTransfers(list []ftCrypto.TransferID) error { var err error // Load each sentTransfer from storage into the map @@ -237,7 +310,7 @@ func (sft *SentFileTransfers) loadTransfers(list []ftCrypto.TransferID) error { // marshalTransfersList creates a list of all transfer IDs in the transfers map // and serialises it. -func (sft *SentFileTransfers) marshalTransfersList() []byte { +func (sft *SentFileTransfersStore) marshalTransfersList() []byte { buff := bytes.NewBuffer(nil) buff.Grow(ftCrypto.TransferIdLength * len(sft.transfers)) diff --git a/storage/fileTransfer/sentFileTransfers_test.go b/storage/fileTransfer/sentFileTransfers_test.go index cae6426eddf056a9f4a14a3d4158f6e6c2c189c8..6d1658deed15baa29f5cff1f3dafc3e31283f3d4 100644 --- a/storage/fileTransfer/sentFileTransfers_test.go +++ b/storage/fileTransfer/sentFileTransfers_test.go @@ -14,42 +14,44 @@ import ( "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/netTime" "reflect" + "sort" "strings" "testing" ) -// Tests that NewSentFileTransfers creates a new object with empty maps and that -// it is saved to storage -func TestNewSentFileTransfers(t *testing.T) { +// Tests that NewSentFileTransfersStore creates a new object with empty maps and +// that it is saved to storage +func TestNewSentFileTransfersStore(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - expectedSFT := &SentFileTransfers{ + expectedSFT := &SentFileTransfersStore{ transfers: make(map[ftCrypto.TransferID]*SentTransfer), - kv: kv.Prefix(sentFileTransfersPrefix), + kv: kv.Prefix(sentFileTransfersStorePrefix), } - sft, err := NewSentFileTransfers(kv) + sft, err := NewSentFileTransfersStore(kv) - // Check that the new SentFileTransfers matches the expected + // Check that the new SentFileTransfersStore matches the expected if !reflect.DeepEqual(expectedSFT, sft) { - t.Errorf("New SentFileTransfers does not match expected."+ + t.Errorf("New SentFileTransfersStore does not match expected."+ "\nexpected: %+v\nreceived: %+v", expectedSFT, sft) } // Ensure that the transfer list is saved to storage - _, err = expectedSFT.kv.Get(sentFileTransfersKey, sentFileTransfersVersion) + _, err = expectedSFT.kv.Get( + sentFileTransfersStoreKey, sentFileTransfersStoreVersion) if err != nil { t.Errorf("Failed to load transfer list from storage: %+v", err) } } -// Tests that SentFileTransfers.AddTransfer adds a new transfer to the map and -// that its ID is saved to the list in storage. -func TestSentFileTransfers_AddTransfer(t *testing.T) { +// Tests that SentFileTransfersStore.AddTransfer adds a new transfer to the map +// and that its ID is saved to the list in storage. +func TestSentFileTransfersStore_AddTransfer(t *testing.T) { prng := NewPrng(42) kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfers(kv) + sft, err := NewSentFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new SentFileTransfers: %+v", err) + t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) } // Generate info for new transfer @@ -80,14 +82,14 @@ func TestSentFileTransfers_AddTransfer(t *testing.T) { } } -// Error path: tests that SentFileTransfers.AddTransfer returns the expected -// error when the PRNG returns an error. -func TestSentFileTransfers_AddTransfer_NewTransferIdRngError(t *testing.T) { +// Error path: tests that SentFileTransfersStore.AddTransfer returns the +// expected error when the PRNG returns an error. +func TestSentFileTransfersStore_AddTransfer_NewTransferIdRngError(t *testing.T) { prng := NewPrngErr() kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfers(kv) + sft, err := NewSentFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new SentFileTransfers: %+v", err) + t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) } // Add the transfer @@ -99,14 +101,14 @@ func TestSentFileTransfers_AddTransfer_NewTransferIdRngError(t *testing.T) { } } -// Tests that SentFileTransfers.GetTransfer returns the expected transfer from -// the map. -func TestSentFileTransfers_GetTransfer(t *testing.T) { +// Tests that SentFileTransfersStore.GetTransfer returns the expected transfer +// from the map. +func TestSentFileTransfersStore_GetTransfer(t *testing.T) { prng := NewPrng(42) kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfers(kv) + sft, err := NewSentFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new SentFileTransfers: %+v", err) + t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) } // Generate info for new transfer @@ -131,14 +133,15 @@ func TestSentFileTransfers_GetTransfer(t *testing.T) { } } -// Error path: tests that SentFileTransfers.GetTransfer returns the expected -// error when the map is empty/there is no transfer with the given transfer ID. -func TestSentFileTransfers_GetTransfer_NoTransferError(t *testing.T) { +// Error path: tests that SentFileTransfersStore.GetTransfer returns the +// expected error when the map is empty/there is no transfer with the given +// transfer ID. +func TestSentFileTransfersStore_GetTransfer_NoTransferError(t *testing.T) { prng := NewPrng(42) kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfers(kv) + sft, err := NewSentFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new SentFileTransfers: %+v", err) + t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) } tid, _ := ftCrypto.NewTransferID(prng) @@ -150,14 +153,14 @@ func TestSentFileTransfers_GetTransfer_NoTransferError(t *testing.T) { } } -// Tests that SentFileTransfers.DeleteTransfer removes the transfer from the map -// in memory and from the list in storage. -func TestSentFileTransfers_DeleteTransfer(t *testing.T) { +// Tests that SentFileTransfersStore.DeleteTransfer removes the transfer from +// the map in memory and from the list in storage. +func TestSentFileTransfersStore_DeleteTransfer(t *testing.T) { prng := NewPrng(42) kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfers(kv) + sft, err := NewSentFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new SentFileTransfers: %+v", err) + t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) } // Generate info for new transfer @@ -178,7 +181,8 @@ func TestSentFileTransfers_DeleteTransfer(t *testing.T) { transfer, err := sft.GetTransfer(tid) if err == nil { - t.Errorf("No error getting transfer that should be deleted: %+v", transfer) + t.Errorf("No error getting transfer that should be deleted: %+v", + transfer) } list, err := sft.loadTransfersList() @@ -191,14 +195,15 @@ func TestSentFileTransfers_DeleteTransfer(t *testing.T) { } } -// Error path: tests that SentFileTransfers.DeleteTransfer returns the expected -// error when the map is empty/there is no transfer with the given transfer ID. -func TestSentFileTransfers_DeleteTransfer_NoTransferError(t *testing.T) { +// Error path: tests that SentFileTransfersStore.DeleteTransfer returns the +// expected error when the map is empty/there is no transfer with the given +// transfer ID. +func TestSentFileTransfersStore_DeleteTransfer_NoTransferError(t *testing.T) { prng := NewPrng(42) kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfers(kv) + sft, err := NewSentFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to create new SentFileTransfers: %+v", err) + t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) } tid, _ := ftCrypto.NewTransferID(prng) @@ -210,17 +215,232 @@ func TestSentFileTransfers_DeleteTransfer_NoTransferError(t *testing.T) { } } +// Tests that SentFileTransfersStore.GetUnsentParts returns the expected unsent +// parts for each transfer. Transfers are created with increasing number of +// parts. Each part of each transfer is set as either in-progress, finished, or +// unsent. +func TestSentFileTransfersStore_GetUnsentParts(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + sft, err := NewSentFileTransfersStore(kv) + if err != nil { + t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) + } + + n := uint16(3) + expectedParts := make(map[ftCrypto.TransferID][]uint16, n) + + // Add new transfers + for i := uint16(0); i < n; i++ { + recipient := id.NewIdFromUInt(uint64(i), id.User, t) + key, _ := ftCrypto.NewTransferKey(prng) + parts, _ := newRandomPartSlice((i+1)*6, prng, t) + numParts := uint16(len(parts)) + numFps := numParts * 3 / 2 + + tid, err := sft.AddTransfer(recipient, key, parts, numFps, nil, 0, prng) + if err != nil { + t.Errorf("Failed to add transfer %d: %+v", i, err) + } + + // Loop through each part and set it individually + for j := uint16(0); j < numParts; j++ { + switch ((j + i) % numParts) % 3 { + case 0: + // Part is sent (in-progress) + _, _ = sft.transfers[tid].SetInProgress(id.Round(j), j) + case 1: + // Part is sent and arrived (finished) + _, _ = sft.transfers[tid].SetInProgress(id.Round(j), j) + _, _ = sft.transfers[tid].FinishTransfer(id.Round(j)) + case 2: + // Part is unsent (neither in-progress nor arrived) + expectedParts[tid] = append(expectedParts[tid], j) + } + } + } + + unsentParts := sft.GetUnsentParts() + + if !reflect.DeepEqual(expectedParts, unsentParts) { + t.Errorf("Unexpected unsent parts map.\nexpected: %+v\nreceived: %+v", + expectedParts, unsentParts) + } +} + +// Tests that SentFileTransfersStore.GetSentRounds returns the expected transfer +// ID for each unfinished round. Transfers are created with increasing number of +// parts. Each part of each transfer is set as either in-progress, finished, or +// unsent. +func TestSentFileTransfersStore_GetSentRounds(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + sft, err := NewSentFileTransfersStore(kv) + if err != nil { + t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) + } + + n := uint16(3) + expectedRounds := make(map[id.Round][]ftCrypto.TransferID) + + // Add new transfers + for i := uint16(0); i < n; i++ { + recipient := id.NewIdFromUInt(uint64(i), id.User, t) + key, _ := ftCrypto.NewTransferKey(prng) + parts, _ := newRandomPartSlice((i+1)*6, prng, t) + numParts := uint16(len(parts)) + numFps := numParts * 3 / 2 + + tid, err := sft.AddTransfer(recipient, key, parts, numFps, nil, 0, prng) + if err != nil { + t.Errorf("Failed to add transfer %d: %+v", i, err) + } + + // Loop through each part and set it individually + for j := uint16(0); j < numParts; j++ { + rid := id.Round(j) + switch j % 3 { + case 0: + // Part is sent (in-progress) + _, _ = sft.transfers[tid].SetInProgress(id.Round(j), j) + expectedRounds[rid] = append(expectedRounds[rid], tid) + case 1: + // Part is sent and arrived (finished) + _, _ = sft.transfers[tid].SetInProgress(id.Round(j), j) + _, _ = sft.transfers[tid].FinishTransfer(id.Round(j)) + case 2: + // Part is unsent (neither in-progress nor arrived) + } + } + } + + // Sort expected rounds map transfer IDs + for _, tIDs := range expectedRounds { + sort.Slice(tIDs, + func(i, j int) bool { return tIDs[i].String() < tIDs[j].String() }) + } + + sentRounds := sft.GetSentRounds() + + // Sort sent rounds map transfer IDs + for _, tIDs := range sentRounds { + sort.Slice(tIDs, + func(i, j int) bool { return tIDs[i].String() < tIDs[j].String() }) + } + + if !reflect.DeepEqual(expectedRounds, sentRounds) { + t.Errorf("Unexpected sent rounds map.\nexpected: %+v\nreceived: %+v", + expectedRounds, sentRounds) + } +} + +// Tests that SentFileTransfersStore.GetUnsentPartsAndSentRounds +func TestSentFileTransfersStore_GetUnsentPartsAndSentRounds(t *testing.T) { + prng := NewPrng(42) + kv := versioned.NewKV(make(ekv.Memstore)) + sft, err := NewSentFileTransfersStore(kv) + if err != nil { + t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) + } + + n := uint16(3) + expectedParts := make(map[ftCrypto.TransferID][]uint16, n) + expectedRounds := make(map[id.Round][]ftCrypto.TransferID) + + // Add new transfers + for i := uint16(0); i < n; i++ { + recipient := id.NewIdFromUInt(uint64(i), id.User, t) + key, _ := ftCrypto.NewTransferKey(prng) + parts, _ := newRandomPartSlice((i+1)*6, prng, t) + numParts := uint16(len(parts)) + numFps := numParts * 3 / 2 + + tid, err := sft.AddTransfer(recipient, key, parts, numFps, nil, 0, prng) + if err != nil { + t.Errorf("Failed to add transfer %d: %+v", i, err) + } + + // Loop through each part and set it individually + for j := uint16(0); j < numParts; j++ { + rid := id.Round(j) + switch j % 3 { + case 0: + // Part is sent (in-progress) + _, _ = sft.transfers[tid].SetInProgress(rid, j) + expectedRounds[rid] = append(expectedRounds[rid], tid) + case 1: + // Part is sent and arrived (finished) + _, _ = sft.transfers[tid].SetInProgress(rid, j) + _, _ = sft.transfers[tid].FinishTransfer(rid) + case 2: + // Part is unsent (neither in-progress nor arrived) + expectedParts[tid] = append(expectedParts[tid], j) + } + } + } + + // Sort expected rounds map transfer IDs + for _, tIDs := range expectedRounds { + sort.Slice(tIDs, + func(i, j int) bool { return tIDs[i].String() < tIDs[j].String() }) + } + + unsentParts, sentRounds := sft.GetUnsentPartsAndSentRounds() + + // Sort sent rounds map transfer IDs + for _, tIDs := range sentRounds { + sort.Slice(tIDs, + func(i, j int) bool { return tIDs[i].String() < tIDs[j].String() }) + } + + if !reflect.DeepEqual(expectedParts, unsentParts) { + t.Errorf("Unexpected unsent parts map.\nexpected: %+v\nreceived: %+v", + expectedParts, unsentParts) + } + + if !reflect.DeepEqual(expectedRounds, sentRounds) { + t.Errorf("Unexpected sent rounds map.\nexpected: %+v\nreceived: %+v", + expectedRounds, sentRounds) + } + + unsentParts2 := sft.GetUnsentParts() + + if !reflect.DeepEqual(unsentParts, unsentParts2) { + t.Errorf("Unsent parts from GetUnsentParts and "+ + "GetUnsentPartsAndSentRounds do not match."+ + "\nGetUnsentParts: %+v"+ + "\nGetUnsentPartsAndSentRounds: %+v", + unsentParts, unsentParts2) + } + + sentRounds2 := sft.GetSentRounds() + + // Sort sent rounds map transfer IDs + for _, tIDs := range sentRounds2 { + sort.Slice(tIDs, + func(i, j int) bool { return tIDs[i].String() < tIDs[j].String() }) + } + + if !reflect.DeepEqual(sentRounds, sentRounds2) { + t.Errorf("Sent rounds map from GetSentRounds and "+ + "GetUnsentPartsAndSentRounds do not match."+ + "\nGetSentRounds: %+v"+ + "\nGetUnsentPartsAndSentRounds: %+v", + sentRounds, sentRounds2) + } +} + //////////////////////////////////////////////////////////////////////////////// // Storage Functions // //////////////////////////////////////////////////////////////////////////////// -// Tests that the SentFileTransfers loaded from storage by LoadSentFileTransfers -// matches the original in memory. -func TestLoadSentFileTransfers(t *testing.T) { +// Tests that the SentFileTransfersStore loaded from storage by +// LoadSentFileTransfersStore matches the original in memory. +func TestLoadSentFileTransfersStore(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfers(kv) + sft, err := NewSentFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to make new SentFileTransfers: %+v", err) + t.Fatalf("Failed to make new SentFileTransfersStore: %+v", err) } // Add 10 transfers to map in memory @@ -236,70 +456,70 @@ func TestLoadSentFileTransfers(t *testing.T) { t.Errorf("Faileds to save transfers list: %+v", err) } - // Load SentFileTransfers from storage - loadedSFT, err := LoadSentFileTransfers(kv) + // Load SentFileTransfersStore from storage + loadedSFT, err := LoadSentFileTransfersStore(kv) if err != nil { - t.Errorf("LoadSentFileTransfers returned an error: %+v", err) + t.Errorf("LoadSentFileTransfersStore returned an error: %+v", err) } - // Equalize all progressCallbacks because reflect.DeepEqual does not seem - // to work on function pointers + // Equalize all progressCallbacks because reflect.DeepEqual does not seem to + // work on function pointers for _, tid := range list { loadedSFT.transfers[tid].progressCallbacks = sft.transfers[tid].progressCallbacks } if !reflect.DeepEqual(sft, loadedSFT) { - t.Errorf("Loaded SentFileTransfers does not match original in memory."+ - "\nexpected: %+v\nreceived: %+v", sft, loadedSFT) + t.Errorf("Loaded SentFileTransfersStore does not match original in "+ + "memory.\nexpected: %+v\nreceived: %+v", sft, loadedSFT) } } -// Error path: tests that LoadSentFileTransfers returns the expected error when -// the transfer list cannot be loaded from storage. -func TestLoadSentFileTransfers_NoListInStorageError(t *testing.T) { +// Error path: tests that LoadSentFileTransfersStore returns the expected error +// when the transfer list cannot be loaded from storage. +func TestLoadSentFileTransfersStore_NoListInStorageError(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) expectedErr := strings.Split(loadSentTransfersListErr, "%")[0] - // Load SentFileTransfers from storage - _, err := LoadSentFileTransfers(kv) + // Load SentFileTransfersStore from storage + _, err := LoadSentFileTransfersStore(kv) if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("LoadSentFileTransfers did not return the expected error "+ - "when there is no transfer list saved in storage."+ + t.Errorf("LoadSentFileTransfersStore did not return the expected "+ + "error when there is no transfer list saved in storage."+ "\nexpected: %s\nreceived: %+v", expectedErr, err) } } -// Error path: tests that LoadSentFileTransfers returns the expected error when -// the first transfer loaded from storage does not exist. -func TestLoadSentFileTransfers_NoTransferInStorageError(t *testing.T) { +// Error path: tests that LoadSentFileTransfersStore returns the expected error +// when the first transfer loaded from storage does not exist. +func TestLoadSentFileTransfersStore_NoTransferInStorageError(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) expectedErr := strings.Split(loadSentTransfersErr, "%")[0] // Save list of one transfer ID to storage obj := &versioned.Object{ - Version: sentFileTransfersVersion, + Version: sentFileTransfersStoreVersion, Timestamp: netTime.Now(), Data: ftCrypto.UnmarshalTransferID([]byte("testID_01")).Bytes(), } - err := kv.Prefix(sentFileTransfersPrefix).Set( - sentFileTransfersKey, sentFileTransfersVersion, obj) + err := kv.Prefix(sentFileTransfersStorePrefix).Set( + sentFileTransfersStoreKey, sentFileTransfersStoreVersion, obj) - // Load SentFileTransfers from storage - _, err = LoadSentFileTransfers(kv) + // Load SentFileTransfersStore from storage + _, err = LoadSentFileTransfersStore(kv) if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("LoadSentFileTransfers did not return the expected error "+ - "when there is no transfer saved in storage."+ + t.Errorf("LoadSentFileTransfersStore did not return the expected "+ + "error when there is no transfer saved in storage."+ "\nexpected: %s\nreceived: %+v", expectedErr, err) } } -// Tests that the SentFileTransfers loaded from storage by -// NewOrLoadSentFileTransfers matches the original in memory. -func TestNewOrLoadSentFileTransfers(t *testing.T) { +// Tests that the SentFileTransfersStore loaded from storage by +// NewOrLoadSentFileTransfersStore matches the original in memory. +func TestNewOrLoadSentFileTransfersStore(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfers(kv) + sft, err := NewSentFileTransfersStore(kv) if err != nil { - t.Fatalf("Failed to make new SentFileTransfers: %+v", err) + t.Fatalf("Failed to make new SentFileTransfersStore: %+v", err) } // Add 10 transfers to map in memory @@ -315,74 +535,76 @@ func TestNewOrLoadSentFileTransfers(t *testing.T) { t.Errorf("Faileds to save transfers list: %+v", err) } - // Load SentFileTransfers from storage - loadedSFT, err := NewOrLoadSentFileTransfers(kv) + // Load SentFileTransfersStore from storage + loadedSFT, err := NewOrLoadSentFileTransfersStore(kv) if err != nil { - t.Errorf("NewOrLoadSentFileTransfers returned an error: %+v", err) + t.Errorf("NewOrLoadSentFileTransfersStore returned an error: %+v", err) } // Equalize all progressCallbacks because reflect.DeepEqual does not seem // to work on function pointers for _, tid := range list { - loadedSFT.transfers[tid].progressCallbacks = sft.transfers[tid].progressCallbacks + loadedSFT.transfers[tid].progressCallbacks = + sft.transfers[tid].progressCallbacks } if !reflect.DeepEqual(sft, loadedSFT) { - t.Errorf("Loaded SentFileTransfers does not match original in memory."+ - "\nexpected: %+v\nreceived: %+v", sft, loadedSFT) + t.Errorf("Loaded SentFileTransfersStore does not match original in "+ + "memory.\nexpected: %+v\nreceived: %+v", sft, loadedSFT) } } -// Tests that NewOrLoadSentFileTransfers returns a new SentFileTransfers when -// there is none in storage. -func TestNewOrLoadSentFileTransfers_NewSentFileTransfers(t *testing.T) { +// Tests that NewOrLoadSentFileTransfersStore returns a new +// SentFileTransfersStore when there is none in storage. +func TestNewOrLoadSentFileTransfersStore_NewSentFileTransfersStore(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - // Load SentFileTransfers from storage - loadedSFT, err := NewOrLoadSentFileTransfers(kv) + // Load SentFileTransfersStore from storage + loadedSFT, err := NewOrLoadSentFileTransfersStore(kv) if err != nil { - t.Errorf("NewOrLoadSentFileTransfers returned an error: %+v", err) + t.Errorf("NewOrLoadSentFileTransfersStore returned an error: %+v", err) } - newSFT, _ := NewSentFileTransfers(kv) + newSFT, _ := NewSentFileTransfersStore(kv) if !reflect.DeepEqual(newSFT, loadedSFT) { - t.Errorf("Returned SentFileTransfers does not match new."+ + t.Errorf("Returned SentFileTransfersStore does not match new."+ "\nexpected: %+v\nreceived: %+v", newSFT, loadedSFT) } } -// Error path: tests that the NewOrLoadSentFileTransfers returns the expected -// error when the first transfer loaded from storage does not exist. -func TestNewOrLoadSentFileTransfers_NoTransferInStorageError(t *testing.T) { +// Error path: tests that the NewOrLoadSentFileTransfersStore returns the +// expected error when the first transfer loaded from storage does not exist. +func TestNewOrLoadSentFileTransfersStore_NoTransferInStorageError(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) expectedErr := strings.Split(loadSentTransfersErr, "%")[0] // Save list of one transfer ID to storage obj := &versioned.Object{ - Version: sentFileTransfersVersion, + Version: sentFileTransfersStoreVersion, Timestamp: netTime.Now(), Data: ftCrypto.UnmarshalTransferID([]byte("testID_01")).Bytes(), } - err := kv.Prefix(sentFileTransfersPrefix).Set( - sentFileTransfersKey, sentFileTransfersVersion, obj) + err := kv.Prefix(sentFileTransfersStorePrefix).Set( + sentFileTransfersStoreKey, sentFileTransfersStoreVersion, obj) - // Load SentFileTransfers from storage - _, err = NewOrLoadSentFileTransfers(kv) + // Load SentFileTransfersStore from storage + _, err = NewOrLoadSentFileTransfersStore(kv) if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("NewOrLoadSentFileTransfers did not return the expected "+ + t.Errorf("NewOrLoadSentFileTransfersStore did not return the expected "+ "error when there is no transfer saved in storage."+ "\nexpected: %s\nreceived: %+v", expectedErr, err) } } -// Tests that SentFileTransfers.saveTransfersList saves all the transfer IDs to -// storage by loading them from storage via SentFileTransfers.loadTransfersList -// and comparing the list to the list in memory. -func TestSentFileTransfers_saveTransfersList_loadTransfersList(t *testing.T) { +// Tests that SentFileTransfersStore.saveTransfersList saves all the transfer +// IDs to storage by loading them from storage via +// SentFileTransfersStore.loadTransfersList and comparing the list to the list +// in memory. +func TestSentFileTransfersStore_saveTransfersList_loadTransfersList(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - sft := &SentFileTransfers{ + sft := &SentFileTransfersStore{ transfers: make(map[ftCrypto.TransferID]*SentTransfer), - kv: kv.Prefix(sentFileTransfersPrefix), + kv: kv.Prefix(sentFileTransfersStorePrefix), } // Add 10 transfers to map in memory @@ -413,13 +635,13 @@ func TestSentFileTransfers_saveTransfersList_loadTransfersList(t *testing.T) { } } -// Tests that the transfer loaded by SentFileTransfers.loadTransfers from +// Tests that the transfer loaded by SentFileTransfersStore.loadTransfers from // storage matches the original in memory -func TestSentFileTransfers_loadTransfers(t *testing.T) { +func TestSentFileTransfersStore_loadTransfers(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - sft := &SentFileTransfers{ + sft := &SentFileTransfersStore{ transfers: make(map[ftCrypto.TransferID]*SentTransfer), - kv: kv.Prefix(sentFileTransfersPrefix), + kv: kv.Prefix(sentFileTransfersStorePrefix), } // Add 10 transfers to map in memory @@ -430,10 +652,10 @@ func TestSentFileTransfers_loadTransfers(t *testing.T) { list[i] = tid } - // Load the transfers into a new SentFileTransfers - loadedSft := &SentFileTransfers{ + // Load the transfers into a new SentFileTransfersStore + loadedSft := &SentFileTransfersStore{ transfers: make(map[ftCrypto.TransferID]*SentTransfer), - kv: kv.Prefix(sentFileTransfersPrefix), + kv: kv.Prefix(sentFileTransfersStorePrefix), } err := loadedSft.loadTransfers(list) if err != nil { @@ -443,21 +665,23 @@ func TestSentFileTransfers_loadTransfers(t *testing.T) { // Equalize all progressCallbacks because reflect.DeepEqual does not seem // to work on function pointers for _, tid := range list { - loadedSft.transfers[tid].progressCallbacks = sft.transfers[tid].progressCallbacks + loadedSft.transfers[tid].progressCallbacks = + sft.transfers[tid].progressCallbacks } if !reflect.DeepEqual(sft.transfers, loadedSft.transfers) { - t.Errorf("Transfers loaded from storage does not match transfers in memory."+ - "\nexpected: %+v\nreceived: %+v", sft.transfers, loadedSft.transfers) + t.Errorf("Transfers loaded from storage does not match transfers in "+ + "memory.\nexpected: %+v\nreceived: %+v", + sft.transfers, loadedSft.transfers) } } // Tests that a transfer list marshalled with -// SentFileTransfers.marshalTransfersList and unmarshalled with +// SentFileTransfersStore.marshalTransfersList and unmarshalled with // unmarshalTransfersList matches the original. -func TestSentFileTransfers_marshalTransfersList_unmarshalTransfersList(t *testing.T) { +func TestSentFileTransfersStore_marshalTransfersList_unmarshalTransfersList(t *testing.T) { prng := NewPrng(42) - sft := &SentFileTransfers{ + sft := &SentFileTransfersStore{ transfers: make(map[ftCrypto.TransferID]*SentTransfer), } diff --git a/storage/fileTransfer/sentPartTracker.go b/storage/fileTransfer/sentPartTracker.go index 8c09acd1f5ee5432998beaf50621147792319e7c..3f2e95addeca5b255fffcbf9111c05f003607d96 100644 --- a/storage/fileTransfer/sentPartTracker.go +++ b/storage/fileTransfer/sentPartTracker.go @@ -8,17 +8,10 @@ package fileTransfer import ( + "gitlab.com/elixxir/client/interfaces" "gitlab.com/elixxir/client/storage/utility" ) -// Part statuses. -const ( - unsentStatus = 0 - sentStatus = 1 - arrivedStatus = 2 - receivedStatus = 3 -) - // SentPartTracker tracks the status of individual sent file parts. type SentPartTracker struct { // The number of file parts in the file @@ -46,13 +39,13 @@ func NewSentPartTracker(inProgress, finished *utility.StateVector) SentPartTrack // 0 = unsent // 1 = sent (sender has sent a part, but it has not arrived) // 2 = arrived (sender has sent a part, and it has arrived) -func (spt SentPartTracker) GetPartStatus(partNum uint16) int { +func (spt SentPartTracker) GetPartStatus(partNum uint16) interfaces.FpStatus { if spt.inProgressStatus.Used(uint32(partNum)) { - return sentStatus + return interfaces.FpSent } else if spt.finishedStatus.Used(uint32(partNum)) { - return arrivedStatus + return interfaces.FpArrived } else { - return unsentStatus + return interfaces.FpUnsent } } diff --git a/storage/fileTransfer/sentPartTracker_test.go b/storage/fileTransfer/sentPartTracker_test.go index beb72908cd63a0a11be9785b81215e320b9d9717..89d37fc5213939c47ae92e02391c6a3952d32e95 100644 --- a/storage/fileTransfer/sentPartTracker_test.go +++ b/storage/fileTransfer/sentPartTracker_test.go @@ -51,14 +51,14 @@ func TestSentPartTracker_GetPartStatus(t *testing.T) { // Set statuses of parts in the SentTransfer and a map randomly prng := rand.New(rand.NewSource(42)) - partStatuses := make(map[uint16]int, st.numParts) + partStatuses := make(map[uint16]interfaces.FpStatus, st.numParts) for partNum := uint16(0); partNum < st.numParts; partNum++ { - partStatuses[partNum] = prng.Intn(3) + partStatuses[partNum] = interfaces.FpStatus(prng.Intn(3)) switch partStatuses[partNum] { - case sentStatus: + case interfaces.FpSent: st.inProgressStatus.Use(uint32(partNum)) - case arrivedStatus: + case interfaces.FpArrived: st.finishedStatus.Use(uint32(partNum)) } } diff --git a/storage/fileTransfer/sentTransfer.go b/storage/fileTransfer/sentTransfer.go index f6142743cf41b84842969cae846b918d973ea847..380f49e088cec9f29e67b3161308e090e74c0c42 100644 --- a/storage/fileTransfer/sentTransfer.go +++ b/storage/fileTransfer/sentTransfer.go @@ -129,20 +129,12 @@ type SentTransfer struct { // status indicates that the transfer is either done or errored out and // that no more callbacks should be called - status transferStatus + status TransferStatus mux sync.RWMutex kv *versioned.KV } -type transferStatus int - -const ( - running transferStatus = iota - stopping - stopped -) - // NewSentTransfer generates a new SentTransfer with the specified transfer key, // transfer ID, and number of parts. func NewSentTransfer(recipient *id.ID, tid ftCrypto.TransferID, @@ -157,7 +149,7 @@ func NewSentTransfer(recipient *id.ID, tid ftCrypto.TransferID, numParts: uint16(len(parts)), numFps: numFps, progressCallbacks: []*sentCallbackTracker{}, - status: running, + status: Running, kv: kv.Prefix(makeSentTransferPrefix(tid)), } @@ -221,7 +213,7 @@ func (st *SentTransfer) ReInit(numFps uint16, var err error // Mark the status as running - st.status = running + st.status = Running // Update number of fingerprints and overwrite old fingerprint vector st.numFps = numFps @@ -267,7 +259,7 @@ func (st *SentTransfer) ReInit(numFps uint16, st.progressCallbacks = append(st.progressCallbacks, sct) // Trigger the initial call - sct.callNowUnsafe(st, nil) + sct.callNowUnsafe(true, st, nil) } return nil @@ -313,6 +305,14 @@ func (st *SentTransfer) GetNumAvailableFps() uint16 { return uint16(st.fpVector.GetNumAvailable()) } +// GetStatus returns the status of the sent transfer. +func (st *SentTransfer) GetStatus() TransferStatus { + st.mux.RLock() + defer st.mux.RUnlock() + + return st.status +} + // IsPartInProgress returns true if the part has successfully been sent. Returns // false if the part is unsent or finished sending or if the part number is // invalid. @@ -362,12 +362,8 @@ func (st *SentTransfer) getProgress() (completed bool, sent, arrived, func (st *SentTransfer) CallProgressCB(err error) { st.mux.Lock() - switch st.status { - case stopped: - st.mux.Unlock() - return - case stopping: - st.status = stopped + if st.status == Stopping { + st.status = Stopped } st.mux.Unlock() @@ -417,7 +413,7 @@ func (st *SentTransfer) AddProgressCB(cb interfaces.SentProgressCallback, st.mux.Unlock() // Trigger the initial call - sct.callNow(st, nil) + sct.callNow(true, st, nil) } // GetEncryptedPart gets the specified part, encrypts it, and returns the @@ -439,7 +435,7 @@ func (st *SentTransfer) GetEncryptedPart(partNum uint16, partSize int, // the status to stopping and return an error specifying that all the // retries have been used if st.fpVector.GetNumAvailable() < 1 { - st.status = stopping + st.status = Stopping return nil, nil, nil, format.Fingerprint{}, MaxRetriesErr } @@ -509,8 +505,9 @@ func (st *SentTransfer) UnsetInProgress(rid id.Round) ([]uint16, error) { } // FinishTransfer moves the in-progress file parts for the given round to the -// finished list. -func (st *SentTransfer) FinishTransfer(rid id.Round) error { +// finished list. Returns true if all file parts have been marked as finished +// and false otherwise. +func (st *SentTransfer) FinishTransfer(rid id.Round) (bool, error) { st.mux.Lock() defer st.mux.Unlock() @@ -518,13 +515,13 @@ func (st *SentTransfer) FinishTransfer(rid id.Round) error { // exist partNums, exists := st.inProgressTransfers.getPartNums(rid) if !exists { - return errors.Errorf(noPartsForRoundErr, rid) + return false, errors.Errorf(noPartsForRoundErr, rid) } // Delete the parts from the in-progress list err := st.inProgressTransfers.deletePartNums(rid) if err != nil { - return errors.Errorf(deleteInProgressPartsErr, rid, err) + return false, errors.Errorf(deleteInProgressPartsErr, rid, err) } // Unset parts as in-progress in status vector @@ -533,7 +530,7 @@ func (st *SentTransfer) FinishTransfer(rid id.Round) error { // Add the parts to the finished list err = st.finishedTransfers.addPartNums(rid, partNums...) if err != nil { - return err + return false, err } // Set parts as finished in status vector @@ -543,10 +540,45 @@ func (st *SentTransfer) FinishTransfer(rid id.Round) error { // to stopping if st.finishedTransfers.getNumParts() == st.numParts && st.inProgressTransfers.getNumParts() == 0 { - st.status = stopping + st.status = Stopping + return true, nil } - return nil + return false, nil +} + +// GetUnsentPartNums returns a list of part numbers that have not been sent. +func (st *SentTransfer) GetUnsentPartNums() []uint16 { + st.mux.RLock() + defer st.mux.RUnlock() + + // Initialize list with a capacity for all unsent parts + numUnsentParts := uint32(st.numParts) - st.inProgressStatus.GetNumUsed() - + st.finishedStatus.GetNumUsed() + unsentPartNums := make([]uint16, 0, numUnsentParts) + + // Loop through each part and add it to the list if it is not marked as + // in-progress or finished + for i := uint16(0); i < st.numParts; i++ { + if !st.inProgressStatus.Used(uint32(i)) && + !st.finishedStatus.Used(uint32(i)) { + unsentPartNums = append(unsentPartNums, i) + } + } + + return unsentPartNums +} + +// GetSentRounds returns a list of round IDs that parts were sent on (in- +// progress parts) that were never marked as finished. +func (st *SentTransfer) GetSentRounds() []id.Round { + sentRounds := make([]id.Round, 0, len(st.inProgressTransfers.list)) + + for rid := range st.inProgressTransfers.list { + sentRounds = append(sentRounds, rid) + } + + return sentRounds } //////////////////////////////////////////////////////////////////////////////// @@ -698,8 +730,6 @@ func (st *SentTransfer) deleteInfo() error { return st.kv.Delete(sentTransferKey, sentTransferVersion) } -// marshal serializes the transfer key, numParts, and numFps. - // marshal serializes all primitive fields in SentTransfer (recipient, key, // numParts, numFps, and status). func (st *SentTransfer) marshal() []byte { @@ -729,9 +759,7 @@ func (st *SentTransfer) marshal() []byte { buff.Write(b) // Write the transfer status to the buffer - b = make([]byte, 8) - binary.LittleEndian.PutUint64(b, uint64(st.status)) - buff.Write(b) + buff.Write(st.status.Marshal()) // Return the serialized data return buff.Bytes() @@ -740,7 +768,7 @@ func (st *SentTransfer) marshal() []byte { // unmarshalSentTransfer deserializes a byte slice into the primitive fields // of SentTransfer (recipient, key, numParts, numFps, and status). func unmarshalSentTransfer(b []byte) (recipient *id.ID, - key ftCrypto.TransferKey, numParts, numFps uint16, status transferStatus) { + key ftCrypto.TransferKey, numParts, numFps uint16, status TransferStatus) { buff := bytes.NewBuffer(b) @@ -758,7 +786,7 @@ func unmarshalSentTransfer(b []byte) (recipient *id.ID, numFps = binary.LittleEndian.Uint16(buff.Next(2)) // Read the transfer status from the buffer - status = transferStatus(binary.LittleEndian.Uint64(buff.Next(8))) + status = UnmarshalTransferStatus(buff.Next(8)) return recipient, key, numParts, numFps, status } diff --git a/storage/fileTransfer/sentTransfer_test.go b/storage/fileTransfer/sentTransfer_test.go index e25307c794c1cc6d2f7d140af9faabc9d4dff8bf..5ebd94c7fb1b818a9ed4756defdb9782dc739760 100644 --- a/storage/fileTransfer/sentTransfer_test.go +++ b/storage/fileTransfer/sentTransfer_test.go @@ -21,6 +21,7 @@ import ( "gitlab.com/xx_network/primitives/netTime" "math/rand" "reflect" + "sort" "strconv" "strings" "sync" @@ -104,7 +105,7 @@ func Test_NewSentTransfer(t *testing.T) { progressCallbacks: []*sentCallbackTracker{ newSentCallbackTracker(cb, expectedPeriod), }, - status: running, + status: Running, kv: kvPrefixed, } @@ -232,7 +233,7 @@ func TestSentTransfer_ReInit(t *testing.T) { progressCallbacks: []*sentCallbackTracker{ newSentCallbackTracker(cb, expectedPeriod), }, - status: running, + status: Running, kv: kvPrefixed, } @@ -380,6 +381,37 @@ func TestSentTransfer_GetNumAvailableFps(t *testing.T) { } } +// Tests that SentTransfer.GetStatus returns the expected status at each stage +// of the transfer. +func TestSentTransfer_GetStatus(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + numParts, numFps := uint16(2), uint16(4) + _, st := newRandomSentTransfer(numParts, numFps, kv, t) + + status := st.GetStatus() + if status != Running { + t.Errorf("Unexpected transfer status.\nexpected: %s\nreceived: %s", + Running, status) + } + + _, _ = st.SetInProgress(0, 0, 1) + _, _ = st.FinishTransfer(0) + + status = st.GetStatus() + if status != Stopping { + t.Errorf("Unexpected transfer status.\nexpected: %s\nreceived: %s", + Stopping, status) + } + + st.CallProgressCB(nil) + + status = st.GetStatus() + if status != Stopped { + t.Errorf("Unexpected transfer status.\nexpected: %s\nreceived: %s", + Stopped, status) + } +} + // Tests that SentTransfer.IsPartInProgress returns false before a part is set // as in-progress and true after it is set via SentTransfer.SetInProgress. Also // tests that it returns false after the part has been unset via @@ -431,7 +463,7 @@ func TestSentTransfer_IsPartFinished(t *testing.T) { } // Set the part number to finished - _ = st.FinishTransfer(rid) + _, _ = st.FinishTransfer(rid) // Test that the part has been set to finished if !st.IsPartFinished(partNum) { @@ -472,7 +504,7 @@ func TestSentTransfer_GetProgress(t *testing.T) { } checkSentTracker(track, st.numParts, []uint16{0, 1, 2, 3, 4, 5}, nil, t) - _ = st.FinishTransfer(1) + _, _ = st.FinishTransfer(1) _, _ = st.UnsetInProgress(2) completed, sent, arrived, total, track = st.GetProgress() @@ -493,7 +525,7 @@ func TestSentTransfer_GetProgress(t *testing.T) { checkSentTracker(track, st.numParts, []uint16{6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, []uint16{0, 1, 2}, t) - _ = st.FinishTransfer(3) + _, _ = st.FinishTransfer(3) _, _ = st.SetInProgress(4, 3, 4, 5) completed, sent, arrived, total, track = st.GetProgress() @@ -505,7 +537,7 @@ func TestSentTransfer_GetProgress(t *testing.T) { checkSentTracker(track, st.numParts, []uint16{3, 4, 5}, []uint16{0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, t) - _ = st.FinishTransfer(4) + _, _ = st.FinishTransfer(4) completed, sent, arrived, total, track = st.GetProgress() err = checkSentProgress(completed, sent, arrived, total, true, 0, 16, numParts) @@ -603,8 +635,8 @@ func TestSentTransfer_CallProgressCB(t *testing.T) { _, _ = st.SetInProgress(1, 3, 4, 5) _, _ = st.SetInProgress(2, 6, 7, 8) _, _ = st.UnsetInProgress(1) - _ = st.FinishTransfer(0) - _ = st.FinishTransfer(2) + _, _ = st.FinishTransfer(0) + _, _ = st.FinishTransfer(2) st.CallProgressCB(nil) @@ -612,7 +644,7 @@ func TestSentTransfer_CallProgressCB(t *testing.T) { } _, _ = st.SetInProgress(4, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15) - _ = st.FinishTransfer(4) + _, _ = st.FinishTransfer(4) st.CallProgressCB(nil) @@ -1020,11 +1052,16 @@ func TestSentTransfer_FinishTransfer(t *testing.T) { } // Move transfers to the finished list - err = st.FinishTransfer(rid) + complete, err := st.FinishTransfer(rid) if err != nil { t.Errorf("FinishTransfer returned an error: %+v", err) } + // Ensure the transfer is not reported as complete + if complete { + t.Error("FinishTransfer reported transfer as complete.") + } + // Check that the round ID is not in the in-progress map _, exists := st.inProgressTransfers.list[rid] if exists { @@ -1072,6 +1109,42 @@ func TestSentTransfer_FinishTransfer(t *testing.T) { } } +// Tests that SentTransfer.FinishTransfer returns true and sets the status to +// stopping when all file parts are marked as complete. +func TestSentTransfer_FinishTransfer_Complete(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + + rid := id.Round(5) + expectedPartNums := make([]uint16, st.numParts) + for i := range expectedPartNums { + expectedPartNums[i] = uint16(i) + } + + // Add parts to the in-progress list + err, _ := st.SetInProgress(rid, expectedPartNums...) + if err != nil { + t.Errorf("Failed to add parts to in-progress list: %+v", err) + } + + // Move transfers to the finished list + complete, err := st.FinishTransfer(rid) + if err != nil { + t.Errorf("FinishTransfer returned an error: %+v", err) + } + + // Ensure the transfer is not reported as complete + if !complete { + t.Error("FinishTransfer reported transfer as not complete.") + } + + // Test that the status is correctly set + if st.status != Stopping { + t.Errorf("Status not set to expected value when transfer is complete."+ + "\nexpected: %s\nreceived: %s", Stopping, st.status) + } +} + // Error path: tests that SentTransfer.FinishTransfer returns the expected error // when the round ID is found in the in-progress map. func TestSentTransfer_FinishTransfer_NoRoundErr(t *testing.T) { @@ -1082,11 +1155,83 @@ func TestSentTransfer_FinishTransfer_NoRoundErr(t *testing.T) { expectedErr := fmt.Sprintf(noPartsForRoundErr, rid) // Move transfers to the finished list - err := st.FinishTransfer(rid) + complete, err := st.FinishTransfer(rid) if err == nil || err.Error() != expectedErr { t.Errorf("Did not get expected error when round ID not in in-progress "+ "map.\nexpected: %s\nreceived: %+v", expectedErr, err) } + + // Ensure the transfer is not reported as complete + if complete { + t.Error("FinishTransfer reported transfer as complete.") + } +} + +// Tests that SentTransfer.GetUnsentPartNums returns only part numbers that are +// not marked as in-progress or finished. +func TestSentTransfer_GetUnsentPartNums(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(18, 27, kv, t) + + expectedPartNums := make([]uint16, 0, st.numParts/3) + + // Loop through each part and set it individually + for i := uint16(0); i < st.numParts; i++ { + switch i % 3 { + case 0: + // Part is sent (in-progress) + _, _ = st.SetInProgress(id.Round(i), i) + case 1: + // Part is sent and arrived (finished) + _, _ = st.SetInProgress(id.Round(i), i) + _, _ = st.FinishTransfer(id.Round(i)) + case 2: + // Part is unsent (neither in-progress nor arrived) + expectedPartNums = append(expectedPartNums, i) + } + } + + unsentPartNums := st.GetUnsentPartNums() + if !reflect.DeepEqual(expectedPartNums, unsentPartNums) { + t.Errorf("Unexpected unsent part numbers.\nexpected: %d\nreceived: %d", + expectedPartNums, unsentPartNums) + } +} + +// Tests that SentTransfer.GetSentRounds returns the expected round IDs when +// every round is either in-progress, finished, or unsent. +func TestSentTransfer_GetSentRounds(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(18, 27, kv, t) + + expectedRounds := make([]id.Round, 0, st.numParts/3) + + // Loop through each part and set it individually + for i := uint16(0); i < st.numParts; i++ { + rid := id.Round(i) + switch i % 3 { + case 0: + // Part is sent (in-progress) + _, _ = st.SetInProgress(rid, i) + expectedRounds = append(expectedRounds, rid) + case 1: + // Part is sent and arrived (finished) + _, _ = st.SetInProgress(rid, i) + _, _ = st.FinishTransfer(rid) + case 2: + // Part is unsent (neither in-progress nor arrived) + } + } + + // Get the sent + sentRounds := st.GetSentRounds() + sort.SliceStable(sentRounds, + func(i, j int) bool { return sentRounds[i] < sentRounds[j] }) + + if !reflect.DeepEqual(expectedRounds, sentRounds) { + t.Errorf("Unexpected sent rounds.\nexpected: %d\nreceived: %d", + expectedRounds, sentRounds) + } } //////////////////////////////////////////////////////////////////////////////// @@ -1107,7 +1252,7 @@ func Test_loadSentTransfer(t *testing.T) { t.Errorf("Failed to add parts to in-progress transfer: %+v", err) } - err = expectedST.FinishTransfer(10) + _, err = expectedST.FinishTransfer(10) if err != nil { t.Errorf("Failed to move parts to finished transfer: %+v", err) } @@ -1117,6 +1262,7 @@ func Test_loadSentTransfer(t *testing.T) { t.Errorf("loadSentTransfer returned an error: %+v", err) } + // Progress callbacks cannot be compared loadedST.progressCallbacks = expectedST.progressCallbacks if !reflect.DeepEqual(expectedST, loadedST) { @@ -1342,7 +1488,7 @@ func TestSentTransfer_delete(t *testing.T) { } // Mark last in-progress transfer and finished - err = st.FinishTransfer(10) + _, err = st.FinishTransfer(10) if err != nil { t.Errorf("Failed to move parts to finished transfer: %+v", err) } @@ -1440,7 +1586,7 @@ func TestSentTransfer_marshal_unmarshalSentTransfer(t *testing.T) { key: ftCrypto.UnmarshalTransferKey([]byte("key")), numParts: 16, numFps: 20, - status: stopped, + status: Stopped, } marshaledData := st.marshal() @@ -1469,7 +1615,7 @@ func TestSentTransfer_marshal_unmarshalSentTransfer(t *testing.T) { if st.status != status { t.Errorf("Failed to get expected transfer status."+ - "\nexpected: %d\nreceived: %d", st.status, status) + "\nexpected: %s\nreceived: %s", st.status, status) } } @@ -1583,10 +1729,10 @@ func checkSentTracker(track SentPartTracker, numParts uint16, inProgress, var done bool for _, inProgressNum := range inProgress { if inProgressNum == partNum { - if track.GetPartStatus(partNum) != sentStatus { + if track.GetPartStatus(partNum) != interfaces.FpSent { t.Errorf("Part number %d has unexpected status."+ - "\nexpected: %d\nreceived: %d", partNum, sentStatus, - track.GetPartStatus(partNum)) + "\nexpected: %d\nreceived: %d", partNum, + interfaces.FpSent, track.GetPartStatus(partNum)) } done = true break @@ -1598,10 +1744,10 @@ func checkSentTracker(track SentPartTracker, numParts uint16, inProgress, for _, finishedNum := range finished { if finishedNum == partNum { - if track.GetPartStatus(partNum) != arrivedStatus { + if track.GetPartStatus(partNum) != interfaces.FpArrived { t.Errorf("Part number %d has unexpected status."+ - "\nexpected: %d\nreceived: %d", partNum, arrivedStatus, - track.GetPartStatus(partNum)) + "\nexpected: %d\nreceived: %d", partNum, + interfaces.FpArrived, track.GetPartStatus(partNum)) } done = true break @@ -1611,10 +1757,10 @@ func checkSentTracker(track SentPartTracker, numParts uint16, inProgress, continue } - if track.GetPartStatus(partNum) != unsentStatus { + if track.GetPartStatus(partNum) != interfaces.FpUnsent { t.Errorf("Part number %d has incorrect status."+ "\nexpected: %d\nreceived: %d", - partNum, unsentStatus, track.GetPartStatus(partNum)) + partNum, interfaces.FpUnsent, track.GetPartStatus(partNum)) } } } diff --git a/storage/fileTransfer/transferStatus.go b/storage/fileTransfer/transferStatus.go new file mode 100644 index 0000000000000000000000000000000000000000..f223d735471ada1ec2e942b278e8330e80de4e08 --- /dev/null +++ b/storage/fileTransfer/transferStatus.go @@ -0,0 +1,52 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "encoding/binary" + "strconv" +) + +// TransferStatus indicates the state of the transfer. +type TransferStatus int + +const ( + Running TransferStatus = iota // Sending parts + Stopping // Sent last part but not callback + Stopped // Sent last part and callback +) + +const invalidTransferStatusStringErr = "INVALID TransferStatus: " + +// String prints the string representation of the TransferStatus. This function +// satisfies the fmt.Stringer interface. +func (ts TransferStatus) String() string { + switch ts { + case Running: + return "running" + case Stopping: + return "stopping" + case Stopped: + return "stopped" + default: + return invalidTransferStatusStringErr + strconv.Itoa(int(ts)) + } +} + +// Marshal returns the byte representation of the TransferStatus. +func (ts TransferStatus) Marshal() []byte { + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(ts)) + return b +} + +// UnmarshalTransferStatus unmarshalls the 8-byte byte slice into a +// TransferStatus. +func UnmarshalTransferStatus(b []byte) TransferStatus { + return TransferStatus(binary.LittleEndian.Uint64(b)) +} diff --git a/storage/fileTransfer/transferStatus_test.go b/storage/fileTransfer/transferStatus_test.go new file mode 100644 index 0000000000000000000000000000000000000000..4f8bcd0a583309698cc25ff969539c03b5b2365a --- /dev/null +++ b/storage/fileTransfer/transferStatus_test.go @@ -0,0 +1,47 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "strconv" + "testing" +) + +// Tests that TransferStatus.String returns the expected string for each value +// of TransferStatus. +func Test_TransferStatus_String(t *testing.T) { + testValues := map[TransferStatus]string{ + Running: "running", + Stopping: "stopping", + Stopped: "stopped", + 100: invalidTransferStatusStringErr + strconv.Itoa(100), + } + + for status, expected := range testValues { + if expected != status.String() { + t.Errorf("TransferStatus string incorrect."+ + "\nexpected: %s\nreceived: %s", expected, status.String()) + } + } +} + +// Tests that a marshalled and unmarshalled TransferStatus matches the original. +func Test_TransferStatus_Marshal_UnmarshalTransferStatus(t *testing.T) { + testValues := []TransferStatus{Running, Stopping, Stopped} + + for _, status := range testValues { + marshalledStatus := status.Marshal() + + newStatus := UnmarshalTransferStatus(marshalledStatus) + + if status != newStatus { + t.Errorf("Marshalled and unmarshalled TransferStatus does not "+ + "match original.\nexpected: %s\nreceived: %s", status, newStatus) + } + } +} diff --git a/storage/fileTransfer/transferredBundle_test.go b/storage/fileTransfer/transferredBundle_test.go index e5198d493738565618af2122b0a1268b6aff2386..62b3916e53608f7a60994a8a6a2f9a125d5659e3 100644 --- a/storage/fileTransfer/transferredBundle_test.go +++ b/storage/fileTransfer/transferredBundle_test.go @@ -311,7 +311,7 @@ func Test_transferredBundle_marshal_unmarshal(t *testing.T) { tb.unmarshal(b) if !reflect.DeepEqual(expectedTB, tb) { - t.Errorf("Failed to marshal and unmarshal transferredBundle into original."+ - "\nexpected: %+v\nreceived: %+v", expectedTB, tb) + t.Errorf("Failed to marshal and unmarshal transferredBundle into "+ + "original.\nexpected: %+v\nreceived: %+v", expectedTB, tb) } }