From 4e83be61aca522249ee2354580059cc257fd6a01 Mon Sep 17 00:00:00 2001
From: josh <josh@elixxir.io>
Date: Tue, 9 Mar 2021 10:29:03 -0800
Subject: [PATCH] Improve testing code

---
 network/dataStructures/roundEvents_test.go   |  4 ++--
 network/dataStructures/roundUpdates_test.go  |  6 +++---
 network/dataStructures/round_test.go         |  2 +-
 network/dataStructures/waitingRounds_test.go |  6 +++---
 network/instance_test.go                     |  6 +++---
 network/validation.go                        |  2 +-
 testutils/utils.go                           | 13 ++++++++++++-
 7 files changed, 25 insertions(+), 14 deletions(-)

diff --git a/network/dataStructures/roundEvents_test.go b/network/dataStructures/roundEvents_test.go
index 287bde4e..0a864732 100644
--- a/network/dataStructures/roundEvents_test.go
+++ b/network/dataStructures/roundEvents_test.go
@@ -100,7 +100,7 @@ func TestRoundEvents_TriggerRoundEvent(t *testing.T) {
 		State: uint32(states.PENDING),
 	}
 
-	if err := testutils.SignRoundInfo(ri); err != nil {
+	if err := testutils.SignRoundInfo(ri, t); err != nil {
 		t.Errorf("Failed to sign mock round info: %v", err)
 	}
 
@@ -139,7 +139,7 @@ func TestRoundEvents_AddRoundEventChan(t *testing.T) {
 		ID:    uint64(rid),
 		State: uint32(states.PENDING),
 	}
-	if err := testutils.SignRoundInfo(ri); err != nil {
+	if err := testutils.SignRoundInfo(ri, t); err != nil {
 		t.Errorf("Failed to sign mock round info: %v", err)
 	}
 
diff --git a/network/dataStructures/roundUpdates_test.go b/network/dataStructures/roundUpdates_test.go
index 576ebdbd..29cd4d51 100644
--- a/network/dataStructures/roundUpdates_test.go
+++ b/network/dataStructures/roundUpdates_test.go
@@ -35,7 +35,7 @@ func TestUpdates_GetUpdate(t *testing.T) {
 		ID:       0,
 		UpdateID: uint64(updateID),
 	}
-	if err := testutils.SignRoundInfo(ri); err != nil {
+	if err := testutils.SignRoundInfo(ri, t); err != nil {
 		t.Errorf("Failed to sign mock round info: %v", err)
 	}
 	rnd := NewRound(ri, testutils.LoadKeyTesting(t))
@@ -54,7 +54,7 @@ func TestUpdates_GetUpdates(t *testing.T) {
 		ID:       0,
 		UpdateID: uint64(updateID),
 	}
-	if err := testutils.SignRoundInfo(roundInfoOne); err != nil {
+	if err := testutils.SignRoundInfo(roundInfoOne, t); err != nil {
 		t.Errorf("Failed to sign mock round info: %v", err)
 	}
 	roundOne := NewRound(roundInfoOne, testutils.LoadKeyTesting(t))
@@ -64,7 +64,7 @@ func TestUpdates_GetUpdates(t *testing.T) {
 		ID:       0,
 		UpdateID: uint64(updateID + 1),
 	}
-	if err := testutils.SignRoundInfo(roundInfoTwo); err != nil {
+	if err := testutils.SignRoundInfo(roundInfoTwo, t); err != nil {
 		t.Errorf("Failed to sign mock round info: %v", err)
 	}
 	roundTwo := NewRound(roundInfoTwo, testutils.LoadKeyTesting(t))
diff --git a/network/dataStructures/round_test.go b/network/dataStructures/round_test.go
index 8653b318..f1439166 100644
--- a/network/dataStructures/round_test.go
+++ b/network/dataStructures/round_test.go
@@ -37,7 +37,7 @@ func TestNewRound_Get(t *testing.T) {
 	pubKey := testutils.LoadKeyTesting(t)
 	ri := &mixmessages.RoundInfo{ID: uint64(1), UpdateID: uint64(1)}
 	// Mock signature of roundInfo as it will be verified in codepath
-	testutils.SignRoundInfo(ri)
+	testutils.SignRoundInfo(ri, t)
 	rnd := NewRound(ri, pubKey)
 
 	// Check the initial value of the atomic value (lazily)
diff --git a/network/dataStructures/waitingRounds_test.go b/network/dataStructures/waitingRounds_test.go
index 440d6dd6..31943076 100644
--- a/network/dataStructures/waitingRounds_test.go
+++ b/network/dataStructures/waitingRounds_test.go
@@ -159,7 +159,7 @@ func TestWaitingRounds_GetUpcomingRealtime_NoWait(t *testing.T) {
 	testWR := NewWaitingRounds()
 
 	for _, round := range testRounds {
-		testutils.SignRoundInfo(round.info)
+		testutils.SignRoundInfo(round.info, t)
 		testWR.Insert(round)
 	}
 
@@ -202,7 +202,7 @@ func TestWaitingRounds_GetUpcomingRealtime(t *testing.T) {
 	for i, round := range expectedRounds {
 		go func(round *Round) {
 			time.Sleep(30 * time.Millisecond)
-			testutils.SignRoundInfo(round.info)
+			testutils.SignRoundInfo(round.info, t)
 			testWR.Insert(round)
 		}(round)
 
@@ -232,7 +232,7 @@ func TestWaitingRounds_GetSlice(t *testing.T) {
 		Timestamps: []uint64{0, 0, 0, 0, 0},
 	}
 
-	testutils.SignRoundInfo(ri)
+	testutils.SignRoundInfo(ri, t)
 
 	rnd := NewRound(ri, testutils.LoadKeyTesting(t))
 
diff --git a/network/instance_test.go b/network/instance_test.go
index e68b7358..b06493c7 100644
--- a/network/instance_test.go
+++ b/network/instance_test.go
@@ -166,7 +166,7 @@ func TestInstance_GetRoundUpdate(t *testing.T) {
 	}
 
 	ri := &mixmessages.RoundInfo{ID: uint64(1), UpdateID: uint64(1)}
-	testutils.SignRoundInfo(ri)
+	testutils.SignRoundInfo(ri, t)
 	rnd := ds.NewRound(ri, testutils.LoadKeyTesting(t))
 
 	_ = i.roundUpdates.AddRound(rnd)
@@ -182,9 +182,9 @@ func TestInstance_GetRoundUpdates(t *testing.T) {
 	}
 
 	roundInfoOne := &mixmessages.RoundInfo{ID: uint64(1), UpdateID: uint64(1)}
-	testutils.SignRoundInfo(roundInfoOne)
+	testutils.SignRoundInfo(roundInfoOne, t)
 	roundInfoTwo := &mixmessages.RoundInfo{ID: uint64(1), UpdateID: uint64(2)}
-	testutils.SignRoundInfo(roundInfoTwo)
+	testutils.SignRoundInfo(roundInfoTwo, t)
 	roundOne := ds.NewRound(roundInfoOne, testutils.LoadKeyTesting(t))
 	roundTwo := ds.NewRound(roundInfoTwo, testutils.LoadKeyTesting(t))
 
diff --git a/network/validation.go b/network/validation.go
index 5fb1691e..85141943 100644
--- a/network/validation.go
+++ b/network/validation.go
@@ -32,4 +32,4 @@ func (s ValidationType) String() string {
 	default:
 		return fmt.Sprintf("UNKNOWN STATE: %d", s)
 	}
-}
\ No newline at end of file
+}
diff --git a/testutils/utils.go b/testutils/utils.go
index 3ec15dc6..218e2c6b 100644
--- a/testutils/utils.go
+++ b/testutils/utils.go
@@ -35,7 +35,18 @@ func LoadKeyTesting(t *testing.T) *rsa.PublicKey {
 }
 
 // Utility function which signs a round info message
-func SignRoundInfo(ri *pb.RoundInfo) error {
+func SignRoundInfo(ri *pb.RoundInfo, i interface{}) error {
+	switch i.(type) {
+	case *testing.T:
+		break
+	case *testing.M:
+		break
+	case *testing.B:
+		break
+	default:
+		jww.FATAL.Panicf("SignRoundInfo is restricted to testing only. Got %T", i)
+	}
+
 	keyPath := testkeys.GetNodeKeyPath()
 	keyData := testkeys.LoadFromPath(keyPath)
 
-- 
GitLab