From 54a8794b6defe94ed0e02dab77a06cb8ee7fde11 Mon Sep 17 00:00:00 2001
From: Benjamin Wenger <ben@elixxir.ioo>
Date: Tue, 22 Mar 2022 11:25:57 -0700
Subject: [PATCH] modified cmix message buffer

---
 network/message/garbled.go                  |  72 +---------
 network/message/handler.go                  | 152 +++-----------------
 network/message/manager.go                  |  37 ++++-
 network/message/sendCmix.go                 |   2 +-
 network/message/sendCmixUtils.go            |   2 +-
 network/message/sendManyCmix.go             |   2 +-
 storage/messages.go                         |   1 -
 storage/utility/meteredCmixMessageBuffer.go |  41 +++++-
 8 files changed, 95 insertions(+), 214 deletions(-)

diff --git a/network/message/garbled.go b/network/message/garbled.go
index b2449ca9f..7295eaf08 100644
--- a/network/message/garbled.go
+++ b/network/message/garbled.go
@@ -53,76 +53,11 @@ func (m *Manager) processGarbledMessages(stop *stoppable.Single) {
 
 //handler for a single run of garbled messages
 func (m *Manager) handleGarbledMessages() {
-	garbledMsgs := m.Session.GetGarbledMessages()
-	e2eKv := m.Session.E2e()
-	var failedMsgs []format.Message
 	//try to decrypt every garbled message, excising those who's counts are too high
-	for grbldMsg, count, timestamp, has := garbledMsgs.Next(); has; grbldMsg, count, timestamp, has = garbledMsgs.Next() {
+	for grbldMsg, count, timestamp, has := m.garbledStore.Next(); has; grbldMsg, count, timestamp, has = m.garbledStore.Next() {
 		//if it exists, check against all in the list
 		grbldContents := grbldMsg.GetContents()
-		identity := m.Session.GetUser().ReceptionID
-		_, forMe, _ := m.Session.GetEdge().Check(identity, grbldMsg.GetIdentityFP(), grbldContents)
-		if forMe {
-			fingerprint := grbldMsg.GetKeyFP()
-			// Check if the key is there, process it if it is
-			if key, isE2E := e2eKv.PopKey(fingerprint); isE2E {
-				jww.INFO.Printf("[GARBLE] Check E2E for %s, KEYFP: %s",
-					grbldMsg.Digest(), grbldMsg.GetKeyFP())
-				// Decrypt encrypted message
-				msg, err := key.Decrypt(grbldMsg)
-				if err == nil {
-					// get the sender
-					sender := key.GetSession().GetPartner()
-					//remove from the buffer if decryption is successful
-					garbledMsgs.Remove(grbldMsg)
-
-					jww.INFO.Printf("[GARBLE] message decoded as E2E from "+
-						"%s, msgDigest: %s", sender, grbldMsg.Digest())
-
-					//handle the successfully decrypted message
-					xxMsg, ok := m.partitioner.HandlePartition(sender, message.E2E,
-						msg.GetContents(),
-						key.GetSession().GetRelationshipFingerprint())
-					if ok {
-						m.Switchboard.Speak(xxMsg)
-						continue
-					}
-				}
-			} else {
-				// todo: figure out how to get the ephermal reception id in here.
-				// we have the raw data, but do not know what address space was
-				// used int he round
-				// todo: figure out how to get the round id, the recipient id, and the round timestamp
-				/*
-					ephid, err := ephemeral.Marshal(garbledMsg.GetEphemeralRID())
-					if err!=nil{
-						jww.WARN.Printf("failed to get the ephemeral id for a garbled " +
-							"message, clearing the message: %+v", err)
-						garbledMsgs.Remove(garbledMsg)
-						continue
-					}
-
-					ephid.Clear(m.)*/
-
-				raw := message.Receive{
-					Payload:        grbldMsg.Marshal(),
-					MessageType:    message.Raw,
-					Sender:         &id.ID{},
-					EphemeralID:    ephemeral.Id{},
-					Timestamp:      time.Time{},
-					Encryption:     message.None,
-					RecipientID:    &id.ID{},
-					RoundId:        0,
-					RoundTimestamp: time.Time{},
-				}
-				im := fmt.Sprintf("[GARBLE] RAW Message reprocessed: keyFP: %v, "+
-					"msgDigest: %s", grbldMsg.GetKeyFP(), grbldMsg.Digest())
-				jww.INFO.Print(im)
-				m.Internal.Events.Report(1, "MessageReception", "Garbled", im)
-				m.Session.GetGarbledMessages().Add(grbldMsg)
-				m.Switchboard.Speak(raw)
-			}
-		}
+		identity := m.session.GetUser().ReceptionID
 
 		// fail the message if any part of the decryption fails,
 		// unless it is the last attempts and has been in the buffer long
@@ -134,7 +69,4 @@ func (m *Manager) handleGarbledMessages() {
 			failedMsgs = append(failedMsgs, grbldMsg)
 		}
 	}
-	for _, grbldMsg := range failedMsgs {
-		garbledMsgs.Failed(grbldMsg)
-	}
 }
diff --git a/network/message/handler.go b/network/message/handler.go
index 9ffb92bab..0a051e079 100644
--- a/network/message/handler.go
+++ b/network/message/handler.go
@@ -9,6 +9,7 @@ package message
 
 import (
 	"fmt"
+	"sync"
 	"time"
 
 	jww "github.com/spf13/jwalterweatherman"
@@ -25,146 +26,33 @@ import (
 )
 
 func (m *Manager) handleMessages(stop *stoppable.Single) {
-	preimageList := m.Session.GetEdge()
 	for {
 		select {
 		case <-stop.Quit():
 			stop.ToStopped()
 			return
 		case bundle := <-m.messageReception:
-			for _, msg := range bundle.Messages {
-				m.handleMessage(msg, bundle, preimageList)
-			}
-			bundle.Finish()
-		}
-	}
-
-}
-
-func (m *Manager) handleMessage(ecrMsg format.Message, bundle Bundle, edge *edge.Store) {
-	// We've done all the networking, now process the message
-	fingerprint := ecrMsg.GetKeyFP()
-	msgDigest := ecrMsg.Digest()
-	identity := bundle.Identity
-
-	e2eKv := m.Session.E2e()
-
-	var sender *id.ID
-	var msg format.Message
-	var encTy message.EncryptionType
-	var err error
-	var relationshipFingerprint []byte
-
-	//if it exists, check against all in the list
-	ecrMsgContents := ecrMsg.GetContents()
-	has, forMe, _ := m.Session.GetEdge().Check(identity.Source, ecrMsg.GetIdentityFP(), ecrMsgContents)
-	if !has {
-		jww.INFO.Printf("checking backup %v", preimage.MakeDefault(identity.Source))
-		//if it doesnt exist, check against the default fingerprint for the identity
-		forMe = fingerprint2.CheckIdentityFP(ecrMsg.GetIdentityFP(),
-			ecrMsgContents, preimage.MakeDefault(identity.Source))
-	}
-
-	if !forMe {
-		if jww.GetLogThreshold() == jww.LevelTrace {
-			expectedFP := fingerprint2.IdentityFP(ecrMsgContents,
-				preimage.MakeDefault(identity.Source))
-			jww.TRACE.Printf("Message for %d (%s) failed identity "+
-				"check: %v (expected-default) vs %v (received)", identity.EphId,
-				identity.Source, expectedFP, ecrMsg.GetIdentityFP())
-		}
-		im := fmt.Sprintf("Garbled/RAW Message: keyFP: %v, round: %d"+
-			"msgDigest: %s, not determined to be for client", ecrMsg.GetKeyFP(), bundle.Round, ecrMsg.Digest())
-		m.Internal.Events.Report(1, "MessageReception", "Garbled", im)
-		m.Session.GetGarbledMessages().Add(ecrMsg)
-		return
-	}
+				go func(){
+					wg := sync.WaitGroup{}
+					wg.Add(len(bundle.Messages))
+					for _, msg := range bundle.Messages {
+						go func() {
+							m.handleMessage(msg, bundle)
+							wg.Done()
+						}()
+					}
+					wg.Wait()
+					bundle.Finish()
+				}()
 
-	// try to get the key fingerprint, process as e2e encryption if
-	// the fingerprint is found
-	if key, isE2E := e2eKv.PopKey(fingerprint); isE2E {
-		// Decrypt encrypted message
-		msg, err = key.Decrypt(ecrMsg)
-		// get the sender
-		sender = key.GetSession().GetPartner()
-		relationshipFingerprint = key.GetSession().GetRelationshipFingerprint()
+			}
 
-		//drop the message is decryption failed
-		if err != nil {
-			//if decryption failed, print an error
-			msg := fmt.Sprintf("Failed to decrypt message with "+
-				"fp %s from partner %s: %s", key.Fingerprint(),
-				sender, err)
-			jww.WARN.Printf(msg)
-			m.Internal.Events.Report(9, "MessageReception",
-				"DecryptionError", msg)
-			return
-		}
-		//set the type as E2E encrypted
-		encTy = message.E2E
-	} else if isUnencrypted, uSender := e2e.IsUnencrypted(ecrMsg); isUnencrypted {
-		// if the key fingerprint does not match, try to treat it as an
-		// unencrypted message
-		sender = uSender
-		msg = ecrMsg
-		encTy = message.None
-	} else {
-		// if it doesn't match any form of encrypted, hear it as a raw message
-		// and add it to garbled messages to be handled later
-		msg = ecrMsg
-		raw := message.Receive{
-			Payload:        msg.Marshal(),
-			MessageType:    message.Raw,
-			Sender:         &id.ID{},
-			EphemeralID:    identity.EphId,
-			Timestamp:      time.Time{},
-			Encryption:     message.None,
-			RecipientID:    identity.Source,
-			RoundId:        id.Round(bundle.RoundInfo.ID),
-			RoundTimestamp: time.Unix(0, int64(bundle.RoundInfo.Timestamps[states.QUEUED])),
 		}
-		im := fmt.Sprintf("Received message of type Garbled/RAW: keyFP: %v, round: %d, "+
-			"msgDigest: %s", msg.GetKeyFP(), bundle.Round, msg.Digest())
-		jww.INFO.Print(im)
-		m.Internal.Events.Report(1, "MessageReception", "Garbled", im)
-		m.Session.GetGarbledMessages().Add(msg)
-		m.Switchboard.Speak(raw)
-		return
 	}
 
-	// Process the decrypted/unencrypted message partition, to see if
-	// we get a full message
-	xxMsg, ok := m.partitioner.HandlePartition(sender, encTy, msg.GetContents(),
-		relationshipFingerprint)
-
-	im := fmt.Sprintf("Received message of ecr type %s and msg type "+
-		"%d from %s in round %d,msgDigest: %s, keyFP: %v", encTy,
-		xxMsg.MessageType, sender, bundle.Round, msgDigest, msg.GetKeyFP())
-	jww.INFO.Print(im)
-	m.Internal.Events.Report(2, "MessageReception", "MessagePart", im)
-
-	// If the reception completed a message, hear it on the switchboard
-	if ok {
-		//Set the identities
-		xxMsg.RecipientID = identity.Source
-		xxMsg.EphemeralID = identity.EphId
-		xxMsg.Encryption = encTy
-		xxMsg.RoundId = id.Round(bundle.RoundInfo.ID)
-		xxMsg.RoundTimestamp = time.Unix(0, int64(bundle.RoundInfo.Timestamps[states.QUEUED]))
-		if xxMsg.MessageType == message.Raw {
-			rm := fmt.Sprintf("Received a message of type 'Raw' from %s."+
-				"Message Ignored, 'Raw' is a reserved type. Message supressed.",
-				xxMsg.ID)
-			jww.WARN.Print(rm)
-			m.Internal.Events.Report(10, "MessageReception",
-				"Error", rm)
-		} else {
-			m.Switchboard.Speak(xxMsg)
-		}
-	}
 }
 
-func (m *Manager) handleMessage2(ecrMsg format.Message, bundle Bundle) {
+func (m *Manager) handleMessage(ecrMsg format.Message, bundle Bundle) {
 	fingerprint := ecrMsg.GetKeyFP()
 	// msgDigest := ecrMsg.Digest()
 	identity := bundle.Identity
@@ -175,11 +63,16 @@ func (m *Manager) handleMessage2(ecrMsg format.Message, bundle Bundle) {
 	// 	ID id.ID
 	// 	ephID ephemeral.Id
 	// }
+
+	//save to garbled
+	m.garbledStore.Add(ecrMsg)
+
 	var receptionID interfaces.Identity
 
 	// If we have a fingerprint, process it.
 	if proc, exists := m.pop(fingerprint); exists {
 		proc.Process(ecrMsg, receptionID, round)
+		m.garbledStore.Remove(ecrMsg)
 		return
 	}
 
@@ -191,8 +84,8 @@ func (m *Manager) handleMessage2(ecrMsg format.Message, bundle Bundle) {
 		if len(triggers) == 0 {
 			jww.ERROR.Printf("empty trigger list for %s",
 				ecrMsg.GetIdentityFP()) // get preimage
-			)
 		}
+		m.garbledStore.Remove(ecrMsg)
 		return
 	} else {
 		// TODO: delete this else block because it should not be needed.
@@ -214,5 +107,6 @@ func (m *Manager) handleMessage2(ecrMsg format.Message, bundle Bundle) {
 		"msgDigest: %s, not determined to be for client",
 		ecrMsg.GetKeyFP(), bundle.Round, ecrMsg.Digest())
 	m.Internal.Events.Report(1, "MessageReception", "Garbled", im)
-	m.Session.GetGarbledMessages().Add(ecrMsg)
+	//denote as active in garbled
+	m.garbledStore.Failed(ecrMsg)
 }
diff --git a/network/message/manager.go b/network/message/manager.go
index d9d1ec2c0..75918d941 100644
--- a/network/message/manager.go
+++ b/network/message/manager.go
@@ -10,18 +10,24 @@ package message
 import (
 	"encoding/base64"
 	"fmt"
+	"gitlab.com/elixxir/client/interfaces"
+	"gitlab.com/elixxir/client/storage"
+	"gitlab.com/elixxir/client/storage/utility"
+	"gitlab.com/elixxir/crypto/fastRNG"
 
 	jww "github.com/spf13/jwalterweatherman"
 	"gitlab.com/elixxir/client/interfaces/params"
 	"gitlab.com/elixxir/client/network/gateway"
-	"gitlab.com/elixxir/client/network/internal"
 	"gitlab.com/elixxir/client/stoppable"
 	"gitlab.com/elixxir/comms/network"
 )
 
+const (
+	garbledMessagesKey = "GarbledMessages"
+)
+
 type Manager struct {
-	param params.Network
-	internal.Internal
+	param            params.Network
 	sender           *gateway.Sender
 	blacklistedNodes map[string]interface{}
 
@@ -30,12 +36,27 @@ type Manager struct {
 	networkIsHealthy chan bool
 	triggerGarbled   chan struct{}
 
+	garbledStore *utility.MeteredCmixMessageBuffer
+
+	rng     *fastRNG.StreamGenerator
+	events  interfaces.EventManager
+	comms   SendCmixCommsInterface
+	session *storage.Session
+
 	FingerprintsManager
 	TriggersManager
 }
 
-func NewManager(internal internal.Internal, param params.Network,
-	nodeRegistration chan network.NodeGateway, sender *gateway.Sender) *Manager {
+func NewManager(param params.Network,
+	nodeRegistration chan network.NodeGateway, sender *gateway.Sender,
+	session *storage.Session, rng *fastRNG.StreamGenerator,
+	events interfaces.EventManager, comms SendCmixCommsInterface) *Manager {
+
+	garbled, err := utility.NewOrLoadMeteredCmixMessageBuffer(session.GetKV(), garbledMessagesKey)
+	if err != nil {
+		jww.FATAL.Panicf("Failed to load or new the Garbled Messages system")
+	}
+
 	m := Manager{
 		param:            param,
 		messageReception: make(chan Bundle, param.MessageReceptionBuffLen),
@@ -43,7 +64,11 @@ func NewManager(internal internal.Internal, param params.Network,
 		triggerGarbled:   make(chan struct{}, 100),
 		nodeRegistration: nodeRegistration,
 		sender:           sender,
-		Internal:         internal,
+		garbledStore:     garbled,
+		rng:              rng,
+		events:           events,
+		comms:            comms,
+		session:          session,
 	}
 	for _, nodeId := range param.BlacklistedNodes {
 		decodedId, err := base64.StdEncoding.DecodeString(nodeId)
diff --git a/network/message/sendCmix.go b/network/message/sendCmix.go
index 8d6a2961a..9faa51049 100644
--- a/network/message/sendCmix.go
+++ b/network/message/sendCmix.go
@@ -74,7 +74,7 @@ func sendCmixHelper(sender *gateway.Sender, msg format.Message,
 	recipient *id.ID, cmixParams params.CMIX, blacklistedNodes map[string]interface{}, instance *network.Instance,
 	session *storage.Session, nodeRegistration chan network.NodeGateway,
 	rng *fastRNG.StreamGenerator, events interfaces.EventManager,
-	senderId *id.ID, comms sendCmixCommsInterface,
+	senderId *id.ID, comms SendCmixCommsInterface,
 	stop *stoppable.Single) (id.Round, ephemeral.Id, error) {
 
 	timeStart := netTime.Now()
diff --git a/network/message/sendCmixUtils.go b/network/message/sendCmixUtils.go
index 49c682af1..feb86c585 100644
--- a/network/message/sendCmixUtils.go
+++ b/network/message/sendCmixUtils.go
@@ -30,7 +30,7 @@ import (
 )
 
 // Interface for SendCMIX comms; allows mocking this in testing.
-type sendCmixCommsInterface interface {
+type SendCmixCommsInterface interface {
 	// SendPutMessage places a cMix message on the gateway to be
 	// sent through cMix.
 	SendPutMessage(host *connect.Host, message *pb.GatewaySlot,
diff --git a/network/message/sendManyCmix.go b/network/message/sendManyCmix.go
index 777dc76f0..a814b8493 100644
--- a/network/message/sendManyCmix.go
+++ b/network/message/sendManyCmix.go
@@ -60,7 +60,7 @@ func sendManyCmixHelper(sender *gateway.Sender,
 	blacklistedNodes map[string]interface{}, instance *network.Instance,
 	session *storage.Session, nodeRegistration chan network.NodeGateway,
 	rng *fastRNG.StreamGenerator, events interfaces.EventManager,
-	senderId *id.ID, comms sendCmixCommsInterface, stop *stoppable.Single) (
+	senderId *id.ID, comms SendCmixCommsInterface, stop *stoppable.Single) (
 	id.Round, []ephemeral.Id, error) {
 
 	timeStart := netTime.Now()
diff --git a/storage/messages.go b/storage/messages.go
index 31f518fbd..71c30d84d 100644
--- a/storage/messages.go
+++ b/storage/messages.go
@@ -10,6 +10,5 @@ package storage
 const (
 	criticalMessagesKey    = "CriticalMessages"
 	criticalRawMessagesKey = "CriticalRawMessages"
-	garbledMessagesKey     = "GarbledMessages"
 	checkedRoundsKey       = "CheckedRounds"
 )
diff --git a/storage/utility/meteredCmixMessageBuffer.go b/storage/utility/meteredCmixMessageBuffer.go
index b0b425800..1061e2a3c 100644
--- a/storage/utility/meteredCmixMessageBuffer.go
+++ b/storage/utility/meteredCmixMessageBuffer.go
@@ -12,9 +12,11 @@ import (
 	"github.com/pkg/errors"
 	jww "github.com/spf13/jwalterweatherman"
 	"gitlab.com/elixxir/client/storage/versioned"
+	pb "gitlab.com/elixxir/comms/mixmessages"
 	"gitlab.com/elixxir/primitives/format"
 	"gitlab.com/xx_network/primitives/netTime"
 	"golang.org/x/crypto/blake2b"
+	"google.golang.org/protobuf/proto"
 	"time"
 )
 
@@ -24,6 +26,7 @@ type meteredCmixMessageHandler struct{}
 
 type meteredCmixMessage struct {
 	M         []byte
+	Ri        []byte
 	Count     uint
 	Timestamp time.Time
 }
@@ -80,6 +83,7 @@ func (*meteredCmixMessageHandler) HashMessage(m interface{}) MessageHash {
 	h, _ := blake2b.New256(nil)
 
 	h.Write(m.(meteredCmixMessage).M)
+	h.Write(m.(meteredCmixMessage).Ri)
 
 	var messageHash MessageHash
 	copy(messageHash[:], h.Sum(nil))
@@ -113,32 +117,54 @@ func LoadMeteredCmixMessageBuffer(kv *versioned.KV, key string) (*MeteredCmixMes
 	return &MeteredCmixMessageBuffer{mb: mb, kv: kv, key: key}, nil
 }
 
-func (mcmb *MeteredCmixMessageBuffer) Add(m format.Message) {
+func NewOrLoadMeteredCmixMessageBuffer(kv *versioned.KV, key string) (*MeteredCmixMessageBuffer, error) {
+	mb, err := LoadMessageBuffer(kv, &meteredCmixMessageHandler{}, key)
+	if err != nil {
+		jww.WARN.Printf("Failed to find MeteredCmixMessageBuffer %s, making a new one", key)
+		return NewMeteredCmixMessageBuffer(kv, key)
+	}
+
+	return &MeteredCmixMessageBuffer{mb: mb, kv: kv, key: key}, nil
+}
+
+func (mcmb *MeteredCmixMessageBuffer) Add(m format.Message, ri *pb.RoundInfo) {
 	if m.GetPrimeByteLen() == 0 {
 		jww.FATAL.Panicf("Cannot handle a metered " +
 			"cmix message with a length of 0")
 	}
+	riMarshal, err := proto.Marshal(ri)
+	if err != nil {
+		jww.FATAL.Panicf("Failed to marshal round info")
+	}
+
 	msg := meteredCmixMessage{
 		M:         m.Marshal(),
+		Ri:        riMarshal,
 		Count:     0,
 		Timestamp: netTime.Now(),
 	}
 	mcmb.mb.Add(msg)
 }
 
-func (mcmb *MeteredCmixMessageBuffer) AddProcessing(m format.Message) {
+func (mcmb *MeteredCmixMessageBuffer) AddProcessing(m format.Message, ri *pb.RoundInfo) {
+	riMarshal, err := proto.Marshal(ri)
+	if err != nil {
+		jww.FATAL.Panicf("Failed to marshal round info")
+	}
+
 	msg := meteredCmixMessage{
 		M:         m.Marshal(),
+		Ri:        riMarshal,
 		Count:     0,
 		Timestamp: netTime.Now(),
 	}
 	mcmb.mb.AddProcessing(msg)
 }
 
-func (mcmb *MeteredCmixMessageBuffer) Next() (format.Message, uint, time.Time, bool) {
+func (mcmb *MeteredCmixMessageBuffer) Next() (format.Message, *pb.RoundInfo, uint, time.Time, bool) {
 	m, ok := mcmb.mb.Next()
 	if !ok {
-		return format.Message{}, 0, time.Time{}, false
+		return format.Message{}, nil, 0, time.Time{}, false
 	}
 
 	msg := m.(meteredCmixMessage)
@@ -158,7 +184,12 @@ func (mcmb *MeteredCmixMessageBuffer) Next() (format.Message, uint, time.Time, b
 		jww.FATAL.Panicf("Failed to unmarshal message after count "+
 			"update: %s", err)
 	}
-	return msfFormat, rtnCnt, msg.Timestamp, true
+
+	ri := &pb.RoundInfo{}
+	err = proto.Unmarshal(msg.Ri, ri)
+	jww.FATAL.Panicf("Failed to unmarshal round info from msg format")
+
+	return msfFormat, ri, rtnCnt, msg.Timestamp, true
 }
 
 func (mcmb *MeteredCmixMessageBuffer) Remove(m format.Message) {
-- 
GitLab