diff --git a/context/networkManager.go b/context/networkManager.go index 8f1fa83d6e2efaa483b21f71d3934770d1cc2d88..9c5717ae23c937542760d2d86a8f29a840c7a9cb 100644 --- a/context/networkManager.go +++ b/context/networkManager.go @@ -9,9 +9,9 @@ import ( ) type NetworkManager interface { - SendE2E(m message.Send, e2eP params.E2E, cmixP params.CMIX) ([]id.Round, error) - SendUnsafe(m message.Send) ([]id.Round, error) - SendCMIX(message format.Message) (id.Round, error) + SendE2E(m message.Send, p params.E2E) ([]id.Round, error) + SendUnsafe(m message.Send, p params.Unsafe) ([]id.Round, error) + SendCMIX(message format.Message, p params.CMIX) (id.Round, error) GetInstance() *network.Instance GetHealthTracker() HealthTracker } diff --git a/context/params/E2E.go b/context/params/E2E.go index 2d0d0ee9cd3b9f24aaa4f3c499ae9fdfb77ec279..22d6cd98b3a187d18751cbc9ed8844d027100afb 100644 --- a/context/params/E2E.go +++ b/context/params/E2E.go @@ -4,10 +4,13 @@ import "fmt" type E2E struct { Type SendType + CMIX } func GetDefaultE2E() E2E { - return E2E{Type: Standard} + return E2E{Type: Standard, + CMIX: GetDefaultCMIX(), + } } type SendType uint8 diff --git a/context/params/Unsafe.go b/context/params/Unsafe.go new file mode 100644 index 0000000000000000000000000000000000000000..b18c8c03162240b08d591ea50308e046199aa786 --- /dev/null +++ b/context/params/Unsafe.go @@ -0,0 +1,9 @@ +package params + +type Unsafe struct { + CMIX +} + +func GetDefaultUnsafe() Unsafe { + return Unsafe{CMIX: GetDefaultCMIX()} +} diff --git a/context/stoppable/cleanup.go b/context/stoppable/cleanup.go new file mode 100644 index 0000000000000000000000000000000000000000..6c84deba1a01afbb48ed98947fa0ecfb5cbdb1ee --- /dev/null +++ b/context/stoppable/cleanup.go @@ -0,0 +1,81 @@ +package stoppable + +import ( + "github.com/pkg/errors" + "sync" + "sync/atomic" + "time" +) + +// Wraps any stoppable and runs a callback after to stop for cleanup behavior +// the cleanup is run under the remainder of the timeout but will not be canceled +// if the timeout runs out +// the cleanup function does not run if the thread does not stop +type Cleanup struct { + stop Stoppable + // the clean function receives how long it has to run before the timeout, + // this is nto expected to be used in most cases + clean func(duration time.Duration) error + running uint32 + once sync.Once +} + +// Creates a new cleanup from the passed stoppable and the function +func NewCleanup(stop Stoppable, clean func(duration time.Duration) error) *Cleanup { + return &Cleanup{ + stop: stop, + clean: clean, + running: 0, + } +} + +// returns true if the thread is still running and its cleanup has completed +func (c *Cleanup) IsRunning() bool { + return atomic.LoadUint32(&c.running) == 1 +} + +// returns the name of the stoppable denoting it has cleanup +func (c *Cleanup) Name() string { + return c.stop.Name() + " with cleanup" +} + +// stops the contained stoppable and runs the cleanup function after. +// the cleanup function does not run if the thread does not stop +func (c *Cleanup) Close(timeout time.Duration) error { + var err error + + c.once.Do( + func() { + defer atomic.StoreUint32(&c.running, 0) + start := time.Now() + + //run the stopable + if err := c.stop.Close(timeout); err != nil { + err = errors.WithMessagef(err, "Cleanup for %s not executed", + c.stop.Name()) + return + } + + //run the cleanup function with the remaining time as a timeout + elapsed := time.Since(start) + + complete := make(chan error, 1) + go func() { + complete <- c.clean(elapsed) + }() + + timer := time.NewTimer(elapsed) + + select { + case err := <-complete: + if err != nil { + err = errors.WithMessagef(err, "Cleanup for %s "+ + "failed", c.stop.Name()) + } + case <-timer.C: + err = errors.Errorf("Clean up for %s timedout", c.stop.Name()) + } + }) + + return err +} diff --git a/context/stoppable/multi.go b/context/stoppable/multi.go index f30c484758a5689e6fd0388b6b996f6bf764a9e0..c9ee1ff61fc79a76ed9f268be0fe009a597a0e81 100644 --- a/context/stoppable/multi.go +++ b/context/stoppable/multi.go @@ -14,6 +14,7 @@ type Multi struct { name string running uint32 mux sync.RWMutex + once sync.Once } //returns a new multi stoppable @@ -56,37 +57,37 @@ func (m *Multi) Name() string { // closes all child stoppers. It does not return their errors and assumes they // print them to the log func (m *Multi) Close(timeout time.Duration) error { - if !m.IsRunning() { - return nil - } - - m.mux.Lock() - defer m.mux.Unlock() - - numErrors := uint32(0) - - wg := &sync.WaitGroup{} - - for _, stoppable := range m.stoppables { - wg.Add(1) - go func() { - if stoppable.Close(timeout) != nil { - atomic.AddUint32(&numErrors, 1) + var err error + m.once.Do( + func() { + atomic.StoreUint32(&m.running, 0) + + numErrors := uint32(0) + + wg := &sync.WaitGroup{} + + m.mux.Lock() + for _, stoppable := range m.stoppables { + wg.Add(1) + go func() { + if stoppable.Close(timeout) != nil { + atomic.AddUint32(&numErrors, 1) + } + wg.Done() + }() } - wg.Done() - }() - } + m.mux.Unlock() - wg.Wait() + wg.Wait() - atomic.StoreUint32(&m.running, 0) + if numErrors > 0 { + errStr := fmt.Sprintf("MultiStopper %s failed to close "+ + "%v/%v stoppers", m.name, numErrors, len(m.stoppables)) + jww.ERROR.Println(errStr) + err = errors.New(errStr) + } + }) - if numErrors > 0 { - errStr := fmt.Sprintf("MultiStopper %s failed to close "+ - "%v/%v stoppers", m.name, numErrors, len(m.stoppables)) - jww.ERROR.Println(errStr) - return errors.New(errStr) - } + return err - return nil } diff --git a/context/stoppable/single.go b/context/stoppable/single.go index bd0184656f86d56021e3cb1a351765dc56e40929..78a2655fa72abdcdc7bacf5a200bdc324f4bed10 100644 --- a/context/stoppable/single.go +++ b/context/stoppable/single.go @@ -3,6 +3,7 @@ package stoppable import ( "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" + "sync" "sync/atomic" "time" ) @@ -13,6 +14,7 @@ type Single struct { name string quit chan struct{} running uint32 + once sync.Once } //returns a new single stoppable @@ -41,18 +43,17 @@ func (s *Single) Name() string { // Close signals thread to time out and closes if it is still running. func (s *Single) Close(timeout time.Duration) error { - if !s.IsRunning() { - return nil - } - defer atomic.StoreUint32(&s.running, 0) - timer := time.NewTimer(timeout) - select { - case <-timer.C: - jww.ERROR.Printf("Stopper for %s failed to stop after "+ - "timeout of %s", s.name, timeout) - return errors.Errorf("%s failed to close", s.name) - case <-s.quit: - - return nil - } + var err error + s.once.Do(func() { + atomic.StoreUint32(&s.running, 0) + timer := time.NewTimer(timeout) + select { + case <-timer.C: + jww.ERROR.Printf("Stopper for %s failed to stop after "+ + "timeout of %s", s.name, timeout) + err = errors.Errorf("%s failed to close", s.name) + case <-s.quit: + } + }) + return err } diff --git a/network/keyExchange/confirm.go b/network/keyExchange/confirm.go new file mode 100644 index 0000000000000000000000000000000000000000..284b869fb5ba98e34352d0b539b868150601b2a4 --- /dev/null +++ b/network/keyExchange/confirm.go @@ -0,0 +1 @@ +package keyExchange diff --git a/network/keyExchange/init.go b/network/keyExchange/init.go index 854eb34f7b3ab26e2f84c3ec314baacae5b51774..22e2c3f23ea099113bfc4c416063d1e7d4f3158e 100644 --- a/network/keyExchange/init.go +++ b/network/keyExchange/init.go @@ -15,23 +15,12 @@ const keyExchangeTriggerName = "KeyExchangeTrigger" func Init(ctx *context.Context) stoppable.Stoppable { //register the rekey request thread - rekeyRequestCh := make(chan message.Receive, 10) + rekeyRequestCh := make(chan message.Receive, 100) ctx.Switchboard.RegisterChannel(keyExchangeTriggerName, &id.ID{}, message.KeyExchangeTrigger, rekeyRequestCh) triggerStop := stoppable.NewSingle(keyExchangeTriggerName) - go func() { - for true { - select { - case <-triggerStop.Quit(): - return - case request := <-rekeyRequestCh: - ctx.Session.request.Sender - } - } - () - } } diff --git a/network/keyExchange/rekey.go b/network/keyExchange/rekey.go index c05711cf21e6f79c044682175d13f53e82d605ab..2fcf6e7340fac74794db60f6100068556022a228 100644 --- a/network/keyExchange/rekey.go +++ b/network/keyExchange/rekey.go @@ -91,10 +91,8 @@ func negotiate(ctx *context.Context, session *e2e.Session) error { //send the message under the key exchange e2eParams := params.GetDefaultE2E() e2eParams.Type = params.KeyExchange - cmixParams := params.GetDefaultCMIX() - cmixParams.RoundTries = 20 - rounds, err := ctx.Manager.SendE2E(m, e2eParams, cmixParams) + rounds, err := ctx.Manager.SendE2E(m, e2eParams) // If the send fails, returns the error so it can be handled. The caller // should ensure the calling session is in a state where the Rekey will // be triggered next time a key is used diff --git a/network/keyExchange/trigger.go b/network/keyExchange/trigger.go index 5e3925868a9162d23b098fc0d56b1d864edff8be..0c91ae7b033d4cc518d793a1a2b13ea19b4cb7cb 100644 --- a/network/keyExchange/trigger.go +++ b/network/keyExchange/trigger.go @@ -7,6 +7,7 @@ import ( "gitlab.com/elixxir/client/context" "gitlab.com/elixxir/client/context/message" "gitlab.com/elixxir/client/context/params" + "gitlab.com/elixxir/client/context/stoppable" "gitlab.com/elixxir/client/context/utility" "gitlab.com/elixxir/client/storage/e2e" ds "gitlab.com/elixxir/comms/network/dataStructures" @@ -15,6 +16,18 @@ import ( "time" ) +func startTrigger(ctx *context.Context, triggerChan chan message.Receive, stop stoppable.Single) { + for true { + select { + case <-stop.Quit(): + return + case request := <-triggerChan: + handleTrigger(ctx, request) + } + } +} + + func handleTrigger(ctx *context.Context, request message.Receive) { //ensure the message was encrypted properly if request.Encryption != message.E2E { @@ -83,9 +96,12 @@ func handleTrigger(ctx *context.Context, request message.Receive) { //send the message under the key exchange e2eParams := params.GetDefaultE2E() - cmixParams := params.GetDefaultCMIX() - rounds, err := ctx.Manager.SendE2E(m, e2eParams, cmixParams) + // store in critical messages buffer first to ensure it is resent if the + // send fails + ctx.Session.GetCriticalMessages().AddProcessing(m, e2eParams) + + rounds, err := ctx.Manager.SendE2E(m, e2eParams) //Register the event for all rounds sendResults := make(chan ds.EventReturn, len(rounds)) @@ -102,20 +118,21 @@ func handleTrigger(ctx *context.Context, request message.Receive) { // transmit, the partner will not be able to read the confirmation. If // such a failure occurs if !success { - session.SetNegotiationStatus(e2e.Unconfirmed) - return errors.Errorf("Key Negotiation for %s failed to "+ + jww.ERROR.Printf("Key Negotiation for %s failed to "+ "transmit %v/%v paritions: %v round failures, %v timeouts", - session, numRoundFail+numTimeOut, len(rounds), numRoundFail, + newSession, numRoundFail+numTimeOut, len(rounds), numRoundFail, numTimeOut) + newSession.SetNegotiationStatus(e2e.Unconfirmed) + ctx.Session.GetCriticalMessages().Failed(m) + return } // otherwise, the transmission is a success and this should be denoted // in the session and the log + newSession.SetNegotiationStatus(e2e.Sent) + ctx.Session.GetCriticalMessages().Succeeded(m) jww.INFO.Printf("Key Negotiation transmission for %s sucesfull", - session) - session.SetNegotiationStatus(e2e.Sent) - - + newSession) } func unmarshalKeyExchangeTrigger(grp *cyclic.Group, payload []byte) (e2e.SessionID, diff --git a/storage/utility/cmixMessageBuffer.go b/storage/utility/cmixMessageBuffer.go index 1a1756787ba9c080e341053d611a4bbde846843d..ca93a11845b81f56930651352ef7ba91b472435f 100644 --- a/storage/utility/cmixMessageBuffer.go +++ b/storage/utility/cmixMessageBuffer.go @@ -82,6 +82,10 @@ func (cmb *CmixMessageBuffer) Add(m format.Message) { cmb.mb.Add(m) } +func (cmb *CmixMessageBuffer) AddProcessing(m format.Message) { + cmb.mb.AddProcessing(m) +} + func (cmb *CmixMessageBuffer) Next() (format.Message, bool) { m, ok := cmb.mb.Next() if !ok { diff --git a/storage/utility/e2eMessageBuffer.go b/storage/utility/e2eMessageBuffer.go index 81ec03ba3abd385f8395218d76b9e5829c98e570..ae9cd73d446c7b8632fd6e563933c74d61c81924 100644 --- a/storage/utility/e2eMessageBuffer.go +++ b/storage/utility/e2eMessageBuffer.go @@ -6,6 +6,7 @@ import ( "encoding/json" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/context/message" + "gitlab.com/elixxir/client/context/params" "gitlab.com/elixxir/client/storage/versioned" "gitlab.com/xx_network/primitives/id" "time" @@ -19,6 +20,7 @@ type e2eMessage struct { Recipient []byte Payload []byte MessageType uint32 + Params params.E2E } // SaveMessage saves the e2eMessage as a versioned object at the specified key @@ -68,6 +70,8 @@ func (emh *e2eMessageHandler) DeleteMessage(kv *versioned.KV, key string) error } // HashMessage generates a hash of the e2eMessage. +// Do not include the params in the hash so it is not needed to resubmit the +// message into succeeded or failed func (emh *e2eMessageHandler) HashMessage(m interface{}) MessageHash { msg := m.(e2eMessage) @@ -105,20 +109,32 @@ func LoadE2eMessageBuffer(kv *versioned.KV, key string) (*E2eMessageBuffer, erro return &E2eMessageBuffer{mb: mb}, nil } -func (emb *E2eMessageBuffer) Add(m message.Send) { +func (emb *E2eMessageBuffer) Add(m message.Send, p params.E2E) { e2eMsg := e2eMessage{ Recipient: m.Recipient.Marshal(), Payload: m.Payload, MessageType: uint32(m.MessageType), + Params: p, } emb.mb.Add(e2eMsg) } -func (emb *E2eMessageBuffer) Next() (message.Send, bool) { +func (emb *E2eMessageBuffer) AddProcessing(m message.Send, p params.E2E) { + e2eMsg := e2eMessage{ + Recipient: m.Recipient.Marshal(), + Payload: m.Payload, + MessageType: uint32(m.MessageType), + Params: p, + } + + emb.mb.AddProcessing(e2eMsg) +} + +func (emb *E2eMessageBuffer) Next() (message.Send, params.E2E, bool) { m, ok := emb.mb.Next() if !ok { - return message.Send{}, false + return message.Send{}, params.E2E{}, false } msg := m.(e2eMessage) @@ -126,13 +142,16 @@ func (emb *E2eMessageBuffer) Next() (message.Send, bool) { if err != nil { jww.FATAL.Panicf("Error unmarshaling recipient: %v", err) } - return message.Send{recipient, msg.Payload, message.Type(msg.MessageType)}, true + return message.Send{recipient, msg.Payload, + message.Type(msg.MessageType)}, msg.Params, true } func (emb *E2eMessageBuffer) Succeeded(m message.Send) { - emb.mb.Succeeded(e2eMessage{m.Recipient.Marshal(), m.Payload, uint32(m.MessageType)}) + emb.mb.Succeeded(e2eMessage{m.Recipient.Marshal(), + m.Payload, uint32(m.MessageType), params.E2E{}}) } func (emb *E2eMessageBuffer) Failed(m message.Send) { - emb.mb.Failed(e2eMessage{m.Recipient.Marshal(), m.Payload, uint32(m.MessageType)}) + emb.mb.Failed(e2eMessage{m.Recipient.Marshal(), + m.Payload, uint32(m.MessageType), params.E2E{}}) } diff --git a/storage/utility/e2eMessageBuffer_test.go b/storage/utility/e2eMessageBuffer_test.go index 8a2494903f86d4d8cdfefefc42b7048ce64ef2c6..067f5d46abce37cb65574411b6e181bba7dff1ea 100644 --- a/storage/utility/e2eMessageBuffer_test.go +++ b/storage/utility/e2eMessageBuffer_test.go @@ -3,6 +3,7 @@ package utility import ( "encoding/json" "gitlab.com/elixxir/client/context/message" + "gitlab.com/elixxir/client/context/params" "gitlab.com/elixxir/client/storage/versioned" "gitlab.com/elixxir/ekv" "gitlab.com/xx_network/primitives/id" @@ -92,15 +93,15 @@ func TestE2EMessageHandler_Smoke(t *testing.T) { } // Add two messages - cmb.Add(testMsgs[0]) - cmb.Add(testMsgs[1]) + cmb.Add(testMsgs[0], params.E2E{}) + cmb.Add(testMsgs[1], params.E2E{}) if len(cmb.mb.messages) != 2 { t.Errorf("Unexpected length of buffer.\n\texpected: %d\n\trecieved: %d", 2, len(cmb.mb.messages)) } - msg, exists := cmb.Next() + msg, _, exists := cmb.Next() if !exists { t.Error("Next() did not find any messages in buffer.") } @@ -111,7 +112,7 @@ func TestE2EMessageHandler_Smoke(t *testing.T) { 1, len(cmb.mb.messages)) } - msg, exists = cmb.Next() + msg, _, exists = cmb.Next() if !exists { t.Error("Next() did not find any messages in buffer.") } @@ -126,13 +127,13 @@ func TestE2EMessageHandler_Smoke(t *testing.T) { 1, len(cmb.mb.messages)) } - msg, exists = cmb.Next() + msg, _, exists = cmb.Next() if !exists { t.Error("Next() did not find any messages in buffer.") } cmb.Succeeded(msg) - msg, exists = cmb.Next() + msg, _, exists = cmb.Next() if exists { t.Error("Next() found a message in the buffer when it should be empty.") } diff --git a/storage/utility/messageBuffer.go b/storage/utility/messageBuffer.go index d40feafbf6e71746297280c517814e5515f467fd..e79d29e599524e086d8e4dd1d4b8dd14dfec5a37 100644 --- a/storage/utility/messageBuffer.go +++ b/storage/utility/messageBuffer.go @@ -197,6 +197,36 @@ func (mb *MessageBuffer) Add(m interface{}) { } } +// Add adds a message to the buffer in "processing" state. +func (mb *MessageBuffer) AddProcessing(m interface{}) { + h := mb.handler.HashMessage(m) + + mb.mux.Lock() + defer mb.mux.Unlock() + + // Ensure message does not already exist in buffer + _, exists1 := mb.messages[h] + _, exists2 := mb.processingMessages[h] + if exists1 || exists2 { + return + } + + // Save message as versioned object + err := mb.handler.SaveMessage(mb.kv, m, makeStoredMessageKey(mb.key, h)) + if err != nil { + jww.FATAL.Panicf("Error saving message: %v", err) + } + + // Add message to the buffer + mb.processingMessages[h] = struct{}{} + + // Save buffer + err = mb.save() + if err != nil { + jww.FATAL.Panicf("Error whilse saving buffer: %v", err) + } +} + // Next gets the next message from the buffer whose state is "not processing". // The returned messages are moved to the processing state. If there are no // messages remaining, then false is returned.