diff --git a/bindings/client.go b/bindings/client.go index 91293ba3d20f0b98a68f4691e3cb0e2ea589bdcd..0c0a588d32e0f98786defb2b0a37de8e41fe6482 100644 --- a/bindings/client.go +++ b/bindings/client.go @@ -23,6 +23,7 @@ import ( "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/netTime" "google.golang.org/grpc/grpclog" + "runtime/pprof" "strings" "sync" "time" @@ -523,3 +524,13 @@ func (c *Client) getSingle() (*single.Manager, error) { return c.single, nil } + +// DumpStack returns a string with the stack trace of every running thread. +func DumpStack() (string, error) { + buf := new(bytes.Buffer) + err := pprof.Lookup("goroutine").WriteTo(buf, 2) + if err!=nil{ + return "", err + } + return buf.String(), nil +} diff --git a/fileTransfer/ftMessages.pb.go b/fileTransfer/ftMessages.pb.go index b874f6fcf71cc4ade1b4fbc3ef315298e1bd58ed..4b6f510eb08daeafc86c3944291ea4159a540ca6 100644 --- a/fileTransfer/ftMessages.pb.go +++ b/fileTransfer/ftMessages.pb.go @@ -32,7 +32,7 @@ type NewFileTransfer struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - FileName string `protobuf:"bytes,1,opt,name=fileName,proto3" json:"fileName,omitempty"` // Name of the file; max 32 characters + FileName string `protobuf:"bytes,1,opt,name=fileName,proto3" json:"fileName,omitempty"` // Name of the file; max 48 characters FileType string `protobuf:"bytes,2,opt,name=fileType,proto3" json:"fileType,omitempty"` // Type of file; max 8 characters TransferKey []byte `protobuf:"bytes,3,opt,name=transferKey,proto3" json:"transferKey,omitempty"` // 256 bit encryption key to identify the transfer TransferMac []byte `protobuf:"bytes,4,opt,name=transferMac,proto3" json:"transferMac,omitempty"` // 256 bit MAC of the entire file diff --git a/fileTransfer/manager_test.go b/fileTransfer/manager_test.go index 59c93026be48b464557b24cdb64d2e7fb07b0154..eaec128086306cd3579b7925a0c637a3d4eb999b 100644 --- a/fileTransfer/manager_test.go +++ b/fileTransfer/manager_test.go @@ -647,7 +647,7 @@ func Test_FileTransfer(t *testing.T) { defer wg.Done() for i := 0; i < 20; i++ { select { - case <-time.NewTimer(350 * time.Millisecond).C: + case <-time.NewTimer(450 * time.Millisecond).C: t.Errorf("Timed out waiting for receive progress callback %d.", i) case r := <-receiveCbChan: if r.completed { diff --git a/fileTransfer/receive.go b/fileTransfer/receive.go index 84ff35a626a5a75c19894ffa1bd1392cb2ab908f..f1b2257dc43d384332f3d970caa4fda0b118d317 100644 --- a/fileTransfer/receive.go +++ b/fileTransfer/receive.go @@ -31,7 +31,7 @@ func (m *Manager) receive(rawMsgs chan message.Receive, stop *stoppable.Single) select { case <-stop.Quit(): jww.DEBUG.Print("Stopping file part reception thread: stoppable " + - "triggered") + "triggered.") stop.ToStopped() return case receiveMsg := <-rawMsgs: diff --git a/fileTransfer/send.go b/fileTransfer/send.go index 4e9fada3c0c6f9d73c050f5f82517fa36ce1c5cb..99bfd34b5a9ab34e0de538408ff831fe3b055d66 100644 --- a/fileTransfer/send.go +++ b/fileTransfer/send.go @@ -439,7 +439,9 @@ func partitionFile(file []byte, partSize int) [][]byte { buff := bytes.NewBuffer(file) for n := buff.Next(partSize); len(n) > 0; n = buff.Next(partSize) { - parts = append(parts, n) + newPart := make([]byte, partSize) + copy(newPart, n) + parts = append(parts, newPart) } return parts diff --git a/fileTransfer/send_test.go b/fileTransfer/send_test.go index 56d6b2ef582f223c6ba8db250eabd4a0c4f76da2..51ab9f3aaac046cdea5126ca8469248de00803ee 100644 --- a/fileTransfer/send_test.go +++ b/fileTransfer/send_test.go @@ -122,7 +122,7 @@ func TestManager_sendThread_Timeout(t *testing.T) { go func(i int, st sentTransferInfo) { for j := 0; j < 2; j++ { select { - case <-time.NewTimer(20*time.Millisecond + pollSleepDuration).C: + case <-time.NewTimer(80*time.Millisecond + pollSleepDuration).C: t.Errorf("Timed out waiting for callback #%d", i) case r := <-st.cbChan: if j > 0 { diff --git a/storage/fileTransfer/partStore.go b/storage/fileTransfer/partStore.go index dfda88cd1b302bd5ad0b6f69cdc392afb9729f64..72765fffeae536e47ec71910c888e8108beba1a5 100644 --- a/storage/fileTransfer/partStore.go +++ b/storage/fileTransfer/partStore.go @@ -158,7 +158,7 @@ func loadPartStore(kv *versioned.KV) (*partStore, error) { // Load each part from storage and add to the map for _, partNum := range list { - vo, err := kv.Get(makePartsKey(partNum), partsStoreVersion) + vo, err = kv.Get(makePartsKey(partNum), partsStoreVersion) if err != nil { return nil, errors.Errorf(loadPartsErr, partNum, err) } diff --git a/storage/fileTransfer/receiveFileTransfers.go b/storage/fileTransfer/receiveFileTransfers.go index 99bc6921be29638ab849fe70916270f3a4d4a4d6..bd5c84244f11f9251b05c87e2fb9463e8ac9ab0d 100644 --- a/storage/fileTransfer/receiveFileTransfers.go +++ b/storage/fileTransfer/receiveFileTransfers.go @@ -10,6 +10,7 @@ package fileTransfer import ( "bytes" "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/storage/versioned" ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" "gitlab.com/elixxir/primitives/format" @@ -134,6 +135,12 @@ func (rft *ReceivedFileTransfers) DeleteTransfer(tid ftCrypto.TransferID) error return errors.Errorf(getReceivedTransferErr, tid) } + // Cancel any scheduled callbacks + err := rt.StopScheduledProgressCB() + if err != nil { + jww.WARN.Print(errors.Errorf(cancelCallbackErr, tid, err)) + } + // Remove all unused fingerprints from map for n, err := rt.fpVector.Next(); err == nil; n, err = rt.fpVector.Next() { // Generate fingerprint @@ -144,7 +151,7 @@ func (rft *ReceivedFileTransfers) DeleteTransfer(tid ftCrypto.TransferID) error } // Delete all data the transfer saved to storage - err := rft.transfers[tid].delete() + err = rft.transfers[tid].delete() if err != nil { return errors.Errorf(deleteReceivedTransferErr, tid, err) } diff --git a/storage/fileTransfer/receiveFileTransfers_test.go b/storage/fileTransfer/receiveFileTransfers_test.go index c5c1d8ddd9f94958222c4f7f4aa0691a84d2465f..e60036a9d51ebfdb84578203aa20067fc384b8b1 100644 --- a/storage/fileTransfer/receiveFileTransfers_test.go +++ b/storage/fileTransfer/receiveFileTransfers_test.go @@ -296,7 +296,7 @@ func TestReceivedFileTransfers_DeleteTransfer(t *testing.T) { // Check that all the fingerprints in the info map were deleted for fpNum, fp := range ftCrypto.GenerateFingerprints(key, numFps) { - _, exists := rft.info[fp] + _, exists = rft.info[fp] if exists { t.Errorf("Part fingerprint %s (#%d) found in map when it should "+ "have been deleted.", fp, fpNum) diff --git a/storage/fileTransfer/receiveTransfer.go b/storage/fileTransfer/receiveTransfer.go index 141a52f4f8b25e637cc6ead2645ba282acaa2245..fc5b48fac25406b9edae66a65ac368ed45ddce90 100644 --- a/storage/fileTransfer/receiveTransfer.go +++ b/storage/fileTransfer/receiveTransfer.go @@ -11,6 +11,7 @@ import ( "bytes" "encoding/binary" "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces" "gitlab.com/elixxir/client/storage/utility" "gitlab.com/elixxir/client/storage/versioned" @@ -31,19 +32,29 @@ const ( // Error messages for ReceivedTransfer const ( + // NewReceivedTransfer newReceivedTransferFpVectorErr = "failed to create new StateVector for fingerprints: %+v" newReceivedTransferPartStoreErr = "failed to create new part store: %+v" newReceivedVectorErr = "failed to create new state vector for received status: %+v" - loadReceivedStoreErr = "failed to load received transfer info from storage: %+v" - loadReceivePartStoreErr = "failed to load received part store from storage: %+v" - loadReceivedVectorErr = "failed to load new received status state vector from storage: %+v" - loadReceiveFpVectorErr = "failed to load received fingerprint vector from storage: %+v" - getFileErr = "missing %d/%d parts of the file" - getTransferMacErr = "failed to verify transfer MAC" - deleteReceivedTransferInfoErr = "failed to delete received transfer info from storage: %+v" - deleteReceivedFpVectorErr = "failed to delete received fingerprint vector from storage: %+v" - deleteReceivedFilePartsErr = "failed to delete received file parts from storage: %+v" - deleteReceivedVectorErr = "failed to delete received status state vector from storage: %+v" + + // ReceivedTransfer.GetFile + getFileErr = "missing %d/%d parts of the file" + getTransferMacErr = "failed to verify transfer MAC" + + // loadReceivedTransfer + loadReceivedStoreErr = "failed to load received transfer info from storage: %+v" + loadReceivePartStoreErr = "failed to load received part store from storage: %+v" + loadReceivedVectorErr = "failed to load new received status state vector from storage: %+v" + loadReceiveFpVectorErr = "failed to load received fingerprint vector from storage: %+v" + + // ReceivedTransfer.delete + deleteReceivedTransferInfoErr = "failed to delete received transfer info from storage: %+v" + deleteReceivedFpVectorErr = "failed to delete received fingerprint vector from storage: %+v" + deleteReceivedFilePartsErr = "failed to delete received file parts from storage: %+v" + deleteReceivedVectorErr = "failed to delete received status state vector from storage: %+v" + + // ReceivedTransfer.StopScheduledProgressCB + cancelReceivedCallbacksErr = "could not cancel %d out of %d received progress callbacks: %d" ) // ReceivedTransfer contains information and progress data for receiving an in- @@ -191,7 +202,15 @@ func (rt *ReceivedTransfer) GetProgress() (completed bool, received, rt.mux.RLock() defer rt.mux.RUnlock() - received = uint16(rt.fpVector.GetNumUsed()) + completed, received, total, t = rt.getProgress() + return completed, received, total, t +} + +// getProgress is the thread-unsafe helper function for GetProgress. +func (rt *ReceivedTransfer) getProgress() (completed bool, received, + total uint16, t ReceivedPartTracker) { + + received = uint16(rt.receivedStatus.GetNumUsed()) total = rt.numParts if received == total { @@ -212,6 +231,31 @@ func (rt *ReceivedTransfer) CallProgressCB(err error) { } } +// StopScheduledProgressCB cancels all scheduled received progress callbacks +// calls. +func (rt *ReceivedTransfer) StopScheduledProgressCB() error { + rt.mux.Lock() + defer rt.mux.Unlock() + + // Tracks the index of callbacks that failed to stop + var failedCallbacks []int + + for i, cb := range rt.progressCallbacks { + err := cb.stopThread() + if err != nil { + failedCallbacks = append(failedCallbacks, i) + jww.WARN.Print(err.Error()) + } + } + + if len(failedCallbacks) > 0 { + return errors.Errorf(cancelReceivedCallbacksErr, len(failedCallbacks), + len(rt.progressCallbacks), failedCallbacks) + } + + return nil +} + // AddProgressCB appends a new interfaces.ReceivedProgressCallback to the list // of progress callbacks to be called and calls it. The period is how often the // callback should be called when there are updates. diff --git a/storage/fileTransfer/receiveTransfer_test.go b/storage/fileTransfer/receiveTransfer_test.go index e8e6b1fedac82896c22b5e611dfa402990b21afa..e3cff77508170c0b02153977e6c028247d472499 100644 --- a/storage/fileTransfer/receiveTransfer_test.go +++ b/storage/fileTransfer/receiveTransfer_test.go @@ -335,6 +335,7 @@ func TestReceivedTransfer_CallProgressCB(t *testing.T) { type progressResults struct { completed bool received, total uint16 + tr interfaces.FilePartTracker err error } @@ -348,8 +349,8 @@ func TestReceivedTransfer_CallProgressCB(t *testing.T) { progressChan := make(chan progressResults) cbFunc := func(completed bool, received, total uint16, - t interfaces.FilePartTracker, err error) { - progressChan <- progressResults{completed, received, total, err} + tr interfaces.FilePartTracker, err error) { + progressChan <- progressResults{completed, received, total, tr, err} } wg.Add(1) @@ -367,25 +368,25 @@ func TestReceivedTransfer_CallProgressCB(t *testing.T) { case 0: if err := checkReceivedProgress(r.completed, r.received, r.total, false, 0, rt.numParts); err != nil { - t.Errorf("%2d: %+v", i, err) + t.Errorf("%2d: %v", i, err) } atomic.AddUint64(&step0, 1) case 1: if err := checkReceivedProgress(r.completed, r.received, r.total, false, 0, rt.numParts); err != nil { - t.Errorf("%2d: %+v", i, err) + t.Errorf("%2d: %v", i, err) } atomic.AddUint64(&step1, 1) case 2: if err := checkReceivedProgress(r.completed, r.received, r.total, false, 4, rt.numParts); err != nil { - t.Errorf("%2d: %+v", i, err) + t.Errorf("%2d: %v", i, err) } atomic.AddUint64(&step2, 1) case 3: if err := checkReceivedProgress(r.completed, r.received, r.total, true, 16, rt.numParts); err != nil { - t.Errorf("%2d: %+v", i, err) + t.Errorf("%2d: %v", i, err) } atomic.AddUint64(&step3, 1) return @@ -410,7 +411,7 @@ func TestReceivedTransfer_CallProgressCB(t *testing.T) { } for i := 0; i < 4; i++ { - _, _ = rt.fpVector.Next() + _, _ = rt.receivedStatus.Next() } rt.CallProgressCB(nil) @@ -419,7 +420,7 @@ func TestReceivedTransfer_CallProgressCB(t *testing.T) { } for i := 0; i < 12; i++ { - _, _ = rt.fpVector.Next() + _, _ = rt.receivedStatus.Next() } rt.CallProgressCB(nil) @@ -427,6 +428,45 @@ func TestReceivedTransfer_CallProgressCB(t *testing.T) { wg.Wait() } +// Tests that ReceivedTransfer.StopScheduledProgressCB stops a scheduled +// callback from being triggered. +func TestReceivedTransfer_StopScheduledProgressCB(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, rt, _ := newEmptyReceivedTransfer(16, 20, kv, t) + + cbChan := make(chan struct{}, 5) + cbFunc := interfaces.ReceivedProgressCallback( + func(completed bool, received, total uint16, + t interfaces.FilePartTracker, err error) { + cbChan <- struct{}{} + }) + rt.AddProgressCB(cbFunc, 150*time.Millisecond) + select { + case <-time.NewTimer(10 * time.Millisecond).C: + t.Error("Timed out waiting for callback.") + case <-cbChan: + } + + rt.CallProgressCB(nil) + rt.CallProgressCB(nil) + select { + case <-time.NewTimer(10 * time.Millisecond).C: + t.Error("Timed out waiting for callback.") + case <-cbChan: + } + + err := rt.StopScheduledProgressCB() + if err != nil { + t.Errorf("StopScheduledProgressCB returned an error: %+v", err) + } + + select { + case <-time.NewTimer(200 * time.Millisecond).C: + case <-cbChan: + t.Error("Callback called when it should have been stopped.") + } +} + // Tests that ReceivedTransfer.AddProgressCB adds an item to the progress // callback list. func TestReceivedTransfer_AddProgressCB(t *testing.T) { diff --git a/storage/fileTransfer/receivedCallbackTracker.go b/storage/fileTransfer/receivedCallbackTracker.go index ac4295c8dbaa20f0e378c5410f4da7de93be2d6a..a0b3f558c212d8ee1603b01c80eacfc00bfa10e2 100644 --- a/storage/fileTransfer/receivedCallbackTracker.go +++ b/storage/fileTransfer/receivedCallbackTracker.go @@ -9,11 +9,15 @@ package fileTransfer import ( "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/xx_network/primitives/netTime" "sync" "time" ) +// receivedCallbackTrackerStoppable is the name used for the tracker stoppable. +const receivedCallbackTrackerStoppable = "receivedCallbackTrackerStoppable" + // receivedCallbackTracker tracks the interfaces.ReceivedProgressCallback and // information on when to call it. The callback will be called on each file part // reception, unless the time since the lastCall is smaller than the period. In @@ -21,9 +25,10 @@ import ( // end of the period. A callback is called once every period, regardless of the // number of receptions that occur. type receivedCallbackTracker struct { - period time.Duration // How often to call the callback - lastCall time.Time // Timestamp of the last call - scheduled bool // Denotes if callback call is scheduled + period time.Duration // How often to call the callback + lastCall time.Time // Timestamp of the last call + scheduled bool // Denotes if callback call is scheduled + stop *stoppable.Single // Stops the scheduled callback from triggering cb interfaces.ReceivedProgressCallback mux sync.RWMutex } @@ -35,6 +40,7 @@ func newReceivedCallbackTracker(cb interfaces.ReceivedProgressCallback, period: period, lastCall: time.Time{}, scheduled: false, + stop: stoppable.NewSingle(receivedCallbackTrackerStoppable), cb: cb, } } @@ -64,7 +70,7 @@ func (rct *receivedCallbackTracker) call(tracker receivedProgressTracker, err er timeSinceLastCall := netTime.Since(rct.lastCall) if timeSinceLastCall > rct.period { // If no callback occurred, then trigger the callback now - rct.callNow(tracker, err) + rct.callNowUnsafe(tracker, err) rct.lastCall = netTime.Now() } else { // If a callback did occur, then schedule a new callback to occur at the @@ -72,6 +78,9 @@ func (rct *receivedCallbackTracker) call(tracker receivedProgressTracker, err er rct.scheduled = true go func() { select { + case <-rct.stop.Quit(): + rct.stop.ToStopped() + return case <-time.NewTimer(rct.period - timeSinceLastCall).C: rct.mux.Lock() rct.callNow(tracker, err) @@ -83,13 +92,37 @@ func (rct *receivedCallbackTracker) call(tracker receivedProgressTracker, err er } } +// stopThread stops all scheduled callbacks. +func (rct *receivedCallbackTracker) stopThread() error { + return rct.stop.Close() +} + // callNow calls the callback immediately regardless of the schedule or period. -func (rct *receivedCallbackTracker) callNow(tracker receivedProgressTracker, err error) { +func (rct *receivedCallbackTracker) callNow( + tracker receivedProgressTracker, err error) { completed, received, total, t := tracker.GetProgress() go rct.cb(completed, received, total, t, err) } +// callNowUnsafe calls the callback immediately regardless of the schedule or +// period without taking a thread lock. This function should be used if a lock +// is already taken on the receivedProgressTracker. +func (rct *receivedCallbackTracker) callNowUnsafe( + tracker receivedProgressTracker, err error) { + completed, received, total, t := tracker.getProgress() + go rct.cb(completed, received, total, t, err) +} + // receivedProgressTracker interface tracks the progress of a transfer. type receivedProgressTracker interface { - GetProgress() (completed bool, received, total uint16, t ReceivedPartTracker) + // GetProgress returns the received transfer progress in a thread-safe + // manner. + GetProgress() ( + completed bool, received, total uint16, t ReceivedPartTracker) + + // getProgress returns the received transfer progress in a thread-unsafe + // manner. This function should be used if a lock is already taken on the + // sent transfer. + getProgress() ( + completed bool, received, total uint16, t ReceivedPartTracker) } diff --git a/storage/fileTransfer/receivedCallbackTracker_test.go b/storage/fileTransfer/receivedCallbackTracker_test.go index 80bf43e8211a249553b43fcd0fb67d44df3b227a..8809f537ca733f78087aaaf4ab931ac034054152 100644 --- a/storage/fileTransfer/receivedCallbackTracker_test.go +++ b/storage/fileTransfer/receivedCallbackTracker_test.go @@ -36,9 +36,9 @@ func Test_newReceivedCallbackTracker(t *testing.T) { cb: cbFunc, } - receivedSCT := newReceivedCallbackTracker(expectedRCT.cb, expectedRCT.period) + receivedRCT := newReceivedCallbackTracker(expectedRCT.cb, expectedRCT.period) - go receivedSCT.cb(false, 0, 0, nil, nil) + go receivedRCT.cb(false, 0, 0, nil, nil) select { case <-time.NewTimer(time.Millisecond).C: @@ -51,16 +51,20 @@ func Test_newReceivedCallbackTracker(t *testing.T) { } // Nil the callbacks so that DeepEqual works - receivedSCT.cb = nil + receivedRCT.cb = nil expectedRCT.cb = nil - if !reflect.DeepEqual(expectedRCT, receivedSCT) { + receivedRCT.stop = expectedRCT.stop + + if !reflect.DeepEqual(expectedRCT, receivedRCT) { t.Errorf("New receivedCallbackTracker does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedRCT, receivedSCT) + "\nexpected: %+v\nreceived: %+v", expectedRCT, receivedRCT) } } -// Tests that receivedCallbackTracker.call triggers the proper callback. +// Tests that receivedCallbackTracker.call calls the tracker immediately when +// no other calls are scheduled and that it schedules a call to the tracker when +// one has been called recently. func Test_receivedCallbackTracker_call(t *testing.T) { type cbFields struct { completed bool @@ -74,10 +78,10 @@ func Test_receivedCallbackTracker_call(t *testing.T) { cbChan <- cbFields{completed, received, total, err} } - sct := newReceivedCallbackTracker(cbFunc, 50*time.Millisecond) + rct := newReceivedCallbackTracker(cbFunc, 50*time.Millisecond) tracker := testReceiveTrack{false, 1, 3, ReceivedPartTracker{}} - sct.call(tracker, nil) + rct.call(tracker, nil) select { case <-time.NewTimer(10 * time.Millisecond).C: @@ -90,19 +94,19 @@ func Test_receivedCallbackTracker_call(t *testing.T) { } tracker = testReceiveTrack{true, 3, 3, ReceivedPartTracker{}} - sct.call(tracker, nil) + rct.call(tracker, nil) select { case <-time.NewTimer(10 * time.Millisecond).C: - if !sct.scheduled { + if !rct.scheduled { t.Error("Callback should be scheduled.") } case r := <-cbChan: t.Errorf("Received message when period of %s should not have been "+ - "reached: %+v", sct.period, r) + "reached: %+v", rct.period, r) } - sct.call(tracker, nil) + rct.call(tracker, nil) select { case <-time.NewTimer(60 * time.Millisecond).C: @@ -115,6 +119,64 @@ func Test_receivedCallbackTracker_call(t *testing.T) { } } +// Tests that receivedCallbackTracker.stopThread prevents a scheduled call to +// the tracker from occurring. +func Test_receivedCallbackTracker_stopThread(t *testing.T) { + type cbFields struct { + completed bool + received, total uint16 + err error + } + + cbChan := make(chan cbFields) + cbFunc := func(completed bool, received, total uint16, + t interfaces.FilePartTracker, err error) { + cbChan <- cbFields{completed, received, total, err} + } + + rct := newReceivedCallbackTracker(cbFunc, 50*time.Millisecond) + + tracker := testReceiveTrack{false, 1, 3, ReceivedPartTracker{}} + rct.call(tracker, nil) + + select { + case <-time.NewTimer(10 * time.Millisecond).C: + t.Error("Timed out waiting for callback to be called.") + case r := <-cbChan: + err := checkReceivedProgress(r.completed, r.received, r.total, false, 1, 3) + if err != nil { + t.Error(err) + } + } + + tracker = testReceiveTrack{true, 3, 3, ReceivedPartTracker{}} + rct.call(tracker, nil) + + select { + case <-time.NewTimer(10 * time.Millisecond).C: + if !rct.scheduled { + t.Error("Callback should be scheduled.") + } + case r := <-cbChan: + t.Errorf("Received message when period of %s should not have been "+ + "reached: %+v", rct.period, r) + } + + rct.call(tracker, nil) + + err := rct.stopThread() + if err != nil { + t.Errorf("stopThread returned an error: %+v", err) + } + + select { + case <-time.NewTimer(60 * time.Millisecond).C: + case r := <-cbChan: + t.Errorf("Received message when period of %s should not have been "+ + "reached: %+v", rct.period, r) + } +} + // Tests that ReceivedTransfer satisfies the receivedProgressTracker interface. func TestReceivedTransfer_ReceivedProgressTrackerInterface(t *testing.T) { var _ receivedProgressTracker = &ReceivedTransfer{} @@ -128,7 +190,13 @@ type testReceiveTrack struct { t ReceivedPartTracker } +func (trt testReceiveTrack) getProgress() (completed bool, received, + total uint16, t ReceivedPartTracker) { + return trt.completed, trt.received, trt.total, trt.t +} + // GetProgress returns the values in the testTrack. -func (trt testReceiveTrack) GetProgress() (completed bool, received, total uint16, t ReceivedPartTracker) { +func (trt testReceiveTrack) GetProgress() (completed bool, received, + total uint16, t ReceivedPartTracker) { return trt.completed, trt.received, trt.total, trt.t } diff --git a/storage/fileTransfer/sentCallbackTracker.go b/storage/fileTransfer/sentCallbackTracker.go index 53e210b2ed905258437c7d6fb1331838bd453c1e..51b25ed453392e46600276dd0c638f5508425d0d 100644 --- a/storage/fileTransfer/sentCallbackTracker.go +++ b/storage/fileTransfer/sentCallbackTracker.go @@ -9,11 +9,15 @@ package fileTransfer import ( "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/xx_network/primitives/netTime" "sync" "time" ) +// sentCallbackTrackerStoppable is the name used for the tracker stoppable. +const sentCallbackTrackerStoppable = "sentCallbackTrackerStoppable" + // sentCallbackTracker tracks the interfaces.SentProgressCallback and // information on when to call it. The callback will be called on each send, // unless the time since the lastCall is smaller than the period. In that case, @@ -21,9 +25,10 @@ import ( // period. A callback is called once every period, regardless of the number of // sends that occur. type sentCallbackTracker struct { - period time.Duration // How often to call the callback - lastCall time.Time // Timestamp of the last call - scheduled bool // Denotes if callback call is scheduled + period time.Duration // How often to call the callback + lastCall time.Time // Timestamp of the last call + scheduled bool // Denotes if callback call is scheduled + stop *stoppable.Single // Stops the scheduled callback from triggering cb interfaces.SentProgressCallback mux sync.RWMutex } @@ -35,6 +40,7 @@ func newSentCallbackTracker(cb interfaces.SentProgressCallback, period: period, lastCall: time.Time{}, scheduled: false, + stop: stoppable.NewSingle(sentCallbackTrackerStoppable), cb: cb, } } @@ -64,7 +70,7 @@ func (sct *sentCallbackTracker) call(tracker sentProgressTracker, err error) { timeSinceLastCall := netTime.Since(sct.lastCall) if timeSinceLastCall > sct.period { // If no callback occurred, then trigger the callback now - sct.callNow(tracker, err) + sct.callNowUnsafe(tracker, err) sct.lastCall = netTime.Now() } else { // If a callback did occur, then schedule a new callback to occur at the @@ -72,6 +78,9 @@ func (sct *sentCallbackTracker) call(tracker sentProgressTracker, err error) { sct.scheduled = true go func() { select { + case <-sct.stop.Quit(): + sct.stop.ToStopped() + return case <-time.NewTimer(sct.period - timeSinceLastCall).C: sct.mux.Lock() sct.callNow(tracker, err) @@ -83,8 +92,14 @@ func (sct *sentCallbackTracker) call(tracker sentProgressTracker, err error) { } } +// stopThread stops all scheduled callbacks. +func (sct *sentCallbackTracker) stopThread() error { + return sct.stop.Close() +} + // callNow calls the callback immediately regardless of the schedule or period. -func (sct *sentCallbackTracker) callNow(tracker sentProgressTracker, err error) { +func (sct *sentCallbackTracker) callNow( + tracker sentProgressTracker, err error) { completed, sent, arrived, total, t := tracker.GetProgress() go sct.cb(completed, sent, arrived, total, t, err) } @@ -92,7 +107,8 @@ func (sct *sentCallbackTracker) callNow(tracker sentProgressTracker, err error) // callNowUnsafe calls the callback immediately regardless of the schedule or // period without taking a thread lock. This function should be used if a lock // is already taken on the sentProgressTracker. -func (sct *sentCallbackTracker) callNowUnsafe(tracker sentProgressTracker, err error) { +func (sct *sentCallbackTracker) callNowUnsafe( + tracker sentProgressTracker, err error) { completed, sent, arrived, total, t := tracker.getProgress() go sct.cb(completed, sent, arrived, total, t, err) } @@ -100,10 +116,12 @@ func (sct *sentCallbackTracker) callNowUnsafe(tracker sentProgressTracker, err e // sentProgressTracker interface tracks the progress of a transfer. type sentProgressTracker interface { // GetProgress returns the sent transfer progress in a thread-safe manner. - GetProgress() (completed bool, sent, arrived, total uint16, t SentPartTracker) + GetProgress() ( + completed bool, sent, arrived, total uint16, t SentPartTracker) // getProgress returns the sent transfer progress in a thread-unsafe manner. // This function should be used if a lock is already taken on the sent // transfer. - getProgress() (completed bool, sent, arrived, total uint16, t SentPartTracker) + getProgress() ( + completed bool, sent, arrived, total uint16, t SentPartTracker) } diff --git a/storage/fileTransfer/sentCallbackTracker_test.go b/storage/fileTransfer/sentCallbackTracker_test.go index a99726e2b73a1c4d298094b008cae6611aa2cd46..237cea82ff66c51e41b7dd98d8a60bc3ec524818 100644 --- a/storage/fileTransfer/sentCallbackTracker_test.go +++ b/storage/fileTransfer/sentCallbackTracker_test.go @@ -9,6 +9,7 @@ package fileTransfer import ( "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/stoppable" "reflect" "testing" "time" @@ -33,6 +34,7 @@ func Test_newSentCallbackTracker(t *testing.T) { period: time.Millisecond, lastCall: time.Time{}, scheduled: false, + stop: stoppable.NewSingle(sentCallbackTrackerStoppable), cb: cbFunc, } @@ -55,13 +57,17 @@ func Test_newSentCallbackTracker(t *testing.T) { receivedSCT.cb = nil expectedSCT.cb = nil + receivedSCT.stop = expectedSCT.stop + if !reflect.DeepEqual(expectedSCT, receivedSCT) { t.Errorf("New sentCallbackTracker does not match expected."+ "\nexpected: %+v\nreceived: %+v", expectedSCT, receivedSCT) } } -// Tests that +// Tests that sentCallbackTracker.call calls the tracker immediately when no +// other calls are scheduled and that it schedules a call to the tracker when +// one has been called recently. func Test_sentCallbackTracker_call(t *testing.T) { type cbFields struct { completed bool @@ -118,6 +124,65 @@ func Test_sentCallbackTracker_call(t *testing.T) { } } +// Tests that sentCallbackTracker.stopThread prevents a scheduled call to the +// tracker from occurring. +func Test_sentCallbackTracker_stopThread(t *testing.T) { + type cbFields struct { + completed bool + sent, arrived, total uint16 + err error + } + + cbChan := make(chan cbFields) + cbFunc := func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { + cbChan <- cbFields{completed, sent, arrived, total, err} + } + + sct := newSentCallbackTracker(cbFunc, 50*time.Millisecond) + + tracker := testSentTrack{false, 1, 2, 3, SentPartTracker{}} + sct.call(tracker, nil) + + select { + case <-time.NewTimer(10 * time.Millisecond).C: + t.Error("Timed out waiting for callback to be called.") + case r := <-cbChan: + err := checkSentProgress( + r.completed, r.sent, r.arrived, r.total, false, 1, 2, 3) + if err != nil { + t.Error(err) + } + } + + tracker = testSentTrack{false, 1, 2, 3, SentPartTracker{}} + sct.call(tracker, nil) + + select { + case <-time.NewTimer(10 * time.Millisecond).C: + if !sct.scheduled { + t.Error("Callback should be scheduled.") + } + case r := <-cbChan: + t.Errorf("Received message when period of %s should not have been "+ + "reached: %+v", sct.period, r) + } + + sct.call(tracker, nil) + + err := sct.stopThread() + if err != nil { + t.Errorf("stopThread returned an error: %+v", err) + } + + select { + case <-time.NewTimer(60 * time.Millisecond).C: + case r := <-cbChan: + t.Errorf("Received message when period of %s should not have been "+ + "reached: %+v", sct.period, r) + } +} + // Tests that SentTransfer satisfies the sentProgressTracker interface. func TestSentTransfer_SentProgressTrackerInterface(t *testing.T) { var _ sentProgressTracker = &SentTransfer{} diff --git a/storage/fileTransfer/sentFileTransfers.go b/storage/fileTransfer/sentFileTransfers.go index 96aa1886cc4313b6acb43218adfb9eae63bdb307..c8a858f017ac62b63cef4af327cb83212519d654 100644 --- a/storage/fileTransfer/sentFileTransfers.go +++ b/storage/fileTransfer/sentFileTransfers.go @@ -10,6 +10,7 @@ package fileTransfer import ( "bytes" "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces" "gitlab.com/elixxir/client/storage/versioned" ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" @@ -35,6 +36,7 @@ const ( newSentTransferErr = "failed to create new sent transfer: %+v" getSentTransferErr = "sent file transfer not found" + cancelCallbackErr = "Transfer with ID %s: %+v" deleteSentTransferErr = "failed to delete sent transfer with ID %s from store: %+v" ) @@ -109,13 +111,19 @@ func (sft *SentFileTransfers) DeleteTransfer(tid ftCrypto.TransferID) error { defer sft.mux.Unlock() // Return an error if the transfer does not exist - _, exists := sft.transfers[tid] + st, exists := sft.transfers[tid] if !exists { return errors.New(getSentTransferErr) } + // Cancel any scheduled callbacks + err := st.StopScheduledProgressCB() + if err != nil { + jww.WARN.Print(errors.Errorf(cancelCallbackErr, tid, err)) + } + // Delete all data the transfer saved to storage - err := sft.transfers[tid].delete() + err = st.delete() if err != nil { return errors.Errorf(deleteSentTransferErr, tid, err) } diff --git a/storage/fileTransfer/sentTransfer.go b/storage/fileTransfer/sentTransfer.go index fface05f00f2184840abc0120b5ea1599f55ff99..f6142743cf41b84842969cae846b918d973ea847 100644 --- a/storage/fileTransfer/sentTransfer.go +++ b/storage/fileTransfer/sentTransfer.go @@ -11,6 +11,7 @@ import ( "bytes" "encoding/binary" "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces" "gitlab.com/elixxir/client/storage/utility" "gitlab.com/elixxir/client/storage/versioned" @@ -50,6 +51,9 @@ const ( reInitSentInProgressVectorErr = "failed to overwrite in-progress state vector with new vector: %+v" reInitSentFinishedVectorErr = "failed to overwrite finished state vector with new vector: %+v" + // SentTransfer.StopScheduledProgressCB + cancelSentCallbacksErr = "could not cancel %d out of %d sent progress callbacks: %d" + // loadSentTransfer loadSentStoreErr = "failed to load sent transfer info from storage: %+v" loadSentFpVectorErr = "failed to load sent fingerprint vector from storage: %+v" @@ -213,6 +217,7 @@ func (st *SentTransfer) ReInit(numFps uint16, progressCB interfaces.SentProgressCallback, period time.Duration) error { st.mux.Lock() defer st.mux.Unlock() + var err error // Mark the status as running @@ -332,22 +337,24 @@ func (st *SentTransfer) GetProgress() (completed bool, sent, arrived, st.mux.RLock() defer st.mux.RUnlock() - return st.getProgress() + completed, sent, arrived, total, t = st.getProgress() + return completed, sent, arrived, total, t } // getProgress is the thread-unsafe helper function for GetProgress. func (st *SentTransfer) getProgress() (completed bool, sent, arrived, total uint16, t SentPartTracker) { - arrived = st.finishedTransfers.getNumParts() - sent = st.inProgressTransfers.getNumParts() + arrived = uint16(st.finishedStatus.GetNumUsed()) + sent = uint16(st.inProgressStatus.GetNumUsed()) total = st.numParts if sent == 0 && arrived == total { completed = true } - return completed, sent, arrived, total, - NewSentPartTracker(st.inProgressStatus, st.finishedStatus) + partTracker := NewSentPartTracker(st.inProgressStatus, st.finishedStatus) + + return completed, sent, arrived, total, partTracker } // CallProgressCB calls all the progress callbacks with the most recent progress @@ -372,6 +379,30 @@ func (st *SentTransfer) CallProgressCB(err error) { } } +// StopScheduledProgressCB cancels all scheduled sent progress callbacks calls. +func (st *SentTransfer) StopScheduledProgressCB() error { + st.mux.Lock() + defer st.mux.Unlock() + + // Tracks the index of callbacks that failed to stop + var failedCallbacks []int + + for i, cb := range st.progressCallbacks { + err := cb.stopThread() + if err != nil { + failedCallbacks = append(failedCallbacks, i) + jww.WARN.Print(err.Error()) + } + } + + if len(failedCallbacks) > 0 { + return errors.Errorf(cancelSentCallbacksErr, len(failedCallbacks), + len(st.progressCallbacks), failedCallbacks) + } + + return nil +} + // AddProgressCB appends a new interfaces.SentProgressCallback to the list of // progress callbacks to be called and calls it. The period is how often the // callback should be called when there are updates. diff --git a/storage/fileTransfer/sentTransfer_test.go b/storage/fileTransfer/sentTransfer_test.go index 10a76c2a6e442d15ec8200cedc60616e39178b7d..e25307c794c1cc6d2f7d140af9faabc9d4dff8bf 100644 --- a/storage/fileTransfer/sentTransfer_test.go +++ b/storage/fileTransfer/sentTransfer_test.go @@ -619,6 +619,45 @@ func TestSentTransfer_CallProgressCB(t *testing.T) { wg.Wait() } +// Tests that SentTransfer.StopScheduledProgressCB stops a scheduled callback +// from being triggered. +func TestSentTransfer_StopScheduledProgressCB(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + _, st := newRandomSentTransfer(16, 24, kv, t) + + cbChan := make(chan struct{}, 5) + cbFunc := interfaces.SentProgressCallback( + func(completed bool, sent, arrived, total uint16, + t interfaces.FilePartTracker, err error) { + cbChan <- struct{}{} + }) + st.AddProgressCB(cbFunc, 150*time.Millisecond) + select { + case <-time.NewTimer(10 * time.Millisecond).C: + t.Error("Timed out waiting for callback.") + case <-cbChan: + } + + st.CallProgressCB(nil) + st.CallProgressCB(nil) + select { + case <-time.NewTimer(10 * time.Millisecond).C: + t.Error("Timed out waiting for callback.") + case <-cbChan: + } + + err := st.StopScheduledProgressCB() + if err != nil { + t.Errorf("StopScheduledProgressCB returned an error: %+v", err) + } + + select { + case <-time.NewTimer(200 * time.Millisecond).C: + case <-cbChan: + t.Error("Callback called when it should have been stopped.") + } +} + // Tests that SentTransfer.AddProgressCB adds an item to the progress callback // list. func TestSentTransfer_AddProgressCB(t *testing.T) { @@ -1165,7 +1204,7 @@ func Test_loadSentTransfer_LoadInProgressTransfersError(t *testing.T) { } // Error path: tests that loadSentTransfer returns the expected error when the -// finished transfers bundle was deleted from storage. +// finished transfer bundle was deleted from storage. func Test_loadSentTransfer_LoadFinishedTransfersError(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) tid, st := newRandomSentTransfer(16, 24, kv, t)