From 68c03dbee3e363e13d654c043bebbcfa15a3ce96 Mon Sep 17 00:00:00 2001
From: benjamin <ben@elixxir.io>
Date: Sun, 23 Oct 2022 22:40:49 -0700
Subject: [PATCH] fixed a bug in channel send tracker which caused the system
 to recieve the message multiple times

---
 channels/sendTracker.go      | 47 +++++++++++++++++++++---------------
 channels/sendTracker_test.go | 11 ++++++---
 2 files changed, 35 insertions(+), 23 deletions(-)

diff --git a/channels/sendTracker.go b/channels/sendTracker.go
index 437340a51..f75819fe2 100644
--- a/channels/sendTracker.go
+++ b/channels/sendTracker.go
@@ -47,12 +47,17 @@ type tracked struct {
 	UUID      uint64
 }
 
+type trackedList struct {
+	List           []*tracked
+	RoundCompleted bool
+}
+
 // the sendTracker tracks outbound messages and denotes when they are delivered
 // to the event model. It also captures incoming messages and in the event they
 // were sent by this user diverts them as status updates on the previously sent
 // messages
 type sendTracker struct {
-	byRound map[id.Round][]*tracked
+	byRound map[id.Round]trackedList
 
 	byMessageID map[cryptoChannel.MessageID]*tracked
 
@@ -82,7 +87,7 @@ func loadSendTracker(net Client, kv *versioned.KV, trigger triggerEventFunc,
 	adminTrigger triggerAdminEventFunc, updateStatus updateStatusFunc,
 	rngSource *fastRNG.StreamGenerator) *sendTracker {
 	st := &sendTracker{
-		byRound:      make(map[id.Round][]*tracked),
+		byRound:      make(map[id.Round]trackedList),
 		byMessageID:  make(map[cryptoChannel.MessageID]*tracked),
 		unsent:       make(map[uint64]*tracked),
 		trigger:      trigger,
@@ -112,7 +117,11 @@ func loadSendTracker(net Client, kv *versioned.KV, trigger triggerEventFunc,
 			return
 		}
 		net.RemoveHealthCallback(callBackID)
-		for rid := range st.byRound {
+		for rid, oldTracked := range st.byRound {
+
+			if oldTracked.RoundCompleted {
+				continue
+			}
 
 			rr := &roundResults{
 				round: rid,
@@ -178,7 +187,7 @@ func (st *sendTracker) load() error {
 	}
 
 	for rid := range st.byRound {
-		roundList := st.byRound[rid]
+		roundList := st.byRound[rid].List
 		for j := range roundList {
 			st.byMessageID[roundList[j].MsgID] = roundList[j]
 		}
@@ -338,7 +347,8 @@ func (st *sendTracker) handleSend(uuid uint64,
 
 	//add the roundID
 	roundsList, existsRound := st.byRound[round.ID]
-	st.byRound[round.ID] = append(roundsList, t)
+	roundsList.List = append(roundsList.List, t)
+	st.byRound[round.ID] = roundsList
 
 	//add the round
 	st.byMessageID[messageID] = t
@@ -408,16 +418,19 @@ func (st *sendTracker) MessageReceive(messageID cryptoChannel.MessageID, round r
 	delete(st.byMessageID, messageID)
 
 	roundList := st.byRound[msgData.RoundID]
-	if len(roundList) == 1 {
+	if len(roundList.List) == 1 {
 		delete(st.byRound, msgData.RoundID)
 	} else {
-		newRoundList := make([]*tracked, 0, len(roundList)-1)
-		for i := range roundList {
-			if !roundList[i].MsgID.Equals(messageID) {
-				newRoundList = append(newRoundList, roundList[i])
+		newRoundList := make([]*tracked, 0, len(roundList.List)-1)
+		for i := range roundList.List {
+			if !roundList.List[i].MsgID.Equals(messageID) {
+				newRoundList = append(newRoundList, roundList.List[i])
 			}
 		}
-		st.byRound[msgData.RoundID] = newRoundList
+		st.byRound[msgData.RoundID] = trackedList{
+			List:           newRoundList,
+			RoundCompleted: roundList.RoundCompleted,
+		}
 	}
 
 	ts := mutateTimestamp(round.Timestamps[states.QUEUED], messageID)
@@ -474,12 +487,8 @@ func (rr *roundResults) callback(allRoundsSucceeded, timedOut bool, results map[
 
 	}
 
-	delete(rr.st.byRound, rr.round)
-
-	for i := range registered {
-		delete(rr.st.byMessageID, registered[i].MsgID)
-	}
-
+	registered.RoundCompleted = true
+	rr.st.byRound[rr.round] = registered
 	if err := rr.st.store(); err != nil {
 		jww.FATAL.Panicf("failed to store update after "+
 			"finalizing delivery of sent messages: %+v", err)
@@ -487,9 +496,9 @@ func (rr *roundResults) callback(allRoundsSucceeded, timedOut bool, results map[
 
 	rr.st.mux.Unlock()
 	if status == Failed {
-		for i := range registered {
+		for i := range registered.List {
 			round := results[rr.round].Round
-			go rr.st.updateStatus(registered[i].UUID, registered[i].MsgID, time.Time{},
+			go rr.st.updateStatus(registered.List[i].UUID, registered.List[i].MsgID, time.Time{},
 				round, Failed)
 		}
 	}
diff --git a/channels/sendTracker_test.go b/channels/sendTracker_test.go
index c2062c6d8..6948df3be 100644
--- a/channels/sendTracker_test.go
+++ b/channels/sendTracker_test.go
@@ -181,7 +181,7 @@ func TestSendTracker_failedSend(t *testing.T) {
 	if ok {
 		t.Fatal("Should not have found a tracked round")
 	}
-	if len(trackedRound) != 0 {
+	if len(trackedRound.List) != 0 {
 		t.Fatal("Did not find expected number of trackedRounds")
 	}
 
@@ -253,10 +253,10 @@ func TestSendTracker_send(t *testing.T) {
 	if !ok {
 		t.Fatal("Should have found a tracked round")
 	}
-	if len(trackedRound) != 1 {
+	if len(trackedRound.List) != 1 {
 		t.Fatal("Did not find expected number of trackedRounds")
 	}
-	if trackedRound[0].MsgID != mid {
+	if trackedRound.List[0].MsgID != mid {
 		t.Fatalf("Did not find expected message ID in trackedRounds")
 	}
 
@@ -279,7 +279,10 @@ func TestSendTracker_load_store(t *testing.T) {
 	cid := id.NewIdFromString("channel", id.User, t)
 	mid := cryptoChannel.MakeMessageID([]byte("hello"), cid)
 	rid := id.Round(2)
-	st.byRound[rid] = []*tracked{{MsgID: mid, ChannelID: cid, RoundID: rid}}
+	st.byRound[rid] = trackedList{
+		List:           []*tracked{{MsgID: mid, ChannelID: cid, RoundID: rid}},
+		RoundCompleted: false,
+	}
 	err := st.store()
 	if err != nil {
 		t.Fatalf("Failed to store byRound: %+v", err)
-- 
GitLab