From d6318fbb533c327d89636f5c6446fb6294fd8a20 Mon Sep 17 00:00:00 2001
From: Jono Wenger <jono@elixxir.io>
Date: Thu, 2 Dec 2021 10:51:12 -0800
Subject: [PATCH] Add support for json.Marshal and json.Unmarshal to id.ID

---
 id/id.go      | 25 +++++++++++++++++++++
 id/id_test.go | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 86 insertions(+)

diff --git a/id/id.go b/id/id.go
index 6739f1e..8914f87 100644
--- a/id/id.go
+++ b/id/id.go
@@ -15,6 +15,7 @@ import (
 	"encoding/base64"
 	"encoding/binary"
 	"encoding/hex"
+	"encoding/json"
 	"github.com/pkg/errors"
 	jww "github.com/spf13/jwalterweatherman"
 	"io"
@@ -112,6 +113,30 @@ func (id *ID) SetType(idType Type) {
 	id[ArrIDLen-1] = byte(idType)
 }
 
+// UnmarshalJSON is part of the json.Unmarshaler interface and allows IDs to be
+// marshaled into JSON.
+func (id *ID) UnmarshalJSON(b []byte) error {
+	var buff []byte
+	if err := json.Unmarshal(b, &buff); err != nil {
+		return err
+	}
+
+	newID, err := Unmarshal(buff)
+	if err != nil {
+		return err
+	}
+
+	*id = *newID
+
+	return nil
+}
+
+// MarshalJSON is part of the json.Marshaler interface and allows IDs to be
+// marshaled into JSON.
+func (id ID) MarshalJSON() ([]byte, error) {
+	return json.Marshal(id.Marshal())
+}
+
 // NewRandomID generates a random ID using the passed in io.Reader r
 // and sets the ID to Type t. If the base64 string of the generated
 // ID does not begin with an alphanumeric character, then another ID
diff --git a/id/id_test.go b/id/id_test.go
index 71cb079..5d48d23 100644
--- a/id/id_test.go
+++ b/id/id_test.go
@@ -10,6 +10,7 @@ import (
 	"bytes"
 	"encoding/base64"
 	"encoding/binary"
+	"encoding/json"
 	"fmt"
 	"math/rand"
 	"reflect"
@@ -301,6 +302,50 @@ func TestID_SetType_NilError(t *testing.T) {
 	id.SetType(Generic)
 }
 
+// Tests that an ID can be JSON marshaled and unmarshalled.
+func TestID_MarshalJSON_UnmarshalJSON(t *testing.T) {
+	testID := NewIdFromBytes(rngBytes(ArrIDLen, 42, t), t)
+
+	jsonData, err := json.Marshal(testID)
+	if err != nil {
+		t.Errorf("json.Marshal returned an error: %+v", err)
+	}
+
+	newID := &ID{}
+	err = json.Unmarshal(jsonData, newID)
+	if err != nil {
+		t.Errorf("json.Unmarshal returned an error: %+v", err)
+	}
+
+	if *testID != *newID {
+		t.Errorf("Failed the JSON marshal and unmarshal ID."+
+			"\noriginal ID: %s\nreceived ID: %s", testID, newID)
+	}
+}
+
+// Error path: supplied data is invalid JSON.
+func TestID_UnmarshalJSON_JsonUnmarshalError(t *testing.T) {
+	expectedErr := "invalid character"
+	id := ID{}
+	err := id.UnmarshalJSON([]byte("invalid JSON"))
+	if err == nil || !strings.Contains(err.Error(), expectedErr) {
+		t.Errorf("UnmarshalJSON failed to return the expected error for "+
+			"invalid JSON.\nexpected: %s\nreceived: %+v", expectedErr, err)
+	}
+}
+
+// Error path: supplied data is valid JSON but an invalid ID.
+func TestID_UnmarshalJSON_IdUnmarshalError(t *testing.T) {
+	expectedErr := fmt.Sprintf("Failed to unmarshal ID: length of data "+
+		"must be %d, length received is %d", ArrIDLen, 0)
+	id := ID{}
+	err := id.UnmarshalJSON([]byte("\"\""))
+	if err == nil || !strings.Contains(err.Error(), expectedErr) {
+		t.Errorf("UnmarshalJSON failed to return the expected error for "+
+			"invalid ID.\nexpected: %s\nreceived: %+v", expectedErr, err)
+	}
+}
+
 // Tests that NewRandomID returns the expected IDs for a given PRNG.
 func TestNewRandomID_Consistency(t *testing.T) {
 	prng := rand.New(rand.NewSource(42))
@@ -599,3 +644,19 @@ func (prng *alphaNumericPRNG) Read(p []byte) (n int, err error) {
 		return len(p), nil
 	}
 }
+
+// Generates a byte slice of the specified length containing random numbers.
+func rngBytes(length int, seed int64, t *testing.T) []byte {
+	prng := rand.New(rand.NewSource(seed))
+
+	// Create new byte slice of the correct size
+	idBytes := make([]byte, length)
+
+	// Create random bytes
+	_, err := prng.Read(idBytes)
+	if err != nil {
+		t.Fatalf("Failed to generate random bytes: %+v", err)
+	}
+
+	return idBytes
+}
-- 
GitLab