From 69e95e2c8f382f47c4ebca5575ba53ebdca6930a Mon Sep 17 00:00:00 2001
From: joshemb <josh@elixxir.io>
Date: Wed, 17 Aug 2022 11:36:00 -0700
Subject: [PATCH] Fix UserMessageInternal and make test

---
 channels/messages.go      | 36 +++++++++++++++++++++++++++-----
 channels/messages.proto   |  3 ++-
 channels/messages_test.go | 44 +++++++++++++++++++++++++++++++++++++++
 3 files changed, 77 insertions(+), 6 deletions(-)
 create mode 100644 channels/messages_test.go

diff --git a/channels/messages.go b/channels/messages.go
index 0769bf01a..52d445ec8 100644
--- a/channels/messages.go
+++ b/channels/messages.go
@@ -7,16 +7,42 @@
 
 package channels
 
-import "github.com/golang/protobuf/proto"
+import (
+	"github.com/golang/protobuf/proto"
+	"sync"
+)
 
 // UserMessageInternal is the internal structure of a UserMessage protobuf.
 type UserMessageInternal struct {
-	UserMessage
-	channelMsg ChannelMessage
+	mux sync.Mutex
+	*UserMessage
+	channelMsg *ChannelMessage
+}
+
+func NewUserMessageInternal(ursMsg *UserMessage) *UserMessageInternal {
+	return &UserMessageInternal{
+		mux:         sync.Mutex{},
+		UserMessage: ursMsg,
+		channelMsg:  nil,
+	}
 }
 
 // GetChannelMessage retrieves a serializes ChannelMessage within the
 // UserMessageInternal object.
-func (umi *UserMessageInternal) GetChannelMessage() ([]byte, error) {
-	return proto.Marshal(&umi.channelMsg)
+func (umi *UserMessageInternal) GetChannelMessage() (*ChannelMessage, error) {
+	umi.mux.Lock()
+	defer umi.mux.Unlock()
+
+	// check if channel message
+	if umi.channelMsg != nil {
+		return umi.channelMsg, nil
+	}
+	// if not, deserialize and store
+	unmarshalledChannelMsg := &ChannelMessage{}
+	err := proto.Unmarshal(umi.UserMessage.Message, unmarshalledChannelMsg)
+	if err != nil {
+		return nil, err
+	}
+
+	return unmarshalledChannelMsg, nil
 }
diff --git a/channels/messages.proto b/channels/messages.proto
index c88932967..d1f4d4162 100644
--- a/channels/messages.proto
+++ b/channels/messages.proto
@@ -31,7 +31,8 @@ message ChannelMessage{
 // UserMessage is a message sent by a user who is a member within the channel.
 message UserMessage {
     // Message contains the contents of the message. This is typically what
-    // the end-user has submitted to the channel.
+    // the end-user has submitted to the channel. This is a serialization of the
+    // ChannelMessage.
     bytes  Message = 1;
 
     // ValidationSignature is the signature validating this user owns
diff --git a/channels/messages_test.go b/channels/messages_test.go
new file mode 100644
index 000000000..f3bf0eb02
--- /dev/null
+++ b/channels/messages_test.go
@@ -0,0 +1,44 @@
+////////////////////////////////////////////////////////////////////////////////
+// Copyright © 2020 xx network SEZC                                           //
+//                                                                            //
+// Use of this source code is governed by a license that can be found in the  //
+// LICENSE file                                                               //
+////////////////////////////////////////////////////////////////////////////////
+
+package channels
+
+import (
+	"bytes"
+	"github.com/golang/protobuf/proto"
+	"testing"
+)
+
+func TestUserMessageInternal_GetChannelMessage(t *testing.T) {
+	channelMsg := &ChannelMessage{
+		Payload: []byte("ban_badUSer"),
+	}
+
+	serialized, err := proto.Marshal(channelMsg)
+	if err != nil {
+		t.Fatalf("Marshal error: %v", err)
+	}
+
+	usrMsg := &UserMessage{
+		Message:             serialized,
+		ValidationSignature: []byte("sig"),
+		Signature:           []byte("sig"),
+		Username:            "hunter",
+	}
+
+	internal := NewUserMessageInternal(usrMsg)
+	received, err := internal.GetChannelMessage()
+	if err != nil {
+		t.Fatalf("GetChannelMessage error: %v", err)
+	}
+
+	if !bytes.Equal(received.Payload, channelMsg.Payload) {
+		t.Fatalf("GetChannelMessage did not return expected data."+
+			"\nExpected: %v"+
+			"\nReceived: %v", channelMsg, received)
+	}
+}
-- 
GitLab