From 45584b1f9a9d956e7f30f8184522d02893192d60 Mon Sep 17 00:00:00 2001
From: Jono Wenger <jono@elixxir.io>
Date: Mon, 20 Jun 2022 13:41:18 -0700
Subject: [PATCH] Fix file manager test wait group crash

---
 fileTransfer2/connect/wrapper_test.go   | 5 +++--
 fileTransfer2/e2e/wrapper_test.go       | 5 +++--
 fileTransfer2/groupChat/wrapper_test.go | 5 +++--
 fileTransfer2/manager_test.go           | 5 +++--
 4 files changed, 12 insertions(+), 8 deletions(-)

diff --git a/fileTransfer2/connect/wrapper_test.go b/fileTransfer2/connect/wrapper_test.go
index ff538756f..686580a2a 100644
--- a/fileTransfer2/connect/wrapper_test.go
+++ b/fileTransfer2/connect/wrapper_test.go
@@ -20,6 +20,7 @@ import (
 	"gitlab.com/xx_network/primitives/netTime"
 	"math"
 	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 )
@@ -113,14 +114,14 @@ func Test_FileTransfer_Smoke(t *testing.T) {
 	// a progress callback that then checks that the file received is correct
 	// when done
 	wg.Add(1)
-	var called bool
+	called := uint32(0)
 	timeReceived := make(chan time.Time)
 	go func() {
 		select {
 		case r := <-receiveCbChan2:
 			receiveProgressCB := func(completed bool, received, total uint16,
 				rt ft.ReceivedTransfer, fpt ft.FilePartTracker, err error) {
-				if completed && !called {
+				if completed && atomic.CompareAndSwapUint32(&called, 0, 1) {
 					timeReceived <- netTime.Now()
 					receivedFile, err2 := m2.Receive(r.tid)
 					if err2 != nil {
diff --git a/fileTransfer2/e2e/wrapper_test.go b/fileTransfer2/e2e/wrapper_test.go
index 50490e4a1..92133edc5 100644
--- a/fileTransfer2/e2e/wrapper_test.go
+++ b/fileTransfer2/e2e/wrapper_test.go
@@ -20,6 +20,7 @@ import (
 	"gitlab.com/xx_network/primitives/netTime"
 	"math"
 	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 )
@@ -115,14 +116,14 @@ func Test_FileTransfer_Smoke(t *testing.T) {
 	// a progress callback that then checks that the file received is correct
 	// when done
 	wg.Add(1)
-	var called bool
+	called := uint32(0)
 	timeReceived := make(chan time.Time)
 	go func() {
 		select {
 		case r := <-receiveCbChan2:
 			receiveProgressCB := func(completed bool, received, total uint16,
 				rt ft.ReceivedTransfer, fpt ft.FilePartTracker, err error) {
-				if completed && !called {
+				if completed && atomic.CompareAndSwapUint32(&called, 0, 1) {
 					timeReceived <- netTime.Now()
 					receivedFile, err2 := m2.Receive(r.tid)
 					if err2 != nil {
diff --git a/fileTransfer2/groupChat/wrapper_test.go b/fileTransfer2/groupChat/wrapper_test.go
index da36d23db..a1bb774bf 100644
--- a/fileTransfer2/groupChat/wrapper_test.go
+++ b/fileTransfer2/groupChat/wrapper_test.go
@@ -18,6 +18,7 @@ import (
 	"gitlab.com/xx_network/primitives/netTime"
 	"math"
 	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 )
@@ -100,14 +101,14 @@ func Test_FileTransfer_Smoke(t *testing.T) {
 	// a progress callback that then checks that the file received is correct
 	// when done
 	wg.Add(1)
-	var called bool
+	called := uint32(0)
 	timeReceived := make(chan time.Time)
 	go func() {
 		select {
 		case r := <-receiveCbChan2:
 			receiveProgressCB := func(completed bool, received, total uint16,
 				rt ft.ReceivedTransfer, fpt ft.FilePartTracker, err error) {
-				if completed && !called {
+				if completed && atomic.CompareAndSwapUint32(&called, 0, 1) {
 					timeReceived <- netTime.Now()
 					receivedFile, err2 := m2.Receive(r.tid)
 					if err2 != nil {
diff --git a/fileTransfer2/manager_test.go b/fileTransfer2/manager_test.go
index cda5efc7c..d9827a543 100644
--- a/fileTransfer2/manager_test.go
+++ b/fileTransfer2/manager_test.go
@@ -19,6 +19,7 @@ import (
 	"math/rand"
 	"reflect"
 	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 )
@@ -137,7 +138,7 @@ func Test_FileTransfer_Smoke(t *testing.T) {
 	// a progress callback that then checks that the file received is correct
 	// when done
 	wg.Add(1)
-	var called bool
+	called := uint32(0)
 	timeReceived := make(chan time.Time)
 	go func() {
 		select {
@@ -148,7 +149,7 @@ func Test_FileTransfer_Smoke(t *testing.T) {
 			}
 			receiveProgressCB := func(completed bool, received, total uint16,
 				rt ReceivedTransfer, fpt FilePartTracker, err error) {
-				if completed && !called {
+				if completed && atomic.CompareAndSwapUint32(&called, 0, 1) {
 					timeReceived <- netTime.Now()
 					receivedFile, err2 := m2.Receive(tid)
 					if err2 != nil {
-- 
GitLab