From 7eab6bd854c6a51c41076ca207c0c9c54de05d86 Mon Sep 17 00:00:00 2001
From: josh <josh@elixxir.io>
Date: Thu, 24 Jun 2021 11:43:13 -0700
Subject: [PATCH] Test requestFormat

---
 auth/fmt_test.go   | 132 +++++++++++++++++++++++++++++++++++++++++++--
 auth/utils_test.go |  10 +++-
 2 files changed, 137 insertions(+), 5 deletions(-)

diff --git a/auth/fmt_test.go b/auth/fmt_test.go
index c28ce0a91..9e87d2f58 100644
--- a/auth/fmt_test.go
+++ b/auth/fmt_test.go
@@ -9,6 +9,8 @@ package auth
 
 import (
 	"bytes"
+	"gitlab.com/xx_network/primitives/id"
+	"math/rand"
 	"reflect"
 	"testing"
 )
@@ -128,7 +130,7 @@ func TestBaseFormat_SetGetEcrPayload(t *testing.T) {
 
 	// Test setter
 	ecrPayloadSize := payloadSize - (pubKeySize + saltSize)
-	ecrPayload := newEcrPayload(ecrPayloadSize, "ecrPayload")
+	ecrPayload := newPayload(ecrPayloadSize, "ecrPayload")
 	baseMsg.SetEcrPayload(ecrPayload)
 	if !bytes.Equal(ecrPayload, baseMsg.ecrPayload) {
 		t.Errorf("SetEcrPayload() error: "+
@@ -163,7 +165,7 @@ func TestBaseFormat_MarshalUnmarshal(t *testing.T) {
 	payloadSize := (saltSize + pubKeySize) * 2
 	baseMsg := newBaseFormat(payloadSize, pubKeySize)
 	ecrPayloadSize := payloadSize - (pubKeySize + saltSize)
-	ecrPayload := newEcrPayload(ecrPayloadSize, "ecrPayload")
+	ecrPayload := newPayload(ecrPayloadSize, "ecrPayload")
 	baseMsg.SetEcrPayload(ecrPayload)
 	salt := newSalt("salt")
 	baseMsg.SetSalt(salt)
@@ -278,7 +280,7 @@ func TestEcrFormat_SetGetPayload(t *testing.T) {
 	ecrMsg := newEcrFormat(payloadSize)
 
 	// Test set
-	expectedPayload := newEcrPayload(payloadSize-ownershipSize, "ownership")
+	expectedPayload := newPayload(payloadSize-ownershipSize, "ownership")
 	ecrMsg.SetPayload(expectedPayload)
 
 	if !bytes.Equal(expectedPayload, ecrMsg.payload) {
@@ -312,7 +314,7 @@ func TestEcrFormat_MarshalUnmarshal(t *testing.T) {
 	// Construct message
 	payloadSize := ownershipSize * 2
 	ecrMsg := newEcrFormat(payloadSize)
-	expectedPayload := newEcrPayload(payloadSize-ownershipSize, "ownership")
+	expectedPayload := newPayload(payloadSize-ownershipSize, "ownership")
 	ecrMsg.SetPayload(expectedPayload)
 	ownership := newOwnership("owner")
 	ecrMsg.SetOwnership(ownership)
@@ -346,3 +348,125 @@ func TestEcrFormat_MarshalUnmarshal(t *testing.T) {
 	}
 
 }
+
+// Tests newRequestFormat
+func TestNewRequestFormat(t *testing.T) {
+	// Construct message
+	payloadSize := id.ArrIDLen*2 - 1
+	ecrMsg := newEcrFormat(payloadSize)
+	expectedPayload := newPayload(id.ArrIDLen, "ownership")
+	ecrMsg.SetPayload(expectedPayload)
+	reqMsg, err := newRequestFormat(ecrMsg)
+	if err != nil {
+		t.Fatalf("newRequestFormat() error: "+
+			"Failed to construct message: %v", err)
+	}
+
+	// Check that the requestFormat was constructed properly
+	if !bytes.Equal(reqMsg.id, expectedPayload) {
+		t.Errorf("newRequestFormat() error: "+
+			"Unexpected id field in requestFormat."+
+			"\n\tExpected: %v"+
+			"\n\tReceived: %v", make([]byte, id.ArrIDLen), reqMsg.id)
+	}
+
+	if !bytes.Equal(reqMsg.msgPayload, make([]byte, 0)) {
+		t.Errorf("newRequestFormat() error: "+
+			"Unexpected msgPayload field in requestFormat."+
+			"\n\tExpected: %v"+
+			"\n\tReceived: %v", make([]byte, 0), reqMsg.msgPayload)
+	}
+
+	payloadSize = ownershipSize * 2
+	ecrMsg = newEcrFormat(payloadSize)
+	reqMsg, err = newRequestFormat(ecrMsg)
+	if err == nil {
+		t.Errorf("Expecter error: Should be invalid size when calling newRequestFormat")
+	}
+
+}
+
+/* Setter/Getter tests for RequestFormat */
+
+// Unit test for Get/SetID
+func TestRequestFormat_SetGetID(t *testing.T) {
+	// Construct message
+	payloadSize := id.ArrIDLen*2 - 1
+	ecrMsg := newEcrFormat(payloadSize)
+	expectedPayload := newPayload(id.ArrIDLen, "ownership")
+	ecrMsg.SetPayload(expectedPayload)
+	reqMsg, err := newRequestFormat(ecrMsg)
+	if err != nil {
+		t.Fatalf("newRequestFormat() error: "+
+			"Failed to construct message: %v", err)
+	}
+
+	// Test SetID
+	prng := rand.New(rand.NewSource(42))
+	expectedId := randID(prng, id.User)
+	reqMsg.SetID(expectedId)
+	if !bytes.Equal(reqMsg.id, expectedId.Bytes()) {
+		t.Errorf("SetID() error: "+
+			"Id field does not have expected value."+
+			"\n\tExpected: %v\n\tReceived: %v", expectedId, reqMsg.msgPayload)
+	}
+
+	// Test GetID
+	receivedId, err := reqMsg.GetID()
+	if err != nil {
+		t.Fatalf("GetID() error: "+
+			"Retrieved id does not match expected value:"+
+			"\n\tExpected: %v\n\tReceived: %v", expectedId, receivedId)
+	}
+
+	// Test GetID error: unmarshal-able ID in requestFormat
+	reqMsg.id = []byte("badId")
+	receivedId, err = reqMsg.GetID()
+	if err == nil {
+		t.Errorf("GetID() error: " +
+			"Should not be able get ID from request message ")
+	}
+
+}
+
+// Unit test for Get/SetMsgPayload
+func TestRequestFormat_SetGetMsgPayload(t *testing.T) {
+	// Construct message
+	payloadSize := id.ArrIDLen*3 - 1
+	ecrMsg := newEcrFormat(payloadSize)
+	expectedPayload := newPayload(id.ArrIDLen*2, "ownership")
+	ecrMsg.SetPayload(expectedPayload)
+	reqMsg, err := newRequestFormat(ecrMsg)
+	if err != nil {
+		t.Fatalf("newRequestFormat() error: "+
+			"Failed to construct message: %v", err)
+	}
+
+	// Test SetMsgPayload
+	msgPayload := newPayload(id.ArrIDLen, "msgPayload")
+	reqMsg.SetMsgPayload(msgPayload)
+	if !bytes.Equal(reqMsg.msgPayload, msgPayload) {
+		t.Errorf("SetMsgPayload() error: "+
+			"MsgPayload has unexpected value: "+
+			"\n\tExpected: %v\n\tReceived: %v", msgPayload, reqMsg.msgPayload)
+	}
+
+	// Test GetMsgPayload
+	retrievedMsgPayload := reqMsg.GetMsgPayload()
+	if !bytes.Equal(retrievedMsgPayload, msgPayload) {
+		t.Errorf("GetMsgPayload() error: "+
+			"MsgPayload has unexpected value: "+
+			"\n\tExpected: %v\n\tReceived: %v", msgPayload, retrievedMsgPayload)
+
+	}
+
+	// Test SetMsgPayload error: Invalid message payload size
+	defer func() {
+		if r := recover(); r == nil {
+			t.Error("SetMsgPayload() did not panic when the size of " +
+				"the payload is the incorrect size.")
+		}
+	}()
+	expectedPayload = append(expectedPayload, expectedPayload...)
+	reqMsg.SetMsgPayload(expectedPayload)
+}
diff --git a/auth/utils_test.go b/auth/utils_test.go
index b2535747e..95ff91489 100644
--- a/auth/utils_test.go
+++ b/auth/utils_test.go
@@ -3,6 +3,8 @@ package auth
 import (
 	"gitlab.com/elixxir/crypto/cyclic"
 	"gitlab.com/xx_network/crypto/large"
+	"gitlab.com/xx_network/primitives/id"
+	"math/rand"
 )
 
 func getGroup() *cyclic.Group {
@@ -23,13 +25,19 @@ func getGroup() *cyclic.Group {
 		large.NewIntFromString("2", 16))
 }
 
+// randID returns a new random ID of the specified type.
+func randID(rng *rand.Rand, t id.Type) *id.ID {
+	newID, _ := id.NewRandomID(rng, t)
+	return newID
+}
+
 func newSalt(s string) []byte {
 	salt := make([]byte, saltSize)
 	copy(salt[:], s)
 	return salt
 }
 
-func newEcrPayload(size int, s string) []byte {
+func newPayload(size int, s string) []byte {
 	b := make([]byte, size)
 	copy(b[:], s)
 	return b
-- 
GitLab