From d6d8445e7baec3b955d50b5f82262d56109465ed Mon Sep 17 00:00:00 2001 From: Jono Wenger <jono@elixxir.io> Date: Wed, 24 Nov 2021 20:38:36 +0000 Subject: [PATCH] File part status and shuffled part send --- bindings/fileTransfer.go | 68 ++-- cmd/fileTransfer.go | 6 +- fileTransfer/manager.go | 83 ++++- fileTransfer/manager_test.go | 82 +++-- fileTransfer/receive_test.go | 10 +- fileTransfer/send.go | 25 +- fileTransfer/send_test.go | 30 +- fileTransfer/utils_test.go | 6 +- interfaces/fileTransfer.go | 31 +- storage/fileTransfer/partStore.go | 5 +- storage/fileTransfer/receiveTransfer.go | 57 +++- storage/fileTransfer/receiveTransfer_test.go | 142 ++++++++- .../fileTransfer/receivedCallbackTracker.go | 6 +- .../receivedCallbackTracker_test.go | 20 +- storage/fileTransfer/receivedPartTracker.go | 45 +++ .../fileTransfer/receivedPartTracker_test.go | 89 ++++++ storage/fileTransfer/sentCallbackTracker.go | 12 +- .../fileTransfer/sentCallbackTracker_test.go | 24 +- storage/fileTransfer/sentPartTracker.go | 62 ++++ storage/fileTransfer/sentPartTracker_test.go | 93 ++++++ storage/fileTransfer/sentTransfer.go | 182 +++++++++-- storage/fileTransfer/sentTransfer_test.go | 299 ++++++++++++++++-- storage/utility/stateVector.go | 56 ++++ storage/utility/stateVector_test.go | 77 ++++- 24 files changed, 1323 insertions(+), 187 deletions(-) create mode 100644 storage/fileTransfer/receivedPartTracker.go create mode 100644 storage/fileTransfer/receivedPartTracker_test.go create mode 100644 storage/fileTransfer/sentPartTracker.go create mode 100644 storage/fileTransfer/sentPartTracker_test.go diff --git a/bindings/fileTransfer.go b/bindings/fileTransfer.go index 0467fc9f8..037686cb4 100644 --- a/bindings/fileTransfer.go +++ b/bindings/fileTransfer.go @@ -10,6 +10,7 @@ package bindings import ( "encoding/json" ft "gitlab.com/elixxir/client/fileTransfer" + "gitlab.com/elixxir/client/interfaces" ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" "gitlab.com/xx_network/primitives/id" "time" @@ -24,14 +25,16 @@ type FileTransfer struct { // progress of sending a file. It is called when a file part is sent, a file // part arrives, the transfer completes, or on error. type FileTransferSentProgressFunc interface { - SentProgressCallback(completed bool, sent, arrived, total int, err error) + SentProgressCallback(completed bool, sent, arrived, total int, + t *FilePartTracker, err error) } // FileTransferReceivedProgressFunc contains a function callback that tracks the // progress of receiving a file. It is called when a file part is received, the // transfer completes, or on error. type FileTransferReceivedProgressFunc interface { - ReceivedProgressCallback(completed bool, received, total int, err error) + ReceivedProgressCallback(completed bool, received, total int, + t *FilePartTracker, err error) } // FileTransferReceiveFunc contains a function callback that notifies the @@ -100,9 +103,10 @@ func (f *FileTransfer) Send(fileName, fileType string, fileData []byte, progressFunc FileTransferSentProgressFunc, periodMS int) ([]byte, error) { // Create SentProgressCallback - progressCB := func(completed bool, sent, arrived, total uint16, err error) { + progressCB := func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { progressFunc.SentProgressCallback( - completed, int(sent), int(arrived), int(total), err) + completed, int(sent), int(arrived), int(total), &FilePartTracker{t}, err) } // Convert recipient ID bytes to id.ID @@ -141,9 +145,10 @@ func (f *FileTransfer) RegisterSendProgressCallback(transferID []byte, tid := ftCrypto.UnmarshalTransferID(transferID) // Create SentProgressCallback - progressCB := func(completed bool, sent, arrived, total uint16, err error) { + progressCB := func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { progressFunc.SentProgressCallback( - completed, int(sent), int(arrived), int(total), err) + completed, int(sent), int(arrived), int(total), &FilePartTracker{t}, err) } // Convert period to time.Duration @@ -171,6 +176,17 @@ func (f *FileTransfer) CloseSend(transferID []byte) error { return f.m.CloseSend(tid) } +// Receive returns the fully assembled file on the completion of the transfer. +// It deletes the transfer from the received transfer map and from storage. +// Returns an error if the transfer is not complete, the full file cannot be +// verified, or if the transfer cannot be found. +func (f *FileTransfer) Receive(transferID []byte) ([]byte, error) { + // Unmarshal transfer ID + tid := ftCrypto.UnmarshalTransferID(transferID) + + return f.m.Receive(tid) +} + // RegisterReceiveProgressCallback allows for the registration of a callback to // track the progress of an individual received file transfer. The callback will // be called immediately when added to report the current status of the @@ -187,9 +203,10 @@ func (f *FileTransfer) RegisterReceiveProgressCallback(transferID []byte, tid := ftCrypto.UnmarshalTransferID(transferID) // Create ReceivedProgressCallback - progressCB := func(completed bool, received, total uint16, err error) { + progressCB := func(completed bool, received, total uint16, + t interfaces.FilePartTracker, err error) { progressFunc.ReceivedProgressCallback( - completed, int(received), int(total), err) + completed, int(received), int(total), &FilePartTracker{t}, err) } // Convert period to time.Duration @@ -198,17 +215,6 @@ func (f *FileTransfer) RegisterReceiveProgressCallback(transferID []byte, return f.m.RegisterReceiveProgressCallback(tid, progressCB, period) } -// Receive returns the fully assembled file on the completion of the transfer. -// It deletes the transfer from the received transfer map and from storage. -// Returns an error if the transfer is not complete, the full file cannot be -// verified, or if the transfer cannot be found. -func (f *FileTransfer) Receive(transferID []byte) ([]byte, error) { - // Unmarshal transfer ID - tid := ftCrypto.UnmarshalTransferID(transferID) - - return f.m.Receive(tid) -} - //////////////////////////////////////////////////////////////////////////////// // Utility Functions // //////////////////////////////////////////////////////////////////////////////// @@ -235,3 +241,27 @@ func (f *FileTransfer) GetMaxFileTypeByteLength() int { func (f *FileTransfer) GetMaxFileSize() int { return ft.FileMaxSize } + +//////////////////////////////////////////////////////////////////////////////// +// File Part Tracker // +//////////////////////////////////////////////////////////////////////////////// + +// FilePartTracker contains the interfaces.FilePartTracker. +type FilePartTracker struct { + m interfaces.FilePartTracker +} + +// GetPartStatus returns the status of the file part with the given part number. +// The possible values for the status are: +// 0 = unsent +// 1 = sent (sender has sent a part, but it has not arrived) +// 2 = arrived (sender has sent a part, and it has arrived) +// 3 = received (receiver has received a part) +func (fpt *FilePartTracker) GetPartStatus(partNum int) int { + return fpt.m.GetPartStatus(uint16(partNum)) +} + +// GetNumParts returns the total number of file parts in the transfer. +func (fpt *FilePartTracker) GetNumParts() int { + return int(fpt.m.GetNumParts()) +} diff --git a/cmd/fileTransfer.go b/cmd/fileTransfer.go index 71ae08423..17e2a0d7c 100644 --- a/cmd/fileTransfer.go +++ b/cmd/fileTransfer.go @@ -193,7 +193,8 @@ func sendFile(filePath, fileType, filePreviewPath, filePreviewString, filePath, len(fileData), recipient.ID) // Create sent progress callback that prints the results - progressCB := func(completed bool, sent, arrived, total uint16, err error) { + progressCB := func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { jww.DEBUG.Printf("Sent progress callback for %q "+ "{completed: %t, sent: %d, arrived: %d, total: %d, err: %v}\n", filePath, completed, sent, arrived, total, err) @@ -253,7 +254,8 @@ func receiveNewFileTransfers(receive chan receivedFtResults, done, // the results to the log. func newReceiveProgressCB(tid ftCrypto.TransferID, done chan struct{}, m *ft.Manager) interfaces.ReceivedProgressCallback { - return func(completed bool, received, total uint16, err error) { + return func(completed bool, received, total uint16, + t interfaces.FilePartTracker, err error) { jww.DEBUG.Printf("Receive progress callback for transfer %s "+ "{completed: %t, received: %d, total: %d, err: %v}", tid, completed, received, total, err) diff --git a/fileTransfer/manager.go b/fileTransfer/manager.go index 00dd39e0c..994a072b8 100644 --- a/fileTransfer/manager.go +++ b/fileTransfer/manager.go @@ -52,6 +52,14 @@ const ( networkHealthBuffLen = 10_000 ) +// Part status constants. +const ( + unsent = 0 + sent = 1 + arrived = 2 + received = 1 +) + // Error messages. const ( // newManager @@ -313,10 +321,33 @@ func (m Manager) RegisterSendProgressCallback(tid ftCrypto.TransferID, return nil } +// GetSentPartStatus returns the status of the sent file part number for the +// given transfer ID. An error is returned if the sent transfer does not +// exist. The possible values for the status are: +// 0 = unsent +// 1 = sent +// 2 = arrived +// TODO: test +func (m Manager) GetSentPartStatus(tid ftCrypto.TransferID, partNum uint16) (int, error) { + // Get the transfer for the given ID + transfer, err := m.sent.GetTransfer(tid) + if err != nil { + return unsent, err + } + + if transfer.IsPartInProgress(partNum) { + return sent, nil + } else if transfer.IsPartFinished(partNum) { + return arrived, nil + } else { + return unsent, nil + } +} + // Resend resends a file if sending fails. Returns an error if CloseSend // was already called or if the transfer did not run out of retries. This -//function should only be called if the interfaces.SentProgressCallback returns -//an error. +// function should only be called if the interfaces.SentProgressCallback returns +// an error. // TODO: add test // TODO: write test // TODO: can you resend? Can you reuse fingerprints? @@ -349,7 +380,7 @@ func (m Manager) CloseSend(tid ftCrypto.TransferID) error { // Check if the transfer has completed or run out of fingerprints, which // occurs when the retry limit is reached - completed, _, _, _ := transfer.GetProgress() + completed, _, _, _, _ := transfer.GetProgress() if transfer.GetNumAvailableFps() > 0 && !completed { return errors.Errorf(transferInProgressErr, tid) } @@ -358,6 +389,27 @@ func (m Manager) CloseSend(tid ftCrypto.TransferID) error { return m.sent.DeleteTransfer(tid) } +// Receive returns the fully assembled file on the completion of the transfer. +// It deletes the transfer from the received transfer map and from storage. +// Returns an error if the transfer is not complete, the full file cannot be +// verified, or if the transfer cannot be found. +func (m Manager) Receive(tid ftCrypto.TransferID) ([]byte, error) { + // Get the transfer for the given ID + transfer, err := m.received.GetTransfer(tid) + if err != nil { + return nil, err + } + + // Get the file from the transfer + file, err := transfer.GetFile() + if err != nil { + return nil, err + } + + // Return the file and delete the transfer from storage + return file, m.received.DeleteTransfer(tid) +} + // RegisterReceiveProgressCallback adds the reception progress callback to the // received transfer so that it will be called when updates for the transfer // occur. The progress callback is called when initially added and on transfer @@ -376,25 +428,24 @@ func (m Manager) RegisterReceiveProgressCallback(tid ftCrypto.TransferID, return nil } -// Receive returns the fully assembled file on the completion of the transfer. -// It deletes the transfer from the received transfer map and from storage. -// Returns an error if the transfer is not complete, the full file cannot be -// verified, or if the transfer cannot be found. -func (m Manager) Receive(tid ftCrypto.TransferID) ([]byte, error) { +// GetReceivedPartStatus returns the status of the received file part number +// for the given transfer ID. An error is returned if the received transfer +// does not exist. The possible values for the status are: +// 0 = unsent +// 1 = received +// TODO: test +func (m Manager) GetReceivedPartStatus(tid ftCrypto.TransferID, partNum uint16) (int, error) { // Get the transfer for the given ID transfer, err := m.received.GetTransfer(tid) if err != nil { - return nil, err + return unsent, err } - // Get the file from the transfer - file, err := transfer.GetFile() - if err != nil { - return nil, err + if transfer.IsPartReceived(partNum) { + return received, nil + } else { + return unsent, nil } - - // Return the file and delete the transfer from storage - return file, m.received.DeleteTransfer(tid) } // calcNumberOfFingerprints is the formula used to calculate the number of diff --git a/fileTransfer/manager_test.go b/fileTransfer/manager_test.go index de5f9d948..8d88f058a 100644 --- a/fileTransfer/manager_test.go +++ b/fileTransfer/manager_test.go @@ -240,7 +240,8 @@ func TestManager_RegisterSendProgressCallback(t *testing.T) { // Create new callback and channel for the callback to trigger cbChan := make(chan sentProgressResults, 6) - cb := func(completed bool, sent, arrived, total uint16, err error) { + cb := func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { cbChan <- sentProgressResults{completed, sent, arrived, total, err} } @@ -416,6 +417,31 @@ func TestManager_CloseSend_NotCompleteErr(t *testing.T) { } } +// Error path: tests that Manager.Receive returns an error when no transfer with +// the ID exists. +func TestManager_Receive_NoTransferError(t *testing.T) { + m := newTestManager(false, nil, nil, nil, t) + tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) + + _, err := m.Receive(tid) + if err == nil { + t.Error("Receive did not return an error when no transfer with the ID " + + "exists.") + } +} + +// Error path: tests that Manager.Receive returns an error when the file is +// incomplete. +func TestManager_Receive_GetFileError(t *testing.T) { + m, _, rti := newTestManagerWithTransfers([]uint16{12, 4, 1}, false, nil, t) + + _, err := m.Receive(rti[0].tid) + if err == nil || !strings.Contains(err.Error(), "missing") { + t.Error("Receive did not return the expected error when no transfer " + + "with the ID exists.") + } +} + // Tests that Manager.RegisterReceiveProgressCallback calls the callback when it // is added to the transfer and that the callback is associated with the // expected transfer and is called when calling from the transfer. @@ -425,7 +451,8 @@ func TestManager_RegisterReceiveProgressCallback(t *testing.T) { // Create new callback and channel for the callback to trigger cbChan := make(chan receivedProgressResults, 6) - cb := func(completed bool, received, total uint16, err error) { + cb := func(completed bool, received, total uint16, + t interfaces.FilePartTracker, err error) { cbChan <- receivedProgressResults{completed, received, total, err} } @@ -485,31 +512,6 @@ func TestManager_RegisterReceiveProgressCallback_NoTransferError(t *testing.T) { } } -// Error path: tests that Manager.Receive returns an error when no transfer with -// the ID exists. -func TestManager_Receive_NoTransferError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, t) - tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) - - _, err := m.Receive(tid) - if err == nil { - t.Error("Receive did not return an error when no transfer with the ID " + - "exists.") - } -} - -// Error path: tests that Manager.Receive returns an error when the file is -// incomplete. -func TestManager_Receive_GetFileError(t *testing.T) { - m, _, rti := newTestManagerWithTransfers([]uint16{12, 4, 1}, false, nil, t) - - _, err := m.Receive(rti[0].tid) - if err == nil || !strings.Contains(err.Error(), "missing") { - t.Error("Receive did not return the expected error when no transfer " + - "with the ID exists.") - } -} - // Tests that calcNumberOfFingerprints matches some manually calculated // results. func Test_calcNumberOfFingerprints(t *testing.T) { @@ -576,7 +578,8 @@ func Test_FileTransfer(t *testing.T) { // Create progress tracker for sending sentCbChan := make(chan sentProgressResults, 20) - sentCb := func(completed bool, sent, arrived, total uint16, err error) { + sentCb := func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { sentCbChan <- sentProgressResults{completed, sent, arrived, total, err} } @@ -633,7 +636,8 @@ func Test_FileTransfer(t *testing.T) { // Register progress callback with receiving manager receiveCbChan := make(chan receivedProgressResults, 100) - receiveCb := func(completed bool, received, total uint16, err error) { + receiveCb := func(completed bool, received, total uint16, + t interfaces.FilePartTracker, err error) { receiveCbChan <- receivedProgressResults{completed, received, total, err} } @@ -647,6 +651,26 @@ func Test_FileTransfer(t *testing.T) { t.Errorf("Timed out waiting for receive progress callback %d.", i) case r := <-receiveCbChan: if r.completed { + // Count the number of parts marked as received + count := 0 + for j := uint16(0); j < r.total; j++ { + status, err2 := m2.GetReceivedPartStatus(receiveTid, j) + if err2 != nil { + t.Errorf("Failed to get part %d status: %+v", j, err2) + } + if status == received { + count++ + } + } + + // Ensure that the number of parts received reported by the + // callback matches the number marked received + if count != int(r.received) { + t.Errorf("Number of parts marked received does not match "+ + "number reported by callback.\nmarked: %d\ncallback: %d", + count, r.received) + } + return } } diff --git a/fileTransfer/receive_test.go b/fileTransfer/receive_test.go index e8d732d31..9d56109e8 100644 --- a/fileTransfer/receive_test.go +++ b/fileTransfer/receive_test.go @@ -9,6 +9,7 @@ package fileTransfer import ( "bytes" + "gitlab.com/elixxir/client/interfaces" "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/stoppable" ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" @@ -58,7 +59,8 @@ func TestManager_receive(t *testing.T) { err error } cbChan := make(chan progressResults) - cb := func(completed bool, received, total uint16, err error) { + cb := func(completed bool, received, total uint16, + t interfaces.FilePartTracker, err error) { cbChan <- progressResults{completed, received, total, err} } @@ -175,7 +177,8 @@ func TestManager_receive_Stop(t *testing.T) { err error } cbChan := make(chan progressResults) - cb := func(completed bool, received, total uint16, err error) { + cb := func(completed bool, received, total uint16, + t interfaces.FilePartTracker, err error) { cbChan <- progressResults{completed, received, total, err} } @@ -280,7 +283,8 @@ func TestManager_readMessage(t *testing.T) { err error } cbChan := make(chan progressResults, 2) - cb := func(completed bool, received, total uint16, err error) { + cb := func(completed bool, received, total uint16, + t interfaces.FilePartTracker, err error) { cbChan <- progressResults{completed, received, total, err} } diff --git a/fileTransfer/send.go b/fileTransfer/send.go index 7d9f9e9b5..4e9fada3c 100644 --- a/fileTransfer/send.go +++ b/fileTransfer/send.go @@ -18,6 +18,7 @@ import ( "gitlab.com/elixxir/client/stoppable" ftStorage "gitlab.com/elixxir/client/storage/fileTransfer" ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/crypto/shuffle" "gitlab.com/elixxir/primitives/format" "gitlab.com/xx_network/crypto/csprng" "gitlab.com/xx_network/primitives/id" @@ -389,14 +390,30 @@ func (m *Manager) makeRoundEventCallback(rid id.Round, tid ftCrypto.TransferID, } } -// queueParts sends an entry for each part in a transfer into the sendQueue -// channel. +// queueParts adds an entry for each file part in a transfer into the sendQueue +// channel in a random order. func (m *Manager) queueParts(tid ftCrypto.TransferID, numParts uint16) { - for i := uint16(0); i < numParts; i++ { - m.sendQueue <- queuedPart{tid, i} + // Add each part number to the buffer in the shuffled order + for _, partNum := range getShuffledPartNumList(numParts) { + m.sendQueue <- queuedPart{tid, uint16(partNum)} } } +// getShuffledPartNumList returns a list of number of file part, from 0 to +// numParts, shuffled in a random order. +func getShuffledPartNumList(numParts uint16) []uint64 { + // Create list of part numbers + partNumList := make([]uint64, numParts) + for i := range partNumList { + partNumList[i] = uint64(i) + } + + // Shuffle list of part numbers + shuffle.Shuffle(&partNumList) + + return partNumList +} + // getPartSize determines the maximum size for each file part in bytes. The size // is calculated based on the content size of a cMix message. Returns an error // if a file part message cannot fit into a cMix message payload. diff --git a/fileTransfer/send_test.go b/fileTransfer/send_test.go index 09474a13d..29172e752 100644 --- a/fileTransfer/send_test.go +++ b/fileTransfer/send_test.go @@ -10,6 +10,7 @@ package fileTransfer import ( "bytes" "fmt" + "gitlab.com/elixxir/client/interfaces" "gitlab.com/elixxir/client/stoppable" ftStorage "gitlab.com/elixxir/client/storage/fileTransfer" "gitlab.com/elixxir/client/storage/versioned" @@ -430,7 +431,8 @@ func TestManager_buildMessages_MessageBuildFailureError(t *testing.T) { } callbackChan := make(chan callbackResults, 10) - progressCB := func(completed bool, sent, arrived, total uint16, err error) { + progressCB := func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { callbackChan <- callbackResults{completed, sent, arrived, total, err} } @@ -609,7 +611,8 @@ func TestManager_makeRoundEventCallback(t *testing.T) { } callbackChan := make(chan callbackResults, 10) - progressCB := func(completed bool, sent, arrived, total uint16, err error) { + progressCB := func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { callbackChan <- callbackResults{completed, sent, arrived, total, err} } @@ -687,7 +690,8 @@ func TestManager_makeRoundEventCallback_RoundFailure(t *testing.T) { } callbackChan := make(chan callbackResults, 10) - progressCB := func(completed bool, sent, arrived, total uint16, err error) { + progressCB := func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { callbackChan <- callbackResults{completed, sent, arrived, total, err} } @@ -842,6 +846,26 @@ func TestManager_queueParts(t *testing.T) { } } +// Tests that getShuffledPartNumList returns a list with all the part numbers. +func Test_getShuffledPartNumList(t *testing.T) { + n := 100 + numList := getShuffledPartNumList(uint16(n)) + + if len(numList) != n { + t.Errorf("Length of shuffled list incorrect."+ + "\nexpected: %d\nreceived: %d", n, len(numList)) + } + + for i := 0; i < n; i++ { + j := 0 + for ; i != int(numList[j]); j++ { + } + if i != int(numList[j]) { + t.Errorf("Failed to find part number %d in shuffled list.", i) + } + } +} + // Tests that the part size returned by Manager.getPartSize matches the manually // calculated part size. func TestManager_getPartSize(t *testing.T) { diff --git a/fileTransfer/utils_test.go b/fileTransfer/utils_test.go index d31e7f9ce..46eb20f72 100644 --- a/fileTransfer/utils_test.go +++ b/fileTransfer/utils_test.go @@ -231,7 +231,8 @@ func newTestManagerWithTransfers(numParts []uint16, sendErr bool, cbChan := make(chan sentProgressResults, 6) - cb := func(completed bool, sent, arrived, total uint16, err error) { + cb := func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { cbChan <- sentProgressResults{completed, sent, arrived, total, err} } @@ -265,7 +266,8 @@ func newTestManagerWithTransfers(numParts []uint16, sendErr bool, cbChan := make(chan receivedProgressResults, 6) - cb := func(completed bool, received, total uint16, err error) { + cb := func(completed bool, received, total uint16, + t interfaces.FilePartTracker, err error) { cbChan <- receivedProgressResults{completed, received, total, err} } diff --git a/interfaces/fileTransfer.go b/interfaces/fileTransfer.go index 382ab098b..498671241 100644 --- a/interfaces/fileTransfer.go +++ b/interfaces/fileTransfer.go @@ -16,12 +16,12 @@ import ( // SentProgressCallback is a callback function that tracks the progress of // sending a file. type SentProgressCallback func(completed bool, sent, arrived, total uint16, - err error) + t FilePartTracker, err error) // ReceivedProgressCallback is a callback function that tracks the progress of // receiving a file. type ReceivedProgressCallback func(completed bool, received, total uint16, - err error) + t FilePartTracker, err error) // ReceiveCallback is a callback function that notifies the receiver of an // incoming file transfer. @@ -63,6 +63,13 @@ type FileTransfer interface { // has not run out of retries. CloseSend(tid ftCrypto.TransferID) error + // Receive returns the full file on the completion of the transfer as + // reported by a registered ReceivedProgressCallback. It deletes internal + // references to the data and unregisters any attached progress callback. + // Returns an error if the transfer is not complete, the full file cannot be + // verified, or if the transfer cannot be found. + Receive(tid ftCrypto.TransferID) ([]byte, error) + // RegisterReceiveProgressCallback allows for the registration of a callback // to track the progress of an individual received file transfer. The // callback will be called immediately when added to report the current @@ -75,11 +82,19 @@ type FileTransfer interface { // can get the full file by calling Receive. RegisterReceiveProgressCallback(tid ftCrypto.TransferID, progressCB ReceivedProgressCallback, period time.Duration) error +} - // Receive returns the full file on the completion of the transfer as - // reported by a registered ReceivedProgressCallback. It deletes internal - // references to the data and unregisters any attached progress callback. - // Returns an error if the transfer is not complete, the full file cannot be - // verified, or if the transfer cannot be found. - Receive(tid ftCrypto.TransferID) ([]byte, error) +// FilePartTracker tracks the status of each file part in a sent or received +// file transfer. +type FilePartTracker interface { + // GetPartStatus returns the status of the file part with the given part + // number. The possible values for the status are: + // 0 = unsent + // 1 = sent (sender has sent a part, but it has not arrived) + // 2 = arrived (sender has sent a part, and it has arrived) + // 3 = received (receiver has received a part) + GetPartStatus(partNum uint16) int + + // GetNumParts returns the total number of file parts in the transfer. + GetNumParts() uint16 } diff --git a/storage/fileTransfer/partStore.go b/storage/fileTransfer/partStore.go index f29655d61..e7c8701a1 100644 --- a/storage/fileTransfer/partStore.go +++ b/storage/fileTransfer/partStore.go @@ -42,7 +42,6 @@ type partStore struct { // newPartStore generates a new empty partStore and saves it to storage. func newPartStore(kv *versioned.KV, numParts uint16) (*partStore, error) { - // Construct empty partStore of the specified size ps := &partStore{ parts: make(map[uint16][]byte, numParts), @@ -131,6 +130,10 @@ func (ps *partStore) len() int { return len(ps.parts) } +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + // loadPartStore loads all the file parts from storage into memory. func loadPartStore(kv *versioned.KV) (*partStore, error) { // Get list of saved file parts diff --git a/storage/fileTransfer/receiveTransfer.go b/storage/fileTransfer/receiveTransfer.go index 7f5a5d897..141a52f4f 100644 --- a/storage/fileTransfer/receiveTransfer.go +++ b/storage/fileTransfer/receiveTransfer.go @@ -26,20 +26,24 @@ const ( receivedTransferKey = "ReceivedTransfer" receivedTransferVersion = 0 receivedFpVectorKey = "ReceivedFingerprintVector" + receivedVectorKey = "ReceivedStatusVector" ) // Error messages for ReceivedTransfer const ( - newReceivedTransferPartStoreErr = "failed to create new part store: %+v" newReceivedTransferFpVectorErr = "failed to create new StateVector for fingerprints: %+v" + newReceivedTransferPartStoreErr = "failed to create new part store: %+v" + newReceivedVectorErr = "failed to create new state vector for received status: %+v" loadReceivedStoreErr = "failed to load received transfer info from storage: %+v" loadReceivePartStoreErr = "failed to load received part store from storage: %+v" + loadReceivedVectorErr = "failed to load new received status state vector from storage: %+v" loadReceiveFpVectorErr = "failed to load received fingerprint vector from storage: %+v" getFileErr = "missing %d/%d parts of the file" getTransferMacErr = "failed to verify transfer MAC" deleteReceivedTransferInfoErr = "failed to delete received transfer info from storage: %+v" deleteReceivedFpVectorErr = "failed to delete received fingerprint vector from storage: %+v" deleteReceivedFilePartsErr = "failed to delete received file parts from storage: %+v" + deleteReceivedVectorErr = "failed to delete received status state vector from storage: %+v" ) // ReceivedTransfer contains information and progress data for receiving an in- @@ -69,6 +73,9 @@ type ReceivedTransfer struct { // Saves each part in order (has its own storage backend) receivedParts *partStore + // Stores the received status for each file part in a bitstream format + receivedStatus *utility.StateVector + // List of callbacks to call for every send progressCallbacks []*receivedCallbackTracker @@ -83,7 +90,7 @@ func NewReceivedTransfer(tid ftCrypto.TransferID, key ftCrypto.TransferKey, kv *versioned.KV) (*ReceivedTransfer, error) { // Create the ReceivedTransfer object - rf := &ReceivedTransfer{ + rt := &ReceivedTransfer{ key: key, transferMAC: transferMAC, fileSize: fileSize, @@ -96,20 +103,27 @@ func NewReceivedTransfer(tid ftCrypto.TransferID, key ftCrypto.TransferKey, var err error // Create new StateVector for storing fingerprint usage - rf.fpVector, err = utility.NewStateVector( - rf.kv, receivedFpVectorKey, uint32(numFps)) + rt.fpVector, err = utility.NewStateVector( + rt.kv, receivedFpVectorKey, uint32(numFps)) if err != nil { return nil, errors.Errorf(newReceivedTransferFpVectorErr, err) } // Create new part store - rf.receivedParts, err = newPartStore(rf.kv, numParts) + rt.receivedParts, err = newPartStore(rt.kv, numParts) if err != nil { return nil, errors.Errorf(newReceivedTransferPartStoreErr, err) } + // Create new StateVector for storing received status + rt.receivedStatus, err = utility.NewStateVector( + rt.kv, receivedVectorKey, uint32(rt.numParts)) + if err != nil { + return nil, errors.Errorf(newReceivedVectorErr, err) + } + // Save all fields without their own storage to storage - return rf, rf.saveInfo() + return rt, rt.saveInfo() } // GetTransferKey returns the transfer Key for this received transfer. @@ -160,10 +174,20 @@ func (rt *ReceivedTransfer) GetFileSize() uint32 { return rt.fileSize } +// IsPartReceived returns true if the part has successfully been received. +// Returns false if the part has not been received or if the part number is +// invalid. +func (rt *ReceivedTransfer) IsPartReceived(partNum uint16) bool { + _, exists := rt.receivedParts.getPart(partNum) + return exists +} + // GetProgress returns the current progress of the transfer. Completed is true // when all parts have been received, received is the number of parts received, -// and total is the total number of parts excepted to be received. -func (rt *ReceivedTransfer) GetProgress() (completed bool, received, total uint16) { +// total is the total number of parts excepted to be received, and t is a part +// status tracker that can be used to get the status of individual file parts. +func (rt *ReceivedTransfer) GetProgress() (completed bool, received, + total uint16, t ReceivedPartTracker) { rt.mux.RLock() defer rt.mux.RUnlock() @@ -174,7 +198,7 @@ func (rt *ReceivedTransfer) GetProgress() (completed bool, received, total uint1 completed = true } - return completed, received, total + return completed, received, total, NewReceivedPartTracker(rt.receivedStatus) } // CallProgressCB calls all the progress callbacks with the most recent progress @@ -228,6 +252,9 @@ func (rt *ReceivedTransfer) AddPart(encryptedPart, padding, mac []byte, partNum, // Mark the fingerprint as used rt.fpVector.Use(uint32(fpNum)) + // Mark part as received + rt.receivedStatus.Use(uint32(partNum)) + return nil } @@ -285,6 +312,12 @@ func loadReceivedTransfer(tid ftCrypto.TransferID, kv *versioned.KV) ( return nil, errors.Errorf(loadReceivePartStoreErr, err) } + // Load the received status StateVector from storage + rt.receivedStatus, err = utility.LoadStateVector(rt.kv, receivedVectorKey) + if err != nil { + return nil, errors.Errorf(loadReceivedVectorErr, err) + } + return rt, nil } @@ -344,6 +377,12 @@ func (rt *ReceivedTransfer) delete() error { return errors.Errorf(deleteReceivedFilePartsErr, err) } + // Delete the received status StateVector from storage + err = rt.receivedStatus.Delete() + if err != nil { + return errors.Errorf(deleteReceivedVectorErr, err) + } + return nil } diff --git a/storage/fileTransfer/receiveTransfer_test.go b/storage/fileTransfer/receiveTransfer_test.go index 9a5130e82..e8e6b1fed 100644 --- a/storage/fileTransfer/receiveTransfer_test.go +++ b/storage/fileTransfer/receiveTransfer_test.go @@ -37,6 +37,8 @@ func Test_NewReceivedTransfer(t *testing.T) { numParts, numFps := uint16(16), uint16(24) fpVector, _ := utility.NewStateVector( kvPrefixed, receivedFpVectorKey, uint32(numFps)) + receivedVector, _ := utility.NewStateVector( + kvPrefixed, receivedVectorKey, uint32(numParts)) expected := &ReceivedTransfer{ key: key, @@ -50,7 +52,9 @@ func Test_NewReceivedTransfer(t *testing.T) { numParts: numParts, kv: kvPrefixed, }, + receivedStatus: receivedVector, progressCallbacks: []*receivedCallbackTracker{}, + mux: sync.RWMutex{}, kv: kvPrefixed, } @@ -178,6 +182,25 @@ func TestReceivedTransfer_GetFileSize(t *testing.T) { } } +// Tests that ReceivedTransfer.IsPartReceived returns false for unreceived file +// parts and true when the file part has been received. +func TestReceivedTransfer_IsPartReceived(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, rt, _ := newEmptyReceivedTransfer(16, 20, kv, t) + + partNum := uint16(5) + + if rt.IsPartReceived(partNum) { + t.Errorf("Part number %d received.", partNum) + } + + _ = rt.receivedParts.addPart([]byte("part"), partNum) + + if !rt.IsPartReceived(partNum) { + t.Errorf("Part number %d not received.", partNum) + } +} + // checkReceivedProgress compares the output of ReceivedTransfer.GetProgress to // expected values. func checkReceivedProgress(completed bool, received, total uint16, @@ -194,6 +217,42 @@ func checkReceivedProgress(completed bool, received, total uint16, return nil } +// checkReceivedTracker checks that the ReceivedPartTracker is reporting the +// correct values for each part. Also checks that ReceivedPartTracker.GetNumParts +// returns the expected value (make sure numParts comes from a correct source). +func checkReceivedTracker(track ReceivedPartTracker, numParts uint16, + received []uint16, t *testing.T) { + if track.GetNumParts() != numParts { + t.Errorf("Tracker reported incorrect number of parts."+ + "\nexpected: %d\nreceived: %d", numParts, track.GetNumParts()) + return + } + + for partNum := uint16(0); partNum < numParts; partNum++ { + var done bool + for _, receivedNum := range received { + if receivedNum == partNum { + if track.GetPartStatus(partNum) != receivedStatus { + t.Errorf("Part number %d has unexpected status."+ + "\nexpected: %d\nreceived: %d", partNum, receivedStatus, + track.GetPartStatus(partNum)) + } + done = true + break + } + } + if done { + continue + } + + if track.GetPartStatus(partNum) != unsentStatus { + t.Errorf("Part number %d has incorrect status."+ + "\nexpected: %d\nreceived: %d", + partNum, unsentStatus, track.GetPartStatus(partNum)) + } + } +} + // Tests that ReceivedTransfer.GetProgress returns the expected progress metrics // for various transfer states. func TestReceivedTransfer_GetProgress(t *testing.T) { @@ -201,57 +260,70 @@ func TestReceivedTransfer_GetProgress(t *testing.T) { numParts := uint16(16) _, rt, _ := newEmptyReceivedTransfer(numParts, 20, kv, t) - completed, received, total := rt.GetProgress() + completed, received, total, track := rt.GetProgress() err := checkReceivedProgress(completed, received, total, false, 0, numParts) if err != nil { t.Error(err) } + checkReceivedTracker(track, rt.numParts, nil, t) _, _ = rt.fpVector.Next() + _, _ = rt.receivedStatus.Next() - completed, received, total = rt.GetProgress() + completed, received, total, track = rt.GetProgress() err = checkReceivedProgress(completed, received, total, false, 1, numParts) if err != nil { t.Error(err) } + checkReceivedTracker(track, rt.numParts, []uint16{0}, t) for i := 0; i < 4; i++ { _, _ = rt.fpVector.Next() + _, _ = rt.receivedStatus.Next() } - completed, received, total = rt.GetProgress() + completed, received, total, track = rt.GetProgress() err = checkReceivedProgress(completed, received, total, false, 5, numParts) if err != nil { t.Error(err) } + checkReceivedTracker(track, rt.numParts, []uint16{0, 1, 2, 3, 4}, t) for i := 0; i < 6; i++ { _, _ = rt.fpVector.Next() + _, _ = rt.receivedStatus.Next() } - completed, received, total = rt.GetProgress() + completed, received, total, track = rt.GetProgress() err = checkReceivedProgress(completed, received, total, false, 11, numParts) if err != nil { t.Error(err) } + checkReceivedTracker(track, rt.numParts, []uint16{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, t) for i := 0; i < 4; i++ { _, _ = rt.fpVector.Next() + _, _ = rt.receivedStatus.Next() } - completed, received, total = rt.GetProgress() + completed, received, total, track = rt.GetProgress() err = checkReceivedProgress(completed, received, total, false, 15, numParts) if err != nil { t.Error(err) } + checkReceivedTracker(track, rt.numParts, []uint16{0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14}, t) _, _ = rt.fpVector.Next() + _, _ = rt.receivedStatus.Next() - completed, received, total = rt.GetProgress() + completed, received, total, track = rt.GetProgress() err = checkReceivedProgress(completed, received, total, true, 16, numParts) if err != nil { t.Error(err) } + checkReceivedTracker(track, rt.numParts, []uint16{0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15}, t) } // Tests that 5 different callbacks all receive the expected data when @@ -275,7 +347,8 @@ func TestReceivedTransfer_CallProgressCB(t *testing.T) { for i := 0; i < numCallbacks; i++ { progressChan := make(chan progressResults) - cbFunc := func(completed bool, received, total uint16, err error) { + cbFunc := func(completed bool, received, total uint16, + t interfaces.FilePartTracker, err error) { progressChan <- progressResults{completed, received, total, err} } wg.Add(1) @@ -367,7 +440,8 @@ func TestReceivedTransfer_AddProgressCB(t *testing.T) { } cbChan := make(chan callbackResults) cbFunc := interfaces.ReceivedProgressCallback( - func(completed bool, received, total uint16, err error) { + func(completed bool, received, total uint16, + t interfaces.FilePartTracker, err error) { cbChan <- callbackResults{completed, received, total, err} }) @@ -454,6 +528,17 @@ func TestReceivedTransfer_AddPart(t *testing.T) { t.Errorf("Incorrect number of used keys in fingerprint list."+ "\nexpected: %d\nreceived: %d", 1, rt.fpVector.GetNumUsed()) } + + // Check that the part was properly marked as received on the received + // status vector + if !rt.receivedStatus.Used(uint32(partNum)) { + t.Errorf("Part number %d not marked as used in received status vector.", + partNum) + } + if rt.receivedStatus.GetNumUsed() != 1 { + t.Errorf("Incorrect number of received parts in vector."+ + "\nexpected: %d\nreceived: %d", 1, rt.receivedStatus.GetNumUsed()) + } } // Error path: tests that ReceivedTransfer.AddPart returns the expected error @@ -576,7 +661,7 @@ func Test_loadReceivedTransfer_LoadInfoError(t *testing.T) { // Error path: tests that loadReceivedTransfer returns the expected error when // the fingerprint state vector has been deleted from storage. -func Test_loadReceivedTransfer_LoadStateVectorError(t *testing.T) { +func Test_loadReceivedTransfer_LoadFpVectorError(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) tid, rt, _ := newRandomReceivedTransfer(16, 20, kv, t) @@ -638,6 +723,38 @@ func Test_loadReceivedTransfer_LoadPartStoreError(t *testing.T) { } } +// Error path: tests that loadReceivedTransfer returns the expected error when +// the received status state vector has been deleted from storage. +func Test_loadReceivedTransfer_LoadReceivedVectorError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + tid, rt, _ := newRandomReceivedTransfer(16, 20, kv, t) + + // Create encrypted part + data := []byte("test") + partNum, fpNum := uint16(1), uint16(1) + encryptedPart, mac, padding := newEncryptedPartData(rt.key, data, fpNum, t) + + // Add encrypted part + err := rt.AddPart(encryptedPart, padding, mac, partNum, fpNum) + if err != nil { + t.Errorf("Failed to add test part: %+v", err) + } + + // Delete fingerprint state vector from storage + err = rt.receivedStatus.Delete() + if err != nil { + t.Errorf("Failed to received status vector: %+v", err) + } + + expectedErr := strings.Split(loadReceivedVectorErr, "%")[0] + _, err = loadReceivedTransfer(tid, kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadReceivedTransfer did not return the expected error when "+ + "the received status vector was deleted from storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + // Tests that ReceivedTransfer.saveInfo saves the expected data to storage. func TestReceivedTransfer_saveInfo(t *testing.T) { rt := &ReceivedTransfer{ @@ -743,6 +860,13 @@ func TestReceivedTransfer_delete(t *testing.T) { t.Error("Successfully loaded fingerprint vector from storage when it " + "should have been deleted.") } + + // Check that the received status vector was deleted + _, err = utility.LoadStateVector(rt.kv, receivedVectorKey) + if err == nil { + t.Error("Successfully loaded received status vector from storage when " + + "it should have been deleted.") + } } // Tests that ReceivedTransfer.deleteInfo removes the saved ReceivedTransfer diff --git a/storage/fileTransfer/receivedCallbackTracker.go b/storage/fileTransfer/receivedCallbackTracker.go index 67b454b5d..ac4295c8d 100644 --- a/storage/fileTransfer/receivedCallbackTracker.go +++ b/storage/fileTransfer/receivedCallbackTracker.go @@ -85,11 +85,11 @@ func (rct *receivedCallbackTracker) call(tracker receivedProgressTracker, err er // callNow calls the callback immediately regardless of the schedule or period. func (rct *receivedCallbackTracker) callNow(tracker receivedProgressTracker, err error) { - completed, received, total := tracker.GetProgress() - go rct.cb(completed, received, total, err) + completed, received, total, t := tracker.GetProgress() + go rct.cb(completed, received, total, t, err) } // receivedProgressTracker interface tracks the progress of a transfer. type receivedProgressTracker interface { - GetProgress() (completed bool, received, total uint16) + GetProgress() (completed bool, received, total uint16, t ReceivedPartTracker) } diff --git a/storage/fileTransfer/receivedCallbackTracker_test.go b/storage/fileTransfer/receivedCallbackTracker_test.go index bd3e8a479..80bf43e82 100644 --- a/storage/fileTransfer/receivedCallbackTracker_test.go +++ b/storage/fileTransfer/receivedCallbackTracker_test.go @@ -8,6 +8,7 @@ package fileTransfer import ( + "gitlab.com/elixxir/client/interfaces" "reflect" "testing" "time" @@ -23,7 +24,8 @@ func Test_newReceivedCallbackTracker(t *testing.T) { } cbChan := make(chan cbFields) - cbFunc := func(completed bool, received, total uint16, err error) { + cbFunc := func(completed bool, received, total uint16, + t interfaces.FilePartTracker, err error) { cbChan <- cbFields{completed, received, total, err} } @@ -36,7 +38,7 @@ func Test_newReceivedCallbackTracker(t *testing.T) { receivedSCT := newReceivedCallbackTracker(expectedRCT.cb, expectedRCT.period) - go receivedSCT.cb(false, 0, 0, nil) + go receivedSCT.cb(false, 0, 0, nil, nil) select { case <-time.NewTimer(time.Millisecond).C: @@ -58,7 +60,7 @@ func Test_newReceivedCallbackTracker(t *testing.T) { } } -// Tests that +// Tests that receivedCallbackTracker.call triggers the proper callback. func Test_receivedCallbackTracker_call(t *testing.T) { type cbFields struct { completed bool @@ -67,13 +69,14 @@ func Test_receivedCallbackTracker_call(t *testing.T) { } cbChan := make(chan cbFields) - cbFunc := func(completed bool, received, total uint16, err error) { + cbFunc := func(completed bool, received, total uint16, + t interfaces.FilePartTracker, err error) { cbChan <- cbFields{completed, received, total, err} } sct := newReceivedCallbackTracker(cbFunc, 50*time.Millisecond) - tracker := testReceiveTrack{false, 1, 3} + tracker := testReceiveTrack{false, 1, 3, ReceivedPartTracker{}} sct.call(tracker, nil) select { @@ -86,7 +89,7 @@ func Test_receivedCallbackTracker_call(t *testing.T) { } } - tracker = testReceiveTrack{true, 3, 3} + tracker = testReceiveTrack{true, 3, 3, ReceivedPartTracker{}} sct.call(tracker, nil) select { @@ -122,9 +125,10 @@ func TestReceivedTransfer_ReceivedProgressTrackerInterface(t *testing.T) { type testReceiveTrack struct { completed bool received, total uint16 + t ReceivedPartTracker } // GetProgress returns the values in the testTrack. -func (trt testReceiveTrack) GetProgress() (completed bool, received, total uint16) { - return trt.completed, trt.received, trt.total +func (trt testReceiveTrack) GetProgress() (completed bool, received, total uint16, t ReceivedPartTracker) { + return trt.completed, trt.received, trt.total, trt.t } diff --git a/storage/fileTransfer/receivedPartTracker.go b/storage/fileTransfer/receivedPartTracker.go new file mode 100644 index 000000000..b31d61c2f --- /dev/null +++ b/storage/fileTransfer/receivedPartTracker.go @@ -0,0 +1,45 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import "gitlab.com/elixxir/client/storage/utility" + +// ReceivedPartTracker tracks the status of individual received file parts. +type ReceivedPartTracker struct { + // The number of file parts in the file + numParts uint16 + + // Stores the received status for each file part in a bitstream format + receivedStatus *utility.StateVector +} + +// NewReceivedPartTracker creates a new ReceivedPartTracker with copies of the +// received status state vectors. +func NewReceivedPartTracker(received *utility.StateVector) ReceivedPartTracker { + return ReceivedPartTracker{ + numParts: uint16(received.GetNumKeys()), + receivedStatus: received.DeepCopy(), + } +} + +// GetPartStatus returns the status of the received file part with the given part +// number. The possible values for the status are: +// 0 = unreceived +// 3 = received (receiver has received a part) +func (rpt ReceivedPartTracker) GetPartStatus(partNum uint16) int { + if rpt.receivedStatus.Used(uint32(partNum)) { + return receivedStatus + } else { + return unsentStatus + } +} + +// GetNumParts returns the total number of file parts in the transfer. +func (rpt ReceivedPartTracker) GetNumParts() uint16 { + return rpt.numParts +} diff --git a/storage/fileTransfer/receivedPartTracker_test.go b/storage/fileTransfer/receivedPartTracker_test.go new file mode 100644 index 000000000..1e8269a02 --- /dev/null +++ b/storage/fileTransfer/receivedPartTracker_test.go @@ -0,0 +1,89 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/ekv" + "math/rand" + "reflect" + "testing" +) + +// Tests that ReceivedPartTracker satisfies the interfaces.FilePartTracker +// interface. +func TestReceivedPartTracker_FilePartTrackerInterface(t *testing.T) { + var _ interfaces.FilePartTracker = ReceivedPartTracker{} +} + +// Tests that NewReceivedPartTracker returns a new ReceivedPartTracker with the +// expected values. +func TestNewReceivedPartTracker(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, rt, _ := newRandomReceivedTransfer(16, 24, kv, t) + + expected := ReceivedPartTracker{ + numParts: rt.numParts, + receivedStatus: rt.receivedStatus.DeepCopy(), + } + + newRPT := NewReceivedPartTracker(rt.receivedStatus) + + if !reflect.DeepEqual(expected, newRPT) { + t.Errorf("New ReceivedPartTracker does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, newRPT) + } +} + +// Tests that ReceivedPartTracker.GetPartStatus returns the expected status for +// each part loaded from a preconfigured ReceivedTransfer. +func TestReceivedPartTracker_GetPartStatus(t *testing.T) { + // Create new ReceivedTransfer + kv := versioned.NewKV(make(ekv.Memstore)) + _, rt, _ := newEmptyReceivedTransfer(16, 24, kv, t) + + // Set statuses of parts in the ReceivedTransfer and a map randomly + prng := rand.New(rand.NewSource(42)) + partStatuses := make(map[uint16]int, rt.numParts) + for partNum := uint16(0); partNum < rt.numParts; partNum++ { + partStatuses[partNum] = prng.Intn(2) * receivedStatus + + if partStatuses[partNum] == receivedStatus { + rt.receivedStatus.Use(uint32(partNum)) + } + } + + // Create a new ReceivedPartTracker from the ReceivedTransfer + rpt := NewReceivedPartTracker(rt.receivedStatus) + + // Check that the statuses for each part matches the map + for partNum := uint16(0); partNum < rt.numParts; partNum++ { + if rpt.GetPartStatus(partNum) != partStatuses[partNum] { + t.Errorf("Part number %d does not have expected status."+ + "\nexpected: %d\nreceived: %d", + partNum, partStatuses[partNum], rpt.GetPartStatus(partNum)) + } + } +} + +// Tests that ReceivedPartTracker.GetNumParts returns the same number of parts +// as the ReceivedPartTracker it was created from. +func TestReceivedPartTracker_GetNumParts(t *testing.T) { + // Create new ReceivedTransfer + kv := versioned.NewKV(make(ekv.Memstore)) + _, rt, _ := newEmptyReceivedTransfer(16, 24, kv, t) + + // Create a new ReceivedPartTracker from the ReceivedTransfer + rpt := NewReceivedPartTracker(rt.receivedStatus) + + if rpt.GetNumParts() != rt.GetNumParts() { + t.Errorf("Number of parts incorrect.\nexpected: %d\nreceived: %d", + rt.GetNumParts(), rpt.GetNumParts()) + } +} diff --git a/storage/fileTransfer/sentCallbackTracker.go b/storage/fileTransfer/sentCallbackTracker.go index 3309eaef7..53e210b2e 100644 --- a/storage/fileTransfer/sentCallbackTracker.go +++ b/storage/fileTransfer/sentCallbackTracker.go @@ -85,25 +85,25 @@ func (sct *sentCallbackTracker) call(tracker sentProgressTracker, err error) { // callNow calls the callback immediately regardless of the schedule or period. func (sct *sentCallbackTracker) callNow(tracker sentProgressTracker, err error) { - completed, sent, arrived, total := tracker.GetProgress() - go sct.cb(completed, sent, arrived, total, err) + completed, sent, arrived, total, t := tracker.GetProgress() + go sct.cb(completed, sent, arrived, total, t, err) } // callNowUnsafe calls the callback immediately regardless of the schedule or // period without taking a thread lock. This function should be used if a lock // is already taken on the sentProgressTracker. func (sct *sentCallbackTracker) callNowUnsafe(tracker sentProgressTracker, err error) { - completed, sent, arrived, total := tracker.getProgress() - go sct.cb(completed, sent, arrived, total, err) + completed, sent, arrived, total, t := tracker.getProgress() + go sct.cb(completed, sent, arrived, total, t, err) } // sentProgressTracker interface tracks the progress of a transfer. type sentProgressTracker interface { // GetProgress returns the sent transfer progress in a thread-safe manner. - GetProgress() (completed bool, sent, arrived, total uint16) + GetProgress() (completed bool, sent, arrived, total uint16, t SentPartTracker) // getProgress returns the sent transfer progress in a thread-unsafe manner. // This function should be used if a lock is already taken on the sent // transfer. - getProgress() (completed bool, sent, arrived, total uint16) + getProgress() (completed bool, sent, arrived, total uint16, t SentPartTracker) } diff --git a/storage/fileTransfer/sentCallbackTracker_test.go b/storage/fileTransfer/sentCallbackTracker_test.go index 925268928..a99726e2b 100644 --- a/storage/fileTransfer/sentCallbackTracker_test.go +++ b/storage/fileTransfer/sentCallbackTracker_test.go @@ -8,6 +8,7 @@ package fileTransfer import ( + "gitlab.com/elixxir/client/interfaces" "reflect" "testing" "time" @@ -23,7 +24,8 @@ func Test_newSentCallbackTracker(t *testing.T) { } cbChan := make(chan cbFields) - cbFunc := func(completed bool, sent, arrived, total uint16, err error) { + cbFunc := func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { cbChan <- cbFields{completed, sent, arrived, total, err} } @@ -36,7 +38,7 @@ func Test_newSentCallbackTracker(t *testing.T) { receivedSCT := newSentCallbackTracker(expectedSCT.cb, expectedSCT.period) - go receivedSCT.cb(false, 0, 0, 0, nil) + go receivedSCT.cb(false, 0, 0, 0, SentPartTracker{}, nil) select { case <-time.NewTimer(time.Millisecond).C: @@ -68,13 +70,14 @@ func Test_sentCallbackTracker_call(t *testing.T) { } cbChan := make(chan cbFields) - cbFunc := func(completed bool, sent, arrived, total uint16, err error) { + cbFunc := func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { cbChan <- cbFields{completed, sent, arrived, total, err} } sct := newSentCallbackTracker(cbFunc, 50*time.Millisecond) - tracker := testSentTrack{false, 1, 2, 3} + tracker := testSentTrack{false, 1, 2, 3, SentPartTracker{}} sct.call(tracker, nil) select { @@ -88,7 +91,7 @@ func Test_sentCallbackTracker_call(t *testing.T) { } } - tracker = testSentTrack{false, 1, 2, 3} + tracker = testSentTrack{false, 1, 2, 3, SentPartTracker{}} sct.call(tracker, nil) select { @@ -125,13 +128,16 @@ func TestSentTransfer_SentProgressTrackerInterface(t *testing.T) { type testSentTrack struct { completed bool sent, arrived, total uint16 + t SentPartTracker } -func (tst testSentTrack) getProgress() (completed bool, sent, arrived, total uint16) { - return tst.completed, tst.sent, tst.arrived, tst.total +func (tst testSentTrack) getProgress() (completed bool, sent, arrived, + total uint16, t SentPartTracker) { + return tst.completed, tst.sent, tst.arrived, tst.total, tst.t } // GetProgress returns the values in the testTrack. -func (tst testSentTrack) GetProgress() (completed bool, sent, arrived, total uint16) { - return tst.completed, tst.sent, tst.arrived, tst.total +func (tst testSentTrack) GetProgress() (completed bool, sent, arrived, + total uint16, t SentPartTracker) { + return tst.completed, tst.sent, tst.arrived, tst.total, tst.t } diff --git a/storage/fileTransfer/sentPartTracker.go b/storage/fileTransfer/sentPartTracker.go new file mode 100644 index 000000000..8c09acd1f --- /dev/null +++ b/storage/fileTransfer/sentPartTracker.go @@ -0,0 +1,62 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "gitlab.com/elixxir/client/storage/utility" +) + +// Part statuses. +const ( + unsentStatus = 0 + sentStatus = 1 + arrivedStatus = 2 + receivedStatus = 3 +) + +// SentPartTracker tracks the status of individual sent file parts. +type SentPartTracker struct { + // The number of file parts in the file + numParts uint16 + + // Stores the in-progress status for each file part in a bitstream format + inProgressStatus *utility.StateVector + + // Stores the finished status for each file part in a bitstream format + finishedStatus *utility.StateVector +} + +// NewSentPartTracker creates a new SentPartTracker with copies of the +// in-progress and finished status state vectors. +func NewSentPartTracker(inProgress, finished *utility.StateVector) SentPartTracker { + return SentPartTracker{ + numParts: uint16(inProgress.GetNumKeys()), + inProgressStatus: inProgress.DeepCopy(), + finishedStatus: finished.DeepCopy(), + } +} + +// GetPartStatus returns the status of the sent file part with the given part +// number. The possible values for the status are: +// 0 = unsent +// 1 = sent (sender has sent a part, but it has not arrived) +// 2 = arrived (sender has sent a part, and it has arrived) +func (spt SentPartTracker) GetPartStatus(partNum uint16) int { + if spt.inProgressStatus.Used(uint32(partNum)) { + return sentStatus + } else if spt.finishedStatus.Used(uint32(partNum)) { + return arrivedStatus + } else { + return unsentStatus + } +} + +// GetNumParts returns the total number of file parts in the transfer. +func (spt SentPartTracker) GetNumParts() uint16 { + return spt.numParts +} diff --git a/storage/fileTransfer/sentPartTracker_test.go b/storage/fileTransfer/sentPartTracker_test.go new file mode 100644 index 000000000..beb72908c --- /dev/null +++ b/storage/fileTransfer/sentPartTracker_test.go @@ -0,0 +1,93 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/ekv" + "math/rand" + "reflect" + "testing" +) + +// Tests that SentPartTracker satisfies the interfaces.FilePartTracker +// interface. +func TestSentPartTracker_FilePartTrackerInterface(t *testing.T) { + var _ interfaces.FilePartTracker = SentPartTracker{} +} + +// Tests that NewSentPartTracker returns a new SentPartTracker with the expected +// values. +func TestNewSentPartTracker(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + + expected := SentPartTracker{ + numParts: st.numParts, + inProgressStatus: st.inProgressStatus.DeepCopy(), + finishedStatus: st.finishedStatus.DeepCopy(), + } + + newSPT := NewSentPartTracker(st.inProgressStatus, st.finishedStatus) + + if !reflect.DeepEqual(expected, newSPT) { + t.Errorf("New SentPartTracker does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, newSPT) + } +} + +// Tests that SentPartTracker.GetPartStatus returns the expected status for each +// part loaded from a preconfigured SentTransfer. +func TestSentPartTracker_GetPartStatus(t *testing.T) { + // Create new SentTransfer + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + + // Set statuses of parts in the SentTransfer and a map randomly + prng := rand.New(rand.NewSource(42)) + partStatuses := make(map[uint16]int, st.numParts) + for partNum := uint16(0); partNum < st.numParts; partNum++ { + partStatuses[partNum] = prng.Intn(3) + + switch partStatuses[partNum] { + case sentStatus: + st.inProgressStatus.Use(uint32(partNum)) + case arrivedStatus: + st.finishedStatus.Use(uint32(partNum)) + } + } + + // Create a new SentPartTracker from the SentTransfer + spt := NewSentPartTracker(st.inProgressStatus, st.finishedStatus) + + // Check that the statuses for each part matches the map + for partNum := uint16(0); partNum < st.numParts; partNum++ { + if spt.GetPartStatus(partNum) != partStatuses[partNum] { + t.Errorf("Part number %d does not have expected status."+ + "\nexpected: %d\nreceived: %d", + partNum, partStatuses[partNum], spt.GetPartStatus(partNum)) + } + } +} + +// Tests that SentPartTracker.GetNumParts returns the same number of parts as +// the SentTransfer it was created from. +func TestSentPartTracker_GetNumParts(t *testing.T) { + // Create new SentTransfer + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + + // Create a new SentPartTracker from the SentTransfer + spt := NewSentPartTracker(st.inProgressStatus, st.finishedStatus) + + if spt.GetNumParts() != st.GetNumParts() { + t.Errorf("Number of parts incorrect.\nexpected: %d\nreceived: %d", + st.GetNumParts(), spt.GetNumParts()) + } +} diff --git a/storage/fileTransfer/sentTransfer.go b/storage/fileTransfer/sentTransfer.go index 2b05f84d5..3a8de11e2 100644 --- a/storage/fileTransfer/sentTransfer.go +++ b/storage/fileTransfer/sentTransfer.go @@ -25,10 +25,12 @@ import ( // Storage keys and versions. const ( - sentTransferPrefix = "FileTransferSentTransferStore" - sentTransferKey = "SentTransfer" - sentTransferVersion = 0 - sentFpVectorKey = "SentFingerprintVector" + sentTransferPrefix = "FileTransferSentTransferStore" + sentTransferKey = "SentTransfer" + sentTransferVersion = 0 + sentFpVectorKey = "SentFingerprintVector" + sentInProgressVectorKey = "SentInProgressStatusVector" + sentFinishedVectorKey = "SentFinishedStatusVector" ) // Error messages. @@ -38,25 +40,33 @@ const ( newSentTransferPartStoreErr = "failed to create new part store: %+v" newInProgressTransfersErr = "failed to create new in-progress transfers bundle: %+v" newFinishedTransfersErr = "failed to create new finished transfers bundle: %+v" + newSentInProgressVectorErr = "failed to create new state vector for in-progress status: %+v" + newSentFinishedVectorErr = "failed to create new state vector for finished status: %+v" // SentTransfer.ReInit reInitSentTransferFpVectorErr = "failed to overwrite fingerprint state vector with new vector: %+v" reInitInProgressTransfersErr = "failed to overwrite in-progress transfers bundle: %+v" reInitFinishedTransfersErr = "failed to overwrite finished transfers bundle: %+v" + reInitSentInProgressVectorErr = "failed to overwrite in-progress state vector with new vector: %+v" + reInitSentFinishedVectorErr = "failed to overwrite finished state vector with new vector: %+v" // loadSentTransfer - loadSentStoreErr = "failed to load sent transfer info from storage: %+v" - loadSentFpVectorErr = "failed to load sent fingerprint vector from storage: %+v" - loadSentPartStoreErr = "failed to load sent part store from storage: %+v" - loadInProgressTransfersErr = "failed to load in-progress transfers bundle from storage: %+v" - loadFinishedTransfersErr = "failed to load finished transfers bundle from storage: %+v" + loadSentStoreErr = "failed to load sent transfer info from storage: %+v" + loadSentFpVectorErr = "failed to load sent fingerprint vector from storage: %+v" + loadSentPartStoreErr = "failed to load sent part store from storage: %+v" + loadInProgressTransfersErr = "failed to load in-progress transfers bundle from storage: %+v" + loadFinishedTransfersErr = "failed to load finished transfers bundle from storage: %+v" + loadSentInProgressVectorErr = "failed to load new in-progress status state vector from storage: %+v" + loadSentFinishedVectorErr = "failed to load new finished status state vector from storage: %+v" // SentTransfer.delete - deleteSentTransferInfoErr = "failed to delete sent transfer info from storage: %+v" - deleteSentFpVectorErr = "failed to delete sent fingerprint vector from storage: %+v" - deleteSentFilePartsErr = "failed to delete sent file parts from storage: %+v" - deleteInProgressTransfersErr = "failed to delete in-progress transfers from storage: %+v" - deleteFinishedTransfersErr = "failed to delete finished transfers from storage: %+v" + deleteSentTransferInfoErr = "failed to delete sent transfer info from storage: %+v" + deleteSentFpVectorErr = "failed to delete sent fingerprint vector from storage: %+v" + deleteSentFilePartsErr = "failed to delete sent file parts from storage: %+v" + deleteInProgressTransfersErr = "failed to delete in-progress transfers from storage: %+v" + deleteFinishedTransfersErr = "failed to delete finished transfers from storage: %+v" + deleteSentInProgressVectorErr = "failed to delete in-progress status state vector from storage: %+v" + deleteSentFinishedVectorErr = "failed to delete finished status state vector from storage: %+v" // SentTransfer.FinishTransfer noPartsForRoundErr = "no file parts in-progress on round %d" @@ -104,6 +114,12 @@ type SentTransfer struct { // List of parts per round that finished transferring finishedTransfers *transferredBundle + // Stores the in-progress status for each file part in a bitstream format + inProgressStatus *utility.StateVector + + // Stores the finished status for each file part in a bitstream format + finishedStatus *utility.StateVector + // List of callbacks to call for every send progressCallbacks []*sentCallbackTracker @@ -168,6 +184,20 @@ func NewSentTransfer(recipient *id.ID, tid ftCrypto.TransferID, return nil, errors.Errorf(newFinishedTransfersErr, err) } + // Create new StateVector for storing in-progress status + st.inProgressStatus, err = utility.NewStateVector( + st.kv, sentInProgressVectorKey, uint32(st.numParts)) + if err != nil { + return nil, errors.Errorf(newSentInProgressVectorErr, err) + } + + // Create new StateVector for storing in-progress status + st.finishedStatus, err = utility.NewStateVector( + st.kv, sentFinishedVectorKey, uint32(st.numParts)) + if err != nil { + return nil, errors.Errorf(newSentFinishedVectorErr, err) + } + // Add first progress callback if progressCB != nil { st.AddProgressCB(progressCB, period) @@ -177,7 +207,7 @@ func NewSentTransfer(recipient *id.ID, tid ftCrypto.TransferID, } // ReInit resets the SentTransfer to its initial state so that sending can -// resume from the beginning. ReInit is used when the sent transfer runs out of +// restart from the beginning. ReInit is used when the sent transfer runs out of // retries and a user wants to attempt to resend the entire file again. func (st *SentTransfer) ReInit(numFps uint16, progressCB interfaces.SentProgressCallback, period time.Duration) error { @@ -208,6 +238,20 @@ func (st *SentTransfer) ReInit(numFps uint16, return errors.Errorf(reInitFinishedTransfersErr, err) } + // Overwrite in-progress status StateVector + st.inProgressStatus, err = utility.NewStateVector( + st.kv, sentInProgressVectorKey, uint32(st.numParts)) + if err != nil { + return errors.Errorf(reInitSentInProgressVectorErr, err) + } + + // Overwrite finished status StateVector + st.finishedStatus, err = utility.NewStateVector( + st.kv, sentFinishedVectorKey, uint32(st.numParts)) + if err != nil { + return errors.Errorf(reInitSentFinishedVectorErr, err) + } + // Clear callbacks st.progressCallbacks = []*sentCallbackTracker{} @@ -264,11 +308,27 @@ func (st *SentTransfer) GetNumAvailableFps() uint16 { return uint16(st.fpVector.GetNumAvailable()) } +// IsPartInProgress returns true if the part has successfully been sent. Returns +// false if the part is unsent or finished sending or if the part number is +// invalid. +func (st *SentTransfer) IsPartInProgress(partNum uint16) bool { + return st.inProgressStatus.Used(uint32(partNum)) +} + +// IsPartFinished returns true if the part has successfully arrived. Returns +// false if the part is unsent or in the process of sending or if the part +// number is invalid. +func (st *SentTransfer) IsPartFinished(partNum uint16) bool { + return st.finishedStatus.Used(uint32(partNum)) +} + // GetProgress returns the current progress of the transfer. Completed is true // when all parts have arrived, sent is the number of in-progress parts, arrived -// is the number of finished parts and total is the total number of parts being -// sent. -func (st *SentTransfer) GetProgress() (completed bool, sent, arrived, total uint16) { +// is the number of finished parts, total is the total number of parts being +// sent, and t is a part status tracker that can be used to get the status of +// individual file parts. +func (st *SentTransfer) GetProgress() (completed bool, sent, arrived, + total uint16, t SentPartTracker) { st.mux.RLock() defer st.mux.RUnlock() @@ -276,7 +336,8 @@ func (st *SentTransfer) GetProgress() (completed bool, sent, arrived, total uint } // getProgress is the thread-unsafe helper function for GetProgress. -func (st *SentTransfer) getProgress() (completed bool, sent, arrived, total uint16) { +func (st *SentTransfer) getProgress() (completed bool, sent, arrived, + total uint16, t SentPartTracker) { arrived = st.finishedTransfers.getNumParts() sent = st.inProgressTransfers.getNumParts() total = st.numParts @@ -285,7 +346,8 @@ func (st *SentTransfer) getProgress() (completed bool, sent, arrived, total uint completed = true } - return completed, sent, arrived, total + return completed, sent, arrived, total, + NewSentPartTracker(st.inProgressStatus, st.finishedStatus) } // CallProgressCB calls all the progress callbacks with the most recent progress @@ -381,8 +443,14 @@ func (st *SentTransfer) SetInProgress(rid id.Round, partNums ...uint16) (error, st.mux.Lock() defer st.mux.Unlock() + // Set as in-progress in bundle _, exists := st.inProgressTransfers.getPartNums(rid) + // Set parts as in-progress in status vector + for _, partNum := range partNums { + st.inProgressStatus.Use(uint32(partNum)) + } + return st.inProgressTransfers.addPartNums(rid, partNums...), exists } @@ -405,6 +473,11 @@ func (st *SentTransfer) UnsetInProgress(rid id.Round) ([]uint16, error) { // Get the list of part numbers to be removed from list partNums, _ := st.inProgressTransfers.getPartNums(rid) + // Unsets parts as in-progress in status vector + for _, partNum := range partNums { + st.inProgressStatus.Unuse(uint32(partNum)) + } + return partNums, st.inProgressTransfers.deletePartNums(rid) } @@ -427,12 +500,22 @@ func (st *SentTransfer) FinishTransfer(rid id.Round) error { return errors.Errorf(deleteInProgressPartsErr, rid, err) } + // Unsets parts as in-progress in status vector + for _, partNum := range partNums { + st.inProgressStatus.Unuse(uint32(partNum)) + } + // Add the parts to the finished list err = st.finishedTransfers.addPartNums(rid, partNums...) if err != nil { return err } + // Set parts as finished in status vector + for _, partNum := range partNums { + st.finishedStatus.Use(uint32(partNum)) + } + // If all parts have been moved to the finished list, then set the status // to stopping if st.finishedTransfers.getNumParts() == st.numParts && @@ -485,15 +568,26 @@ func loadSentTransfer(tid ftCrypto.TransferID, kv *versioned.KV) (*SentTransfer, return nil, errors.Errorf(loadFinishedTransfersErr, err) } + // Load the in-progress status StateVector from storage + st.inProgressStatus, err = utility.LoadStateVector( + st.kv, sentInProgressVectorKey) + if err != nil { + return nil, errors.Errorf(loadSentInProgressVectorErr, err) + } + + // Load the finished status StateVector from storage + st.finishedStatus, err = utility.LoadStateVector( + st.kv, sentFinishedVectorKey) + if err != nil { + return nil, errors.Errorf(loadSentFinishedVectorErr, err) + } + return st, nil } -// saveInfo saves the SentTransfer info (transfer key, number of file parts, -// and number of fingerprints) to storage. - // saveInfo saves all fields in SentTransfer that do not have their own storage -// (recipient ID, transfer key, number of file parts, and number of -// fingerprints) to storage. +// (recipient ID, transfer key, number of file parts, number of fingerprints, +// and transfer status) to storage. func (st *SentTransfer) saveInfo() error { st.mux.Lock() defer st.mux.Unlock() @@ -509,8 +603,9 @@ func (st *SentTransfer) saveInfo() error { return st.kv.Set(sentTransferKey, sentTransferVersion, obj) } -// loadInfo gets the recipient ID, transfer key, number of part, and number of -// fingerprints from storage and saves it to the SentTransfer. +// loadInfo gets the recipient ID, transfer key, number of part, number of +// fingerprints, and transfer status from storage and saves it to the +// SentTransfer. func (st *SentTransfer) loadInfo() error { vo, err := st.kv.Get(sentTransferKey, sentTransferVersion) if err != nil { @@ -518,7 +613,8 @@ func (st *SentTransfer) loadInfo() error { } // Unmarshal the transfer key and numParts - st.recipient, st.key, st.numParts, st.numFps = unmarshalSentTransfer(vo.Data) + st.recipient, st.key, st.numParts, st.numFps, st.status = + unmarshalSentTransfer(vo.Data) return nil } @@ -558,6 +654,18 @@ func (st *SentTransfer) delete() error { return errors.Errorf(deleteFinishedTransfersErr, err) } + // Delete the in-progress status StateVector from storage + err = st.inProgressStatus.Delete() + if err != nil { + return errors.Errorf(deleteSentInProgressVectorErr, err) + } + + // Delete the finished status StateVector from storage + err = st.finishedStatus.Delete() + if err != nil { + return errors.Errorf(deleteSentFinishedVectorErr, err) + } + return nil } @@ -570,7 +678,7 @@ func (st *SentTransfer) deleteInfo() error { // marshal serializes the transfer key, numParts, and numFps. // marshal serializes all primitive fields in SentTransfer (recipient, key, -// numParts, and numFps). +// numParts, numFps, and status). func (st *SentTransfer) marshal() []byte { // Construct the buffer to the correct size // (size of ID + size of key + numParts (2 bytes) + numFps (2 bytes)) @@ -597,14 +705,19 @@ func (st *SentTransfer) marshal() []byte { binary.LittleEndian.PutUint16(b, st.numFps) buff.Write(b) + // Write the transfer status to the buffer + b = make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(st.status)) + buff.Write(b) + // Return the serialized data return buff.Bytes() } // unmarshalSentTransfer deserializes a byte slice into the primitive fields -// of SentTransfer (recipient, key, numParts, and numFps). -func unmarshalSentTransfer(b []byte) ( - recipient *id.ID, key ftCrypto.TransferKey, numParts, numFps uint16) { +// of SentTransfer (recipient, key, numParts, numFps, and status). +func unmarshalSentTransfer(b []byte) (recipient *id.ID, + key ftCrypto.TransferKey, numParts, numFps uint16, status transferStatus) { buff := bytes.NewBuffer(b) @@ -621,7 +734,10 @@ func unmarshalSentTransfer(b []byte) ( // Read the number of fingerprints from the buffer numFps = binary.LittleEndian.Uint16(buff.Next(2)) - return recipient, key, numParts, numFps + // Read the transfer status from the buffer + status = transferStatus(binary.LittleEndian.Uint64(buff.Next(8))) + + return recipient, key, numParts, numFps, status } // makeSentTransferPrefix generates the unique prefix used on the key value diff --git a/storage/fileTransfer/sentTransfer_test.go b/storage/fileTransfer/sentTransfer_test.go index 14e8e76ad..de3e97a94 100644 --- a/storage/fileTransfer/sentTransfer_test.go +++ b/storage/fileTransfer/sentTransfer_test.go @@ -44,6 +44,10 @@ func Test_NewSentTransfer(t *testing.T) { numParts, numFps := uint16(len(parts)), uint16(float64(len(parts))*1.5) fpVector, _ := utility.NewStateVector( kvPrefixed, sentFpVectorKey, uint32(numFps)) + inProgressStatus, _ := utility.NewStateVector( + kvPrefixed, sentInProgressVectorKey, uint32(numParts)) + finishedStatus, _ := utility.NewStateVector( + kvPrefixed, sentFinishedVectorKey, uint32(numParts)) type cbFields struct { completed bool @@ -60,7 +64,8 @@ func Test_NewSentTransfer(t *testing.T) { } cbChan := make(chan cbFields) - cb := func(completed bool, sent, arrived, total uint16, err error) { + cb := func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { cbChan <- cbFields{ completed: completed, sent: sent, @@ -93,10 +98,13 @@ func Test_NewSentTransfer(t *testing.T) { key: finishedKey, kv: kvPrefixed, }, + inProgressStatus: inProgressStatus, + finishedStatus: finishedStatus, progressCallbacks: []*sentCallbackTracker{ newSentCallbackTracker(cb, expectedPeriod), }, - kv: kvPrefixed, + status: running, + kv: kvPrefixed, } // Create new SentTransfer @@ -164,6 +172,10 @@ func TestSentTransfer_ReInit(t *testing.T) { numFps2 := 2 * numFps1 fpVector, _ := utility.NewStateVector( kvPrefixed, sentFpVectorKey, uint32(numFps2)) + inProgressStatus, _ := utility.NewStateVector( + kvPrefixed, sentInProgressVectorKey, uint32(numParts)) + finishedStatus, _ := utility.NewStateVector( + kvPrefixed, sentFinishedVectorKey, uint32(numParts)) type cbFields struct { completed bool @@ -180,7 +192,8 @@ func TestSentTransfer_ReInit(t *testing.T) { } cbChan := make(chan cbFields) - cb := func(completed bool, sent, arrived, total uint16, err error) { + cb := func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { cbChan <- cbFields{ completed: completed, sent: sent, @@ -213,10 +226,13 @@ func TestSentTransfer_ReInit(t *testing.T) { key: finishedKey, kv: kvPrefixed, }, + inProgressStatus: inProgressStatus, + finishedStatus: finishedStatus, progressCallbacks: []*sentCallbackTracker{ newSentCallbackTracker(cb, expectedPeriod), }, - kv: kvPrefixed, + status: running, + kv: kvPrefixed, } // Create new SentTransfer @@ -363,6 +379,65 @@ func TestSentTransfer_GetNumAvailableFps(t *testing.T) { } } +// Tests that SentTransfer.IsPartInProgress returns false before a part is set +// as in-progress and true after it is set via SentTransfer.SetInProgress. Also +// tests that it returns false after the part has been unset via +// SentTransfer.UnsetInProgress. +func TestSentTransfer_IsPartInProgress(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + + rid := id.Round(0) + partNum := uint16(7) + + // Test that the part has not been set to in-progress + if st.IsPartInProgress(partNum) { + t.Errorf("Part number %d set as in-progress.", partNum) + } + + // Set the part number to in-progress + _, _ = st.SetInProgress(rid, partNum) + + // Test that the part has been set to in-progress + if !st.IsPartInProgress(partNum) { + t.Errorf("Part number %d not set as in-progress.", partNum) + } + + // Unset the part as in-progress + _, _ = st.UnsetInProgress(rid) + + // Test that the part has been unset + if st.IsPartInProgress(partNum) { + t.Errorf("Part number %d set as in-progress.", partNum) + } +} + +// Tests that SentTransfer.IsPartFinished returns false before a part is set as +// finished and true after it is set via SentTransfer.FinishTransfer. +func TestSentTransfer_IsPartFinished(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + + rid := id.Round(0) + partNum := uint16(7) + + // Set the part number to in-progress + _, _ = st.SetInProgress(rid, partNum) + + // Test that the part has not been set to finished + if st.IsPartFinished(partNum) { + t.Errorf("Part number %d set as finished.", partNum) + } + + // Set the part number to finished + _ = st.FinishTransfer(rid) + + // Test that the part has been set to finished + if !st.IsPartFinished(partNum) { + t.Errorf("Part number %d not set as finished.", partNum) + } +} + // Tests that SentTransfer.GetProgress returns the expected progress metrics for // various transfer states. func TestSentTransfer_GetProgress(t *testing.T) { @@ -370,64 +445,74 @@ func TestSentTransfer_GetProgress(t *testing.T) { numParts := uint16(16) _, st := newRandomSentTransfer(16, 24, kv, t) - completed, sent, arrived, total := st.GetProgress() + completed, sent, arrived, total, track := st.GetProgress() err := checkSentProgress( completed, sent, arrived, total, false, 0, 0, numParts) if err != nil { t.Error(err) } + checkSentTracker(track, st.numParts, nil, nil, t) _, _ = st.SetInProgress(1, 0, 1, 2) - completed, sent, arrived, total = st.GetProgress() + completed, sent, arrived, total, track = st.GetProgress() err = checkSentProgress(completed, sent, arrived, total, false, 3, 0, numParts) if err != nil { t.Error(err) } + checkSentTracker(track, st.numParts, []uint16{0, 1, 2}, nil, t) _, _ = st.SetInProgress(2, 3, 4, 5) - completed, sent, arrived, total = st.GetProgress() + completed, sent, arrived, total, track = st.GetProgress() err = checkSentProgress(completed, sent, arrived, total, false, 6, 0, numParts) if err != nil { t.Error(err) } + checkSentTracker(track, st.numParts, []uint16{0, 1, 2, 3, 4, 5}, nil, t) _ = st.FinishTransfer(1) _, _ = st.UnsetInProgress(2) - completed, sent, arrived, total = st.GetProgress() + completed, sent, arrived, total, track = st.GetProgress() err = checkSentProgress(completed, sent, arrived, total, false, 0, 3, numParts) if err != nil { t.Error(err) } + checkSentTracker(track, st.numParts, nil, []uint16{0, 1, 2}, t) _, _ = st.SetInProgress(3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) - completed, sent, arrived, total = st.GetProgress() + completed, sent, arrived, total, track = st.GetProgress() err = checkSentProgress( completed, sent, arrived, total, false, 10, 3, numParts) if err != nil { t.Error(err) } + checkSentTracker(track, st.numParts, + []uint16{6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, []uint16{0, 1, 2}, t) _ = st.FinishTransfer(3) _, _ = st.SetInProgress(4, 3, 4, 5) - completed, sent, arrived, total = st.GetProgress() + completed, sent, arrived, total, track = st.GetProgress() err = checkSentProgress( completed, sent, arrived, total, false, 3, 13, numParts) if err != nil { t.Error(err) } + checkSentTracker(track, st.numParts, []uint16{3, 4, 5}, + []uint16{0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, t) _ = st.FinishTransfer(4) - completed, sent, arrived, total = st.GetProgress() + completed, sent, arrived, total, track = st.GetProgress() err = checkSentProgress(completed, sent, arrived, total, true, 0, 16, numParts) if err != nil { t.Error(err) } + checkSentTracker(track, st.numParts, nil, + []uint16{0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 3, 4, 5}, t) } // Tests that 5 different callbacks all receive the expected data when @@ -451,7 +536,8 @@ func TestSentTransfer_CallProgressCB(t *testing.T) { for i := 0; i < numCallbacks; i++ { progressChan := make(chan progressResults) - cbFunc := func(completed bool, sent, arrived, total uint16, err error) { + cbFunc := func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { progressChan <- progressResults{completed, sent, arrived, total, err} } wg.Add(1) @@ -545,7 +631,8 @@ func TestSentTransfer_AddProgressCB(t *testing.T) { } cbChan := make(chan callbackResults) cbFunc := interfaces.SentProgressCallback( - func(completed bool, sent, arrived, total uint16, err error) { + func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { cbChan <- callbackResults{completed, sent, arrived, total, err} }) @@ -719,7 +806,8 @@ func TestSentTransfer_GetEncryptedPart_EncryptPartError(t *testing.T) { } // Tests that SentTransfer.SetInProgress correctly adds the part numbers for the -// given round ID to the in-progress map. +// given round ID to the in-progress map and sets the correct parts as +// in-progress in the state vector. func TestSentTransfer_SetInProgress(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) _, st := newRandomSentTransfer(16, 24, kv, t) @@ -756,6 +844,22 @@ func TestSentTransfer_SetInProgress(t *testing.T) { "\nexpected: %d\nreceived: %d", 1, len(st.inProgressTransfers.list)) } + // Check that the part numbers were set on the in-progress status vector + for i, partNum := range expectedPartNums { + if !st.inProgressStatus.Used(uint32(partNum)) { + t.Errorf("Part number %d not marked as used in status vector (%d).", + partNum, i) + } + } + + // Check that the correct number of parts were marked as in-progress in the + // status vector + if int(st.inProgressStatus.GetNumUsed()) != len(expectedPartNums) { + t.Errorf("Incorrect number of parts marked as in-progress."+ + "\nexpected: %d\nreceived: %d", len(expectedPartNums), + st.inProgressStatus.GetNumUsed()) + } + // Add more parts to the in-progress list err, exists = st.SetInProgress(rid, expectedPartNums...) if err != nil { @@ -766,6 +870,13 @@ func TestSentTransfer_SetInProgress(t *testing.T) { if !exists { t.Errorf("Round %d should already exist.", rid) } + + // Check that the number of parts were marked as in-progress is unchanged + if int(st.inProgressStatus.GetNumUsed()) != len(expectedPartNums) { + t.Errorf("Incorrect number of parts marked as in-progress."+ + "\nexpected: %d\nreceived: %d", len(expectedPartNums), + st.inProgressStatus.GetNumUsed()) + } } // Tests that SentTransfer.GetInProgress returns the correct part numbers for @@ -805,7 +916,8 @@ func TestSentTransfer_GetInProgress(t *testing.T) { } // Tests that SentTransfer.UnsetInProgress correctly removes the part numbers -// for the given round ID from the in-progress map. +// for the given round ID from the in-progress map and unsets the correct parts +// as in-progress in the state vector. func TestSentTransfer_UnsetInProgress(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) _, st := newRandomSentTransfer(16, 24, kv, t) @@ -843,10 +955,17 @@ func TestSentTransfer_UnsetInProgress(t *testing.T) { t.Errorf("Extra items in in-progress list."+ "\nexpected: %d\nreceived: %d", 0, len(st.inProgressTransfers.list)) } + + // Check that there are no set parts in the in-progress status vector + if st.inProgressStatus.GetNumUsed() != 0 { + t.Errorf("Failed to unset all parts in the in-progress vector."+ + "\nexpected: %d\nreceived: %d", 0, st.inProgressStatus.GetNumUsed()) + } } // Tests that SentTransfer.FinishTransfer removes the parts from the in-progress -// list and moved them to the finished list. +// list and moved them to the finished list and that it unsets the correct parts +// in the in-progress vector in the state vector. func TestSentTransfer_FinishTransfer(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) _, st := newRandomSentTransfer(16, 24, kv, t) @@ -889,6 +1008,28 @@ func TestSentTransfer_FinishTransfer(t *testing.T) { t.Errorf("Extra items in finished list."+ "\nexpected: %d\nreceived: %d", 1, len(st.finishedTransfers.list)) } + + // Check that there are no set parts in the in-progress status vector + if st.inProgressStatus.GetNumUsed() != 0 { + t.Errorf("Failed to unset all parts in the in-progress vector."+ + "\nexpected: %d\nreceived: %d", 0, st.inProgressStatus.GetNumUsed()) + } + + // Check that the part numbers were set on the finished status vector + for i, partNum := range expectedPartNums { + if !st.finishedStatus.Used(uint32(partNum)) { + t.Errorf("Part number %d not marked as used in status vector (%d).", + partNum, i) + } + } + + // Check that the correct number of parts were marked as finished in the + // status vector + if int(st.finishedStatus.GetNumUsed()) != len(expectedPartNums) { + t.Errorf("Incorrect number of parts marked as finished."+ + "\nexpected: %d\nreceived: %d", len(expectedPartNums), + st.finishedStatus.GetNumUsed()) + } } // Error path: tests that SentTransfer.FinishTransfer returns the expected error @@ -961,7 +1102,7 @@ func Test_loadSentTransfer_LoadInfoError(t *testing.T) { // Error path: tests that loadSentTransfer returns the expected error when the // fingerprint state vector was deleted from storage. -func Test_loadSentTransfer_LoadStateVectorError(t *testing.T) { +func Test_loadSentTransfer_LoadFingerprintStateVectorError(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) tid, st := newRandomSentTransfer(16, 24, kv, t) @@ -975,7 +1116,7 @@ func Test_loadSentTransfer_LoadStateVectorError(t *testing.T) { _, err = loadSentTransfer(tid, kv) if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Errorf("loadSentTransfer did not return the expected error when "+ - "the fingerprint vector was delete from storage."+ + "the fingerprint vector was deleted from storage."+ "\nexpected: %s\nreceived: %+v", expectedErr, err) } } @@ -996,7 +1137,7 @@ func Test_loadSentTransfer_LoadPartStoreError(t *testing.T) { _, err = loadSentTransfer(tid, kv) if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Errorf("loadSentTransfer did not return the expected error when "+ - "the part store was delete from storage."+ + "the part store was deleted from storage."+ "\nexpected: %s\nreceived: %+v", expectedErr, err) } } @@ -1017,7 +1158,7 @@ func Test_loadSentTransfer_LoadInProgressTransfersError(t *testing.T) { _, err = loadSentTransfer(tid, kv) if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Errorf("loadSentTransfer did not return the expected error when "+ - "the in-progress transfers bundle was delete from storage."+ + "the in-progress transfers bundle was deleted from storage."+ "\nexpected: %s\nreceived: %+v", expectedErr, err) } } @@ -1038,7 +1179,49 @@ func Test_loadSentTransfer_LoadFinishedTransfersError(t *testing.T) { _, err = loadSentTransfer(tid, kv) if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Errorf("loadSentTransfer did not return the expected error when "+ - "the finished transfers bundle was delete from storage."+ + "the finished transfers bundle was deleted from storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: tests that loadSentTransfer returns the expected error when the +// in-progress status state vector was deleted from storage. +func Test_loadSentTransfer_LoadInProgressStateVectorError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + tid, st := newRandomSentTransfer(16, 24, kv, t) + + // Delete the in-progress state vector from storage + err := st.inProgressStatus.Delete() + if err != nil { + t.Errorf("Failed to delete the in-progress vector: %+v", err) + } + + expectedErr := strings.Split(loadSentInProgressVectorErr, "%")[0] + _, err = loadSentTransfer(tid, kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadSentTransfer did not return the expected error when "+ + "the in-progress vector was deleted from storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: tests that loadSentTransfer returns the expected error when the +// finished status state vector was deleted from storage. +func Test_loadSentTransfer_LoadFinishedStateVectorError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + tid, st := newRandomSentTransfer(16, 24, kv, t) + + // Delete the finished state vector from storage + err := st.finishedStatus.Delete() + if err != nil { + t.Errorf("Failed to delete the finished vector: %+v", err) + } + + expectedErr := strings.Split(loadSentFinishedVectorErr, "%")[0] + _, err = loadSentTransfer(tid, kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadSentTransfer did not return the expected error when "+ + "the finished vector was deleted from storage."+ "\nexpected: %s\nreceived: %+v", expectedErr, err) } } @@ -1164,6 +1347,21 @@ func TestSentTransfer_delete(t *testing.T) { t.Error("Successfully loaded fingerprint vector from storage when it " + "should have been deleted.") } + + // Check that the in-progress status vector was deleted + _, err = utility.LoadStateVector(st.kv, sentInProgressVectorKey) + if err == nil { + t.Error("Successfully loaded in-progress vector from storage when it " + + "should have been deleted.") + } + + // Check that the finished status vector was deleted + _, err = utility.LoadStateVector(st.kv, sentFinishedVectorKey) + if err == nil { + t.Error("Successfully loaded finished vector from storage when it " + + "should have been deleted.") + } + } // Tests that SentTransfer.deleteInfo removes the saved SentTransfer data from @@ -1202,11 +1400,12 @@ func TestSentTransfer_marshal_unmarshalSentTransfer(t *testing.T) { key: ftCrypto.UnmarshalTransferKey([]byte("key")), numParts: 16, numFps: 20, + status: stopped, } marshaledData := st.marshal() - recipient, key, numParts, numFps := unmarshalSentTransfer(marshaledData) + recipient, key, numParts, numFps, status := unmarshalSentTransfer(marshaledData) if !st.recipient.Cmp(recipient) { t.Errorf("Failed to get recipient ID.\nexpected: %s\nreceived: %s", @@ -1227,6 +1426,11 @@ func TestSentTransfer_marshal_unmarshalSentTransfer(t *testing.T) { t.Errorf("Failed to get expected number of fingerprints."+ "\nexpected: %d\nreceived: %d", st.numFps, numFps) } + + if st.status != status { + t.Errorf("Failed to get expected transfer status."+ + "\nexpected: %d\nreceived: %d", st.status, status) + } } // Consistency test: tests that makeSentTransferPrefix returns the expected @@ -1302,3 +1506,54 @@ func checkSentProgress(completed bool, sent, arrived, total uint16, return nil } + +// checkSentTracker checks that the SentPartTracker is reporting the correct +// values for each part. Also checks that SentPartTracker.GetNumParts returns +// the expected value (make sure numParts comes from a correct source). +func checkSentTracker(track SentPartTracker, numParts uint16, inProgress, + finished []uint16, t *testing.T) { + if track.GetNumParts() != numParts { + t.Errorf("Tracker reported incorrect number of parts."+ + "\nexpected: %d\nreceived: %d", numParts, track.GetNumParts()) + return + } + + for partNum := uint16(0); partNum < numParts; partNum++ { + var done bool + for _, inProgressNum := range inProgress { + if inProgressNum == partNum { + if track.GetPartStatus(partNum) != sentStatus { + t.Errorf("Part number %d has unexpected status."+ + "\nexpected: %d\nreceived: %d", partNum, sentStatus, + track.GetPartStatus(partNum)) + } + done = true + break + } + } + if done { + continue + } + + for _, finishedNum := range finished { + if finishedNum == partNum { + if track.GetPartStatus(partNum) != arrivedStatus { + t.Errorf("Part number %d has unexpected status."+ + "\nexpected: %d\nreceived: %d", partNum, arrivedStatus, + track.GetPartStatus(partNum)) + } + done = true + break + } + } + if done { + continue + } + + if track.GetPartStatus(partNum) != unsentStatus { + t.Errorf("Part number %d has incorrect status."+ + "\nexpected: %d\nreceived: %d", + partNum, unsentStatus, track.GetPartStatus(partNum)) + } + } +} diff --git a/storage/utility/stateVector.go b/storage/utility/stateVector.go index 3c458145a..19ee95c14 100644 --- a/storage/utility/stateVector.go +++ b/storage/utility/stateVector.go @@ -27,6 +27,7 @@ const ( // Error messages. const ( saveUsedKeyErr = "Failed to save %s after marking key %d as used: %+v" + saveUnusedKeyErr = "Failed to save %s after marking key %d as unused: %+v" saveNextErr = "failed to save %s after getting next available key: %+v" noKeysErr = "all keys used" loadUnmarshalErr = "failed to unmarshal from storage: %+v" @@ -105,6 +106,43 @@ func (sv *StateVector) use(keyNum uint32) { } } +// Unuse marks the key as unused (sets it to 0). +func (sv *StateVector) Unuse(keyNum uint32) { + sv.mux.Lock() + defer sv.mux.Unlock() + + // Mark the key as used + sv.unuse(keyNum) + + // Save changes to storage + if err := sv.save(); err != nil { + jww.FATAL.Printf(saveUnusedKeyErr, sv, keyNum, err) + } +} + +// unuse marks the key as unused (sets it to 0). It is not thread-safe and does +// not save to storage. +func (sv *StateVector) unuse(keyNum uint32) { + // If the key is already unused, then exit + if !sv.used(keyNum) { + return + } + + // Calculate block and position of the key + block, pos := getBlockAndPos(keyNum) + + // Set the key to unused (0) + sv.vect[block] &= ^(1 << pos) + + // Increment number available unused keys + sv.numAvailable++ + + // If this is before the first available key, then set to the next available + if keyNum < sv.firstAvailable { + sv.firstAvailable = keyNum + } +} + // Used returns true if the key is used (set to 1) or false if the key is unused // (set to 0). func (sv *StateVector) Used(keyNum uint32) bool { @@ -219,6 +257,24 @@ func (sv *StateVector) GetUsedKeyNums() []uint32 { return keyNums } +// DeepCopy creates a deep copy of the StateVector without a storage backend. +// The deep copy can only be used for functions that do not access storage. +func (sv *StateVector) DeepCopy() *StateVector { + newSV := &StateVector{ + vect: make([]uint64, len(sv.vect)), + firstAvailable: sv.firstAvailable, + numKeys: sv.numKeys, + numAvailable: sv.numAvailable, + key: sv.key, + } + + for i, val := range sv.vect { + newSV.vect[i] = val + } + + return newSV +} + // getBlockAndPos calculates the block index and the position within that block // of a key number. func getBlockAndPos(keyNum uint32) (block, pos uint32) { diff --git a/storage/utility/stateVector_test.go b/storage/utility/stateVector_test.go index 90a637173..3c92a23b8 100644 --- a/storage/utility/stateVector_test.go +++ b/storage/utility/stateVector_test.go @@ -60,7 +60,6 @@ func TestStateVector_Use(t *testing.T) { usedKeys := []uint32{0, 2, 3, 4, 6, 39, 62, 70, 98, 100} usedKeysMap := make(map[uint32]bool, len(usedKeys)) for i, keyNum := range usedKeys { - sv.Use(keyNum) // Check if numAvailable is correct @@ -91,6 +90,56 @@ func TestStateVector_Use(t *testing.T) { } } +// Tests that StateVector.Unuse sets the correct keys to unused and does not +// modify others keys and that numAvailable is correctly set. +func TestStateVector_Unuse(t *testing.T) { + sv := newTestStateVector("StateVectorUse", 138, t) + + // Set all the keys to used + for keyNum := uint32(0); keyNum < sv.numKeys; keyNum++ { + sv.Use(keyNum) + } + + // Set some keys to unused + unusedKeys := []uint32{0, 2, 3, 4, 6, 39, 62, 70, 98, 100} + unusedKeysMap := make(map[uint32]bool, len(unusedKeys)) + for i, keyNum := range unusedKeys { + sv.Unuse(keyNum) + + // Check if numAvailable is correct + if sv.numAvailable != uint32(i)+1 { + t.Errorf("numAvailable incorrect (%d).\nexpected: %d\nreceived: %d", + i, uint32(i)+1, sv.numAvailable) + } + + // Check if firstAvailable is correct + if sv.firstAvailable != unusedKeys[0] { + t.Errorf("firstAvailable incorrect (%d).\nexpected: %d\nreceived: %d", + i, unusedKeys[0], sv.firstAvailable) + } + + unusedKeysMap[keyNum] = true + } + + // Check all keys for their expected states + for i := uint32(0); i < sv.numKeys; i++ { + if unusedKeysMap[i] { + if sv.Used(i) { + t.Errorf("Key #%d should have been marked unused.", i) + } + } else if !sv.Used(i) { + t.Errorf("Key #%d should have been marked used.", i) + } + } + + // Make sure numAvailable is not modified when the key is already used + sv.Unuse(unusedKeys[0]) + if sv.numAvailable != uint32(len(unusedKeys)) { + t.Errorf("numAvailable incorrect.\nexpected: %d\nreceived: %d", + uint32(len(unusedKeys)), sv.numAvailable) + } +} + // Tests StateVector.Used by creating a vector with known used and unused keys // and making sure it returns the expected state for all keys in the vector. func TestStateVector_Used(t *testing.T) { @@ -343,6 +392,32 @@ func TestStateVector_GetUsedKeyNums(t *testing.T) { } } +// Tests that StateVector.DeepCopy makes a copy of the values and not of the +// pointers. +func TestStateVector_DeepCopy(t *testing.T) { + sv := newTestStateVector("StateVectorGetUsedKeyNums", 1000, t) + + newSV := sv.DeepCopy() + sv.kv = nil + + // Check that the values are the same + if !reflect.DeepEqual(sv, newSV) { + t.Errorf("Original and copy do not match."+ + "\nexpected: %#v\nreceived: %#v", sv, newSV) + } + + // Check that the pointers are different + if sv == newSV { + t.Errorf("Original and copy do not match."+ + "\nexpected: %p\nreceived: %p", sv, newSV) + } + + if &sv.vect == &newSV.vect { + t.Errorf("Original and copy do not match."+ + "\nexpected: %p\nreceived: %p", sv.vect, newSV.vect) + } +} + // Tests that StateVector.String returns the expected string. func TestStateVector_String(t *testing.T) { key := "StateVectorString" -- GitLab