Skip to content
Snippets Groups Projects
Commit 68c03dbe authored by benjamin's avatar benjamin
Browse files

fixed a bug in channel send tracker which caused the system to recieve the message multiple times

parent 5bc533a9
No related branches found
No related tags found
2 merge requests!510Release,!340Project/channels
...@@ -47,12 +47,17 @@ type tracked struct { ...@@ -47,12 +47,17 @@ type tracked struct {
UUID uint64 UUID uint64
} }
type trackedList struct {
List []*tracked
RoundCompleted bool
}
// the sendTracker tracks outbound messages and denotes when they are delivered // the sendTracker tracks outbound messages and denotes when they are delivered
// to the event model. It also captures incoming messages and in the event they // to the event model. It also captures incoming messages and in the event they
// were sent by this user diverts them as status updates on the previously sent // were sent by this user diverts them as status updates on the previously sent
// messages // messages
type sendTracker struct { type sendTracker struct {
byRound map[id.Round][]*tracked byRound map[id.Round]trackedList
byMessageID map[cryptoChannel.MessageID]*tracked byMessageID map[cryptoChannel.MessageID]*tracked
...@@ -82,7 +87,7 @@ func loadSendTracker(net Client, kv *versioned.KV, trigger triggerEventFunc, ...@@ -82,7 +87,7 @@ func loadSendTracker(net Client, kv *versioned.KV, trigger triggerEventFunc,
adminTrigger triggerAdminEventFunc, updateStatus updateStatusFunc, adminTrigger triggerAdminEventFunc, updateStatus updateStatusFunc,
rngSource *fastRNG.StreamGenerator) *sendTracker { rngSource *fastRNG.StreamGenerator) *sendTracker {
st := &sendTracker{ st := &sendTracker{
byRound: make(map[id.Round][]*tracked), byRound: make(map[id.Round]trackedList),
byMessageID: make(map[cryptoChannel.MessageID]*tracked), byMessageID: make(map[cryptoChannel.MessageID]*tracked),
unsent: make(map[uint64]*tracked), unsent: make(map[uint64]*tracked),
trigger: trigger, trigger: trigger,
...@@ -112,7 +117,11 @@ func loadSendTracker(net Client, kv *versioned.KV, trigger triggerEventFunc, ...@@ -112,7 +117,11 @@ func loadSendTracker(net Client, kv *versioned.KV, trigger triggerEventFunc,
return return
} }
net.RemoveHealthCallback(callBackID) net.RemoveHealthCallback(callBackID)
for rid := range st.byRound { for rid, oldTracked := range st.byRound {
if oldTracked.RoundCompleted {
continue
}
rr := &roundResults{ rr := &roundResults{
round: rid, round: rid,
...@@ -178,7 +187,7 @@ func (st *sendTracker) load() error { ...@@ -178,7 +187,7 @@ func (st *sendTracker) load() error {
} }
for rid := range st.byRound { for rid := range st.byRound {
roundList := st.byRound[rid] roundList := st.byRound[rid].List
for j := range roundList { for j := range roundList {
st.byMessageID[roundList[j].MsgID] = roundList[j] st.byMessageID[roundList[j].MsgID] = roundList[j]
} }
...@@ -338,7 +347,8 @@ func (st *sendTracker) handleSend(uuid uint64, ...@@ -338,7 +347,8 @@ func (st *sendTracker) handleSend(uuid uint64,
//add the roundID //add the roundID
roundsList, existsRound := st.byRound[round.ID] roundsList, existsRound := st.byRound[round.ID]
st.byRound[round.ID] = append(roundsList, t) roundsList.List = append(roundsList.List, t)
st.byRound[round.ID] = roundsList
//add the round //add the round
st.byMessageID[messageID] = t st.byMessageID[messageID] = t
...@@ -408,16 +418,19 @@ func (st *sendTracker) MessageReceive(messageID cryptoChannel.MessageID, round r ...@@ -408,16 +418,19 @@ func (st *sendTracker) MessageReceive(messageID cryptoChannel.MessageID, round r
delete(st.byMessageID, messageID) delete(st.byMessageID, messageID)
roundList := st.byRound[msgData.RoundID] roundList := st.byRound[msgData.RoundID]
if len(roundList) == 1 { if len(roundList.List) == 1 {
delete(st.byRound, msgData.RoundID) delete(st.byRound, msgData.RoundID)
} else { } else {
newRoundList := make([]*tracked, 0, len(roundList)-1) newRoundList := make([]*tracked, 0, len(roundList.List)-1)
for i := range roundList { for i := range roundList.List {
if !roundList[i].MsgID.Equals(messageID) { if !roundList.List[i].MsgID.Equals(messageID) {
newRoundList = append(newRoundList, roundList[i]) newRoundList = append(newRoundList, roundList.List[i])
} }
} }
st.byRound[msgData.RoundID] = newRoundList st.byRound[msgData.RoundID] = trackedList{
List: newRoundList,
RoundCompleted: roundList.RoundCompleted,
}
} }
ts := mutateTimestamp(round.Timestamps[states.QUEUED], messageID) ts := mutateTimestamp(round.Timestamps[states.QUEUED], messageID)
...@@ -474,12 +487,8 @@ func (rr *roundResults) callback(allRoundsSucceeded, timedOut bool, results map[ ...@@ -474,12 +487,8 @@ func (rr *roundResults) callback(allRoundsSucceeded, timedOut bool, results map[
} }
delete(rr.st.byRound, rr.round) registered.RoundCompleted = true
rr.st.byRound[rr.round] = registered
for i := range registered {
delete(rr.st.byMessageID, registered[i].MsgID)
}
if err := rr.st.store(); err != nil { if err := rr.st.store(); err != nil {
jww.FATAL.Panicf("failed to store update after "+ jww.FATAL.Panicf("failed to store update after "+
"finalizing delivery of sent messages: %+v", err) "finalizing delivery of sent messages: %+v", err)
...@@ -487,9 +496,9 @@ func (rr *roundResults) callback(allRoundsSucceeded, timedOut bool, results map[ ...@@ -487,9 +496,9 @@ func (rr *roundResults) callback(allRoundsSucceeded, timedOut bool, results map[
rr.st.mux.Unlock() rr.st.mux.Unlock()
if status == Failed { if status == Failed {
for i := range registered { for i := range registered.List {
round := results[rr.round].Round round := results[rr.round].Round
go rr.st.updateStatus(registered[i].UUID, registered[i].MsgID, time.Time{}, go rr.st.updateStatus(registered.List[i].UUID, registered.List[i].MsgID, time.Time{},
round, Failed) round, Failed)
} }
} }
......
...@@ -181,7 +181,7 @@ func TestSendTracker_failedSend(t *testing.T) { ...@@ -181,7 +181,7 @@ func TestSendTracker_failedSend(t *testing.T) {
if ok { if ok {
t.Fatal("Should not have found a tracked round") t.Fatal("Should not have found a tracked round")
} }
if len(trackedRound) != 0 { if len(trackedRound.List) != 0 {
t.Fatal("Did not find expected number of trackedRounds") t.Fatal("Did not find expected number of trackedRounds")
} }
...@@ -253,10 +253,10 @@ func TestSendTracker_send(t *testing.T) { ...@@ -253,10 +253,10 @@ func TestSendTracker_send(t *testing.T) {
if !ok { if !ok {
t.Fatal("Should have found a tracked round") t.Fatal("Should have found a tracked round")
} }
if len(trackedRound) != 1 { if len(trackedRound.List) != 1 {
t.Fatal("Did not find expected number of trackedRounds") t.Fatal("Did not find expected number of trackedRounds")
} }
if trackedRound[0].MsgID != mid { if trackedRound.List[0].MsgID != mid {
t.Fatalf("Did not find expected message ID in trackedRounds") t.Fatalf("Did not find expected message ID in trackedRounds")
} }
...@@ -279,7 +279,10 @@ func TestSendTracker_load_store(t *testing.T) { ...@@ -279,7 +279,10 @@ func TestSendTracker_load_store(t *testing.T) {
cid := id.NewIdFromString("channel", id.User, t) cid := id.NewIdFromString("channel", id.User, t)
mid := cryptoChannel.MakeMessageID([]byte("hello"), cid) mid := cryptoChannel.MakeMessageID([]byte("hello"), cid)
rid := id.Round(2) rid := id.Round(2)
st.byRound[rid] = []*tracked{{MsgID: mid, ChannelID: cid, RoundID: rid}} st.byRound[rid] = trackedList{
List: []*tracked{{MsgID: mid, ChannelID: cid, RoundID: rid}},
RoundCompleted: false,
}
err := st.store() err := st.store()
if err != nil { if err != nil {
t.Fatalf("Failed to store byRound: %+v", err) t.Fatalf("Failed to store byRound: %+v", err)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment