From 930ef2e68c64f3e64eb4bc58ac8903d8834705ad Mon Sep 17 00:00:00 2001
From: Jono Wenger <jono@elixxir.io>
Date: Wed, 22 Jun 2022 19:46:14 +0000
Subject: [PATCH] Fix duplicate file transfer callback bug

---
 .../callbackTracker/callbackTracker.go        |  4 +--
 fileTransfer/manager.go                       | 10 ++++++
 fileTransfer/manager_test.go                  |  2 --
 fileTransfer/params_test.go                   | 13 +++----
 fileTransfer/store/receivedTransfer.go        | 35 +++++++++++++++++++
 fileTransfer/store/sentTransfer.go            | 35 +++++++++++++++++++
 6 files changed, 86 insertions(+), 13 deletions(-)

diff --git a/fileTransfer/callbackTracker/callbackTracker.go b/fileTransfer/callbackTracker/callbackTracker.go
index 6cf7a641c..441ee157e 100644
--- a/fileTransfer/callbackTracker/callbackTracker.go
+++ b/fileTransfer/callbackTracker/callbackTracker.go
@@ -74,7 +74,7 @@ func (ct *callbackTracker) call(err error) {
 	if timeSinceLastCall > ct.period {
 
 		// If no callback occurred, then trigger the callback now
-		go ct.cb(err)
+		ct.cb(err)
 		ct.lastCall = netTime.Now()
 	} else {
 		// If a callback did occur, then schedule a new callback to occur at the
@@ -89,7 +89,7 @@ func (ct *callbackTracker) call(err error) {
 				return
 			case <-timer.C:
 				ct.mux.Lock()
-				go ct.cb(err)
+				ct.cb(err)
 				ct.lastCall = netTime.Now()
 				ct.scheduled = false
 				ct.mux.Unlock()
diff --git a/fileTransfer/manager.go b/fileTransfer/manager.go
index 9325a5763..afae51d80 100644
--- a/fileTransfer/manager.go
+++ b/fileTransfer/manager.go
@@ -358,6 +358,11 @@ func (m *manager) registerSentProgressCallback(st *store.SentTransfer,
 		// Build part tracker from copy of part statuses vector
 		tracker := &sentFilePartTracker{st.CopyPartStatusVector()}
 
+		// If the callback data is the same as the last call, skip the call
+		if !st.CompareAndSwapCallbackFps(completed, arrived, total, err) {
+			return
+		}
+
 		// Call the progress callback
 		progressCB(completed, arrived, total, st, tracker, err)
 	}
@@ -518,6 +523,11 @@ func (m *manager) registerReceivedProgressCallback(rt *store.ReceivedTransfer,
 		// Build part tracker from copy of part statuses vector
 		tracker := &receivedFilePartTracker{rt.CopyPartStatusVector()}
 
+		// If the callback data is the same as the last call, skip the call
+		if !rt.CompareAndSwapCallbackFps(completed, received, total, err) {
+			return
+		}
+
 		// Call the progress callback
 		progressCB(completed, received, total, rt, tracker, err)
 	}
diff --git a/fileTransfer/manager_test.go b/fileTransfer/manager_test.go
index 94593da1d..f9699a626 100644
--- a/fileTransfer/manager_test.go
+++ b/fileTransfer/manager_test.go
@@ -85,8 +85,6 @@ func Test_FileTransfer_Smoke(t *testing.T) {
 	cMixHandler := newMockCmixHandler()
 	rngGen := fastRNG.NewStreamGenerator(1000, 10, csprng.NewSystemRNG)
 	params := DefaultParams()
-	// params.MaxThroughput = math.MaxInt
-	// params.MaxThroughput = 0
 
 	// Set up the first client
 	myID1 := id.NewIdFromString("myID1", id.User, t)
diff --git a/fileTransfer/params_test.go b/fileTransfer/params_test.go
index c7015a449..965ff70f9 100644
--- a/fileTransfer/params_test.go
+++ b/fileTransfer/params_test.go
@@ -15,8 +15,8 @@ import (
 	"testing"
 )
 
-// Tests that no data is lost when marshaling and
-// unmarshaling the Params object.
+// Tests that no data is lost when marshaling and unmarshalling the Params
+// object.
 func TestParams_MarshalUnmarshal(t *testing.T) {
 	// Construct a set of params
 	p := DefaultParams()
@@ -27,8 +27,6 @@ func TestParams_MarshalUnmarshal(t *testing.T) {
 		t.Fatalf("Marshal error: %v", err)
 	}
 
-	t.Logf("%s", string(data))
-
 	// Unmarshal the params object
 	received := Params{}
 	err = json.Unmarshal(data, &received)
@@ -42,11 +40,8 @@ func TestParams_MarshalUnmarshal(t *testing.T) {
 		t.Fatalf("Marshal error: %v", err)
 	}
 
-	t.Logf("%s", string(data2))
-
-	// Check that they match (it is done this way to avoid
-	// false failures with the reflect.DeepEqual function and
-	// pointers)
+	// Check that they match (it is done this way to avoid false failures with
+	// the reflect.DeepEqual function and pointers)
 	if !bytes.Equal(data, data2) {
 		t.Fatalf("Data was lost in marshal/unmarshal.")
 	}
diff --git a/fileTransfer/store/receivedTransfer.go b/fileTransfer/store/receivedTransfer.go
index 396925d37..db6f54264 100644
--- a/fileTransfer/store/receivedTransfer.go
+++ b/fileTransfer/store/receivedTransfer.go
@@ -85,6 +85,10 @@ type ReceivedTransfer struct {
 	// Stores the received status for each file part in a bitstream format
 	partStatus *utility.StateVector
 
+	// Unique identifier of the last progress callback called (used to prevent
+	// callback calls with duplicate data)
+	lastCallbackFingerprint string
+
 	mux sync.RWMutex
 	kv  *versioned.KV
 }
@@ -202,6 +206,37 @@ func (rt *ReceivedTransfer) CopyPartStatusVector() *utility.StateVector {
 	return rt.partStatus.DeepCopy()
 }
 
+// CompareAndSwapCallbackFps compares the fingerprint to the previous callback
+// call's fingerprint. If they are different, the new one is stored, and it
+// returns true. Returns fall if they are the same.
+func (rt *ReceivedTransfer) CompareAndSwapCallbackFps(
+	completed bool, received, total uint16, err error) bool {
+	fp := generateReceivedFp(completed, received, total, err)
+
+	rt.mux.Lock()
+	defer rt.mux.Unlock()
+
+	if fp != rt.lastCallbackFingerprint {
+		rt.lastCallbackFingerprint = fp
+		return true
+	}
+
+	return false
+}
+
+// generateReceivedFp generates a fingerprint for a received progress callback.
+func generateReceivedFp(completed bool, received, total uint16, err error) string {
+	errString := "<nil>"
+	if err != nil {
+		errString = err.Error()
+	}
+
+	return strconv.FormatBool(completed) +
+		strconv.FormatUint(uint64(received), 10) +
+		strconv.FormatUint(uint64(total), 10) +
+		errString
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 // Storage Functions                                                          //
 ////////////////////////////////////////////////////////////////////////////////
diff --git a/fileTransfer/store/sentTransfer.go b/fileTransfer/store/sentTransfer.go
index 05046f183..83b9594fb 100644
--- a/fileTransfer/store/sentTransfer.go
+++ b/fileTransfer/store/sentTransfer.go
@@ -18,6 +18,7 @@ import (
 	ftCrypto "gitlab.com/elixxir/crypto/fileTransfer"
 	"gitlab.com/xx_network/primitives/id"
 	"gitlab.com/xx_network/primitives/netTime"
+	"strconv"
 	"sync"
 )
 
@@ -83,6 +84,10 @@ type SentTransfer struct {
 	// Stores the status of each part in a bitstream format
 	partStatus *utility.StateVector
 
+	// Unique identifier of the last progress callback called (used to prevent
+	// callback calls with duplicate data)
+	lastCallbackFingerprint string
+
 	mux sync.RWMutex
 	kv  *versioned.KV
 }
@@ -214,6 +219,36 @@ func (st *SentTransfer) CopyPartStatusVector() *utility.StateVector {
 	return st.partStatus.DeepCopy()
 }
 
+// CompareAndSwapCallbackFps compares the fingerprint to the previous callback
+// call's fingerprint. If they are different, the new one is stored, and it
+// returns true. Returns fall if they are the same.
+func (st *SentTransfer) CompareAndSwapCallbackFps(
+	completed bool, arrived, total uint16, err error) bool {
+	fp := generateSentFp(completed, arrived, total, err)
+	st.mux.Lock()
+	defer st.mux.Unlock()
+
+	if fp != st.lastCallbackFingerprint {
+		st.lastCallbackFingerprint = fp
+		return true
+	}
+
+	return false
+}
+
+// generateSentFp generates a fingerprint for a sent progress callback.
+func generateSentFp(completed bool, arrived, total uint16, err error) string {
+	errString := "<nil>"
+	if err != nil {
+		errString = err.Error()
+	}
+
+	return strconv.FormatBool(completed) +
+		strconv.FormatUint(uint64(arrived), 10) +
+		strconv.FormatUint(uint64(total), 10) +
+		errString
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 // Storage Functions                                                          //
 ////////////////////////////////////////////////////////////////////////////////
-- 
GitLab