diff --git a/storage/fileTransfer/sentPartTracker_test.go b/storage/fileTransfer/sentPartTracker_test.go index 431abd425b0ae9d66d13300712b7ce0841eb674d..8d5a476ed542563b7524e9f3da9d668680f9ff22 100644 --- a/storage/fileTransfer/sentPartTracker_test.go +++ b/storage/fileTransfer/sentPartTracker_test.go @@ -62,7 +62,12 @@ func Test_sentPartTracker_GetPartStatus(t *testing.T) { partNum, interfaces.FpSent, err) } case interfaces.FpArrived: - err := st.partStats.Set(partNum, uint8(interfaces.FpArrived)) + err := st.partStats.Set(partNum, uint8(interfaces.FpSent)) + if err != nil { + t.Errorf("Failed to set part %d to %s: %+v", + partNum, interfaces.FpSent, err) + } + err = st.partStats.Set(partNum, uint8(interfaces.FpArrived)) if err != nil { t.Errorf("Failed to set part %d to %s: %+v", partNum, interfaces.FpArrived, err) diff --git a/storage/fileTransfer/sentTransfer.go b/storage/fileTransfer/sentTransfer.go index 01e016d8e8e75b853aab1b5739fbe353abc47224..351c50a78b607af43edde1b95c0470c4f5ba4f3e 100644 --- a/storage/fileTransfer/sentTransfer.go +++ b/storage/fileTransfer/sentTransfer.go @@ -91,6 +91,13 @@ const ( // been used. var MaxRetriesErr = errors.New(maxRetriesErr) +// sentTransferStateMap prevents illegal state changes for part statuses. +var sentTransferStateMap = [][]bool{ + {false, true, false}, + {true, false, true}, + {false, false, false}, +} + // SentTransfer contains information and progress data for sending and in- // progress file transfer. type SentTransfer struct { @@ -182,7 +189,7 @@ func NewSentTransfer(recipient *id.ID, tid ftCrypto.TransferID, // Create new MultiStateVector for storing part statuses st.partStats, err = utility.NewMultiStateVector( - st.numParts, 3, nil, sentPartStatsVectorKey, st.kv) + st.numParts, 3, sentTransferStateMap, sentPartStatsVectorKey, st.kv) if err != nil { return nil, errors.Errorf(newSentPartStatusVectorErr, err) } @@ -230,7 +237,7 @@ func (st *SentTransfer) ReInit(numFps uint16, // Overwrite new part status MultiStateVector st.partStats, err = utility.NewMultiStateVector( - st.numParts, 3, nil, sentPartStatsVectorKey, st.kv) + st.numParts, 3, sentTransferStateMap, sentPartStatsVectorKey, st.kv) if err != nil { return errors.Errorf(reInitSentPartStatusVectorErr, err) } @@ -618,7 +625,7 @@ func loadSentTransfer(tid ftCrypto.TransferID, kv *versioned.KV) (*SentTransfer, // Load the part status MultiStateVector from storage st.partStats, err = utility.LoadMultiStateVector( - nil, sentPartStatsVectorKey, st.kv) + sentTransferStateMap, sentPartStatsVectorKey, st.kv) if err != nil { return nil, errors.Errorf(loadSentPartStatusVectorErr, err) } diff --git a/storage/fileTransfer/sentTransfer_test.go b/storage/fileTransfer/sentTransfer_test.go index 9d2a72eaf420fb539ce2a738196e5e4e2ccdee52..79415befe0632d48092d75f078969f28f86f5693 100644 --- a/storage/fileTransfer/sentTransfer_test.go +++ b/storage/fileTransfer/sentTransfer_test.go @@ -46,7 +46,8 @@ func Test_NewSentTransfer(t *testing.T) { numParts, numFps := uint16(len(parts)), uint16(float64(len(parts))*1.5) fpVector, _ := utility.NewStateVector( kvPrefixed, sentFpVectorKey, uint32(numFps)) - partStats, _ := utility.NewMultiStateVector(numParts, 3, nil, sentPartStatsVectorKey, kvPrefixed) + partStats, _ := utility.NewMultiStateVector( + numParts, 3, sentTransferStateMap, sentPartStatsVectorKey, kvPrefixed) type cbFields struct { completed bool @@ -170,7 +171,8 @@ func TestSentTransfer_ReInit(t *testing.T) { numFps2 := 2 * numFps1 fpVector, _ := utility.NewStateVector( kvPrefixed, sentFpVectorKey, uint32(numFps2)) - partStats, _ := utility.NewMultiStateVector(numParts, 3, nil, sentPartStatsVectorKey, kvPrefixed) + partStats, _ := utility.NewMultiStateVector( + numParts, 3, sentTransferStateMap, sentPartStatsVectorKey, kvPrefixed) type cbFields struct { completed bool @@ -975,7 +977,8 @@ func TestSentTransfer_SetInProgress(t *testing.T) { } // Add more parts to the in-progress list - exists, err = st.SetInProgress(rid, expectedPartNums...) + expectedPartNums2 := []uint16{4, 5, 6} + exists, err = st.SetInProgress(rid, expectedPartNums2...) if err != nil { t.Errorf("SetInProgress returned an error: %+v", err) } @@ -987,9 +990,10 @@ func TestSentTransfer_SetInProgress(t *testing.T) { // Check that the number of parts were marked as in-progress is unchanged count, _ = st.partStats.GetCount(1) - if int(count) != len(expectedPartNums) { + if int(count) != len(expectedPartNums2)+len(expectedPartNums) { t.Errorf("Incorrect number of parts marked as in-progress."+ - "\nexpected: %d\nreceived: %d", len(expectedPartNums), count) + "\nexpected: %d\nreceived: %d", + len(expectedPartNums2)+len(expectedPartNums), count) } }