diff --git a/cmd/single.go b/cmd/single.go new file mode 100644 index 0000000000000000000000000000000000000000..e9fbc5dcaeeb02060054e709956e7a55109ceeb4 --- /dev/null +++ b/cmd/single.go @@ -0,0 +1,266 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +// Package cmd initializes the CLI and config parsers as well as the logger. +package cmd + +import ( + "bytes" + "fmt" + "github.com/spf13/cobra" + jww "github.com/spf13/jwalterweatherman" + "github.com/spf13/viper" + "gitlab.com/elixxir/client/interfaces/contact" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/single" + "gitlab.com/elixxir/client/switchboard" + "gitlab.com/xx_network/primitives/utils" + "time" +) + +// singleCmd is the single-use subcommand that allows for sending and responding +// to single-use messages. +var singleCmd = &cobra.Command{ + Use: "single", + Short: "Send and respond to single-use messages.", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + + client := initClient() + + // Write user contact to file + user := client.GetUser() + jww.INFO.Printf("User: %s", user.ReceptionID) + writeContact(user.GetContact()) + + // Set up reception handler + swBoard := client.GetSwitchboard() + recvCh := make(chan message.Receive, 10000) + listenerID := swBoard.RegisterChannel("DefaultCLIReceiver", + switchboard.AnyUser(), message.Text, recvCh) + jww.INFO.Printf("Message ListenerID: %v", listenerID) + + // Set up auth request handler, which simply prints the user ID of the + // requester + authMgr := client.GetAuthRegistrar() + authMgr.AddGeneralRequestCallback(printChanRequest) + + // If unsafe channels, then add auto-acceptor + if viper.GetBool("unsafe-channel-creation") { + authMgr.AddGeneralRequestCallback(func( + requester contact.Contact, message string) { + jww.INFO.Printf("Got request: %s", requester.ID) + err := client.ConfirmAuthenticatedChannel(requester) + if err != nil { + jww.FATAL.Panicf("%+v", err) + } + }) + } + + err := client.StartNetworkFollower() + if err != nil { + jww.FATAL.Panicf("%+v", err) + } + + // Wait until connected or crash on timeout + connected := make(chan bool, 10) + client.GetHealth().AddChannel(connected) + waitUntilConnected(connected) + + // Make single-use manager and start receiving process + singleMng := single.NewManager(client, client.GetStorage().Reception()) + + // Register the callback + callbackChan := make(chan responseCallbackChan) + callback := func(payload []byte, c single.Contact) { + callbackChan <- responseCallbackChan{payload, c} + } + singleMng.RegisterCallback("tag", callback) + client.AddService(singleMng.StartProcesses) + + timeout := viper.GetDuration("timeout") + + // If the send flag is set, then send a message + if viper.GetBool("send") { + // Get message details + payload := []byte(viper.GetString("message")) + partner := readSingleUseContact("contact") + maxMessages := uint8(viper.GetUint("maxMessages")) + + sendSingleUse(singleMng, partner, payload, maxMessages, timeout) + } + + // If the reply flag is set, then start waiting for a message and reply + // when it is received + if viper.GetBool("reply") { + replySingleUse(singleMng, timeout, callbackChan) + } + }, +} + +func init() { + // Single-use subcommand options + + singleCmd.Flags().Bool("send", false, "Sends a single-use message.") + _ = viper.BindPFlag("send", singleCmd.Flags().Lookup("send")) + + singleCmd.Flags().Bool("reply", false, + "Listens for a single-use message and sends a reply.") + _ = viper.BindPFlag("reply", singleCmd.Flags().Lookup("reply")) + + singleCmd.Flags().StringP("contact", "c", "", + "Path to contact file to send message to.") + _ = viper.BindPFlag("contact", singleCmd.Flags().Lookup("contact")) + + singleCmd.Flags().StringP("message", "m", "", + "Message to send.") + _ = viper.BindPFlag("message", singleCmd.Flags().Lookup("message")) + + singleCmd.Flags().Uint8("maxMessages", 1, + "The max number of single-use response messages.") + _ = viper.BindPFlag("maxMessages", singleCmd.Flags().Lookup("maxMessages")) + + singleCmd.Flags().DurationP("timeout", "t", 30*time.Second, + "Duration before stopping to wait for single-use message.") + _ = viper.BindPFlag("timeout", singleCmd.Flags().Lookup("timeout")) + + rootCmd.AddCommand(singleCmd) +} + +// sendSingleUse sends a single use message. +func sendSingleUse(m *single.Manager, partner *contact.Contact, payload []byte, + maxMessages uint8, timeout time.Duration) { + // Construct callback + callbackChan := make(chan struct { + payload []byte + err error + }) + callback := func(payload []byte, err error) { + callbackChan <- struct { + payload []byte + err error + }{payload: payload, err: err} + } + + jww.INFO.Printf("Sending single-use message to contact: %+v", partner) + jww.INFO.Printf("Payload: \"%s\"", payload) + jww.INFO.Printf("Max number of replies: %d", maxMessages) + jww.INFO.Printf("Timeout: %s", timeout) + + // Send single-use message + fmt.Printf("Sending single-use transmission message: %s\n", payload) + jww.DEBUG.Printf("Sending single-use transmission to %s: %s", partner.ID, payload) + err := m.TransmitSingleUse(partner, payload, "tag", maxMessages, callback, timeout) + if err != nil { + jww.FATAL.Panicf("Failed to transmit single-use message: %+v", err) + } + + // Wait for callback to be called + fmt.Println("Waiting for response.") + results := <-callbackChan + if results.payload != nil { + fmt.Printf("Message received: %s\n", results.payload) + jww.DEBUG.Printf("Received single-use reply payload: %s", results.payload) + } else { + jww.ERROR.Print("Failed to receive single-use reply payload.") + } + + if results.err != nil { + jww.FATAL.Panicf("Received error when waiting for reply: %+v", results.err) + } +} + +// replySingleUse responds to any single-use message it receives by replying\ +// with the same payload. +func replySingleUse(m *single.Manager, timeout time.Duration, callbackChan chan responseCallbackChan) { + + // Wait to receive a message or stop after timeout occurs + fmt.Println("Waiting for single-use message.") + timer := time.NewTimer(timeout) + select { + case results := <-callbackChan: + if results.payload != nil { + fmt.Printf("Single-use transmission received: %s\n", results.payload) + jww.DEBUG.Printf("Received single-use transmission from %s: %s", + results.c.GetPartner(), results.payload) + } else { + jww.ERROR.Print("Failed to receive single-use payload.") + } + + // Create new payload from repeated received payloads so that each + // message part contains the same payload + payload := makeResponsePayload(m, results.payload, results.c.GetMaxParts()) + + fmt.Printf("Sending single-use response message: %s\n", payload) + jww.DEBUG.Printf("Sending single-use response to %s: %s", results.c.GetPartner(), payload) + err := m.RespondSingleUse(results.c, payload, timeout) + if err != nil { + jww.FATAL.Panicf("Failed to send response: %+v", err) + } + + case <-timer.C: + fmt.Println("Timed out!") + jww.FATAL.Panicf("Failed to receive transmission after %s.", timeout) + } +} + +// responseCallbackChan structure used to collect information sent to the +// response callback. +type responseCallbackChan struct { + payload []byte + c single.Contact +} + +// makeResponsePayload generates a new payload that will span the max number of +// message parts in the contact. Each resulting message payload will contain a +// copy of the supplied payload with spaces taking up any remaining data. +func makeResponsePayload(m *single.Manager, payload []byte, maxParts uint8) []byte { + payloads := make([][]byte, maxParts) + payloadPart := makeResponsePayloadPart(m, payload) + for i := range payloads { + payloads[i] = make([]byte, m.GetMaxResponsePayloadSize()) + copy(payloads[i], payloadPart) + } + return bytes.Join(payloads, []byte{}) +} + +// makeResponsePayloadPart creates a single response payload by coping the given +// payload and filling the rest with spaces. +func makeResponsePayloadPart(m *single.Manager, payload []byte) []byte { + payloadPart := make([]byte, m.GetMaxResponsePayloadSize()) + for i := range payloadPart { + payloadPart[i] = ' ' + } + copy(payloadPart, payload) + + return payloadPart +} + +// readSingleUseContact opens the contact specified in the CLI flags. Panics if +// no file provided or if an error occurs while reading or unmarshalling it. +func readSingleUseContact(key string) *contact.Contact { + // Get path + filePath := viper.GetString(key) + if filePath == "" { + jww.FATAL.Panicf("Failed to read contact file: no file path provided.") + } + + // Read from file + data, err := utils.ReadFile(filePath) + jww.INFO.Printf("Contact file size read in: %d bytes", len(data)) + if err != nil { + jww.FATAL.Panicf("Failed to read contact file: %+v", err) + } + + // Unmarshal contact + c, err := contact.Unmarshal(data) + if err != nil { + jww.FATAL.Panicf("Failed to unmarshal contact: %+v", err) + } + + return &c +} diff --git a/network/follow.go b/network/follow.go index b488ffdc53f839252398bd7e06e525c219bb204b..050501213144ea9dcfb2387b947d21ece62e12f1 100644 --- a/network/follow.go +++ b/network/follow.go @@ -72,7 +72,7 @@ func (m *manager) follow(rng csprng.Source, comms followNetworkComms) { //get the identity we will poll for identity, err := m.Session.Reception().GetIdentity(rng) if err != nil { - jww.FATAL.Panicf("Failed to get an ideneity, this should be "+ + jww.FATAL.Panicf("Failed to get an identity, this should be "+ "impossible: %+v", err) } @@ -191,7 +191,8 @@ func (m *manager) follow(rng csprng.Source, comms followNetworkComms) { } if len(pollResp.Filters.Filters) == 0 { - jww.DEBUG.Printf("no filters found for the passed ID, skipping processing") + jww.DEBUG.Printf("No filters found for the passed ID %d (%s), "+ + "skipping processing.", identity.EphId, identity.Source) return } @@ -211,13 +212,12 @@ func (m *manager) follow(rng csprng.Source, comms followNetworkComms) { //prepare the filter objects for processing filterList := make([]*rounds.RemoteFilter, filtersEnd-filtersStart) for i := filtersStart; i < filtersEnd; i++ { - if len(pollResp.Filters.Filters[i].Filter)!=0{ - jww.INFO.Printf("ima spam blooms: first: %d, last: %d, filter: %v", pollResp.Filters.Filters[i].FirstRound, pollResp.Filters.Filters[i].FirstRound+uint64(pollResp.Filters.Filters[i].RoundRange), pollResp.Filters.Filters[i].Filter) + if len(pollResp.Filters.Filters[i].Filter) != 0 { filterList[i-filtersStart] = rounds.NewRemoteFilter(pollResp.Filters.Filters[i]) - if filterList[i-filtersStart].FirstRound()<firstRound{ + if filterList[i-filtersStart].FirstRound() < firstRound { firstRound = filterList[i-filtersStart].FirstRound() } - if filterList[i-filtersStart].LastRound()>lastRound{ + if filterList[i-filtersStart].LastRound() > lastRound { lastRound = filterList[i-filtersStart].LastRound() } } diff --git a/network/message/handler.go b/network/message/handler.go index 0f20cd57b94b6edd9dc81c99fbd617823320d7b3..ae9ac78e29e14ec324d5ed5639a78c1b49be8565 100644 --- a/network/message/handler.go +++ b/network/message/handler.go @@ -53,7 +53,6 @@ func (m *Manager) handleMessage(ecrMsg format.Message, identity reception.Identi jww.FATAL.Panicf("Could not check IdentityFIngerprint: %+v", err) } if !forMe { - jww.INFO.Printf("I rejected: %s, am i dumb?", ecrMsg.GetContents()) return } diff --git a/network/message/sendCmix.go b/network/message/sendCmix.go index 5dd3b8af035999a5deac46df78eeb5a7afbd36af..7d1d716c4659254052563bce7bed8f55269bcb00 100644 --- a/network/message/sendCmix.go +++ b/network/message/sendCmix.go @@ -26,7 +26,7 @@ import ( "time" ) -// interface for sendcmix comms; allows mocking this in testing +// interface for SendCMIX comms; allows mocking this in testing type sendCmixCommsInterface interface { GetHost(hostId *id.ID) (*connect.Host, bool) SendPutMessage(host *connect.Host, message *pb.GatewaySlot) (*pb.GatewaySlotResponse, error) @@ -63,7 +63,7 @@ func sendCmixHelper(msg format.Message, recipient *id.ID, param params.CMIX, ins return 0, ephemeral.Id{}, errors.New("Sending cmix message timed out") } remainingTime := param.Timeout - elapsed - jww.TRACE.Printf("SendCMIX GetUpcommingRealtime") + jww.TRACE.Printf("SendCMIX GetUpcomingRealtime") //find the best round to send to, excluding attempted rounds bestRound, _ := instance.GetWaitingRounds().GetUpcomingRealtime(remainingTime, attempted) if bestRound == nil { @@ -76,7 +76,7 @@ func sendCmixHelper(msg format.Message, recipient *id.ID, param params.CMIX, ins now := time.Now() if now.After(roundCutoffTime) { - jww.WARN.Printf("Round %d received which has already started" + + jww.WARN.Printf("Round %d received which has already started"+ " realtime: \n\t started: %s \n\t now: %s", bestRound.ID, roundCutoffTime, now) attempted.Insert(bestRound) @@ -94,13 +94,13 @@ func sendCmixHelper(msg format.Message, recipient *id.ID, param params.CMIX, ins jww.INFO.Printf("Sending to EphID %v (source: %s)", ephID.Int64(), recipient) stream := rng.GetStream() - ephID, err = ephID.Fill(uint(bestRound.AddressSpaceSize), stream) + ephIdFilled, err := ephID.Fill(uint(bestRound.AddressSpaceSize), stream) if err != nil { - jww.FATAL.Panicf("Failed to obviscate the ephemeralID: %+v", err) + jww.FATAL.Panicf("Failed to obfuscate the ephemeralID: %+v", err) } stream.Close() - msg.SetEphemeralRID(ephID[:]) + msg.SetEphemeralRID(ephIdFilled[:]) //set the identity fingerprint ifp, err := fingerprint.IdentityFP(msg.GetContents(), recipient) diff --git a/network/rounds/remoteFilters.go b/network/rounds/remoteFilters.go index 2ef27e3665c0ad3ff759350ad113f6d1ab79dae5..24477369a4c797ca1280dbd071d8ab6a3b706259 100644 --- a/network/rounds/remoteFilters.go +++ b/network/rounds/remoteFilters.go @@ -70,9 +70,11 @@ func ValidFilterRange(identity reception.IdentityUse, filters *mixmessages.Clien return startIdx, endIdx, outOfBounds } - if int(endIdx) > len(filters.Filters)-1 { + if endIdx > len(filters.Filters)-1 { endIdx = len(filters.Filters) - 1 } - return startIdx, endIdx, outOfBounds + // Add 1 to the end index so that it follows Go's convention; the last index + // is exclusive to the range + return startIdx, endIdx + 1, outOfBounds } diff --git a/single/callbackMap.go b/single/callbackMap.go new file mode 100644 index 0000000000000000000000000000000000000000..be7b03380285b6d298f7b3344bfff08ce16228b0 --- /dev/null +++ b/single/callbackMap.go @@ -0,0 +1,60 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package single + +import ( + "github.com/pkg/errors" + "gitlab.com/elixxir/crypto/e2e/singleUse" + "sync" +) + +type receiveComm func(payload []byte, c Contact) + +// callbackMap stores a list of possible callbacks that can be called when a +// message is received. To receive a transmission, each transmitted message must +// use the same tag as the registered callback. The tag fingerprint is a hash of +// a tag string that is used to identify the module that the transmission +// message belongs to. The tag can be anything, but should be long enough so +// that it is unique. +type callbackMap struct { + callbacks map[singleUse.TagFP]receiveComm + sync.RWMutex +} + +// newCallbackMap initialises a new map. +func newCallbackMap() *callbackMap { + return &callbackMap{ + callbacks: map[singleUse.TagFP]receiveComm{}, + } +} + +// registerCallback adds a callback function to the map that associates it with +// its tag. The tag should be a unique string identifying the module using the +// callback. +func (cbm *callbackMap) registerCallback(tag string, callback receiveComm) { + cbm.Lock() + defer cbm.Unlock() + + tagFP := singleUse.NewTagFP(tag) + cbm.callbacks[tagFP] = callback +} + +// getCallback returns the callback registered with the given tag fingerprint. +// An error is returned if no associated callback exists. +func (cbm *callbackMap) getCallback(tagFP singleUse.TagFP) (receiveComm, error) { + cbm.RLock() + defer cbm.RUnlock() + + cb, exists := cbm.callbacks[tagFP] + if !exists { + return nil, errors.Errorf("no callback registered for the tag "+ + "fingerprint %s.", tagFP) + } + + return cb, nil +} diff --git a/single/callbackMap_test.go b/single/callbackMap_test.go new file mode 100644 index 0000000000000000000000000000000000000000..59ccb2dde618ebf9256cb82dd27aebb5d4d79903 --- /dev/null +++ b/single/callbackMap_test.go @@ -0,0 +1,82 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package single + +import ( + "gitlab.com/elixxir/crypto/e2e/singleUse" + "testing" +) + +// Happy path. +func Test_callbackMap_registerCallback(t *testing.T) { + m := newTestManager(0, false, t) + callbackChan := make(chan int) + testCallbacks := []struct { + tag string + cb receiveComm + }{ + {"tag1", func([]byte, Contact) { callbackChan <- 0 }}, + {"tag2", func([]byte, Contact) { callbackChan <- 1 }}, + {"tag3", func([]byte, Contact) { callbackChan <- 2 }}, + } + + for _, val := range testCallbacks { + m.callbackMap.registerCallback(val.tag, val.cb) + } + + for i, val := range testCallbacks { + go m.callbackMap.callbacks[singleUse.NewTagFP(val.tag)](nil, Contact{}) + result := <-callbackChan + if result != i { + t.Errorf("getCallback() did not return the expected callback."+ + "\nexpected: %d\nreceived: %d", i, result) + } + } +} + +// Happy path. +func Test_callbackMap_getCallback(t *testing.T) { + m := newTestManager(0, false, t) + callbackChan := make(chan int) + testCallbacks := []struct { + tagFP singleUse.TagFP + cb receiveComm + }{ + {singleUse.UnmarshalTagFP([]byte("tag1")), func([]byte, Contact) { callbackChan <- 0 }}, + {singleUse.UnmarshalTagFP([]byte("tag2")), func([]byte, Contact) { callbackChan <- 1 }}, + {singleUse.UnmarshalTagFP([]byte("tsg3")), func([]byte, Contact) { callbackChan <- 2 }}, + } + + for _, val := range testCallbacks { + m.callbackMap.callbacks[val.tagFP] = val.cb + } + + cb, err := m.callbackMap.getCallback(testCallbacks[1].tagFP) + if err != nil { + t.Errorf("getCallback() returned an error: %+v", err) + } + + go cb(nil, Contact{}) + + result := <-callbackChan + if result != 1 { + t.Errorf("getCallback() did not return the expected callback."+ + "\nexpected: %d\nreceived: %d", 1, result) + } +} + +// Error path: no callback exists for the given tag fingerprint. +func Test_callbackMap_getCallback_NoCallbackError(t *testing.T) { + m := newTestManager(0, false, t) + + _, err := m.callbackMap.getCallback(singleUse.UnmarshalTagFP([]byte("tag1"))) + if err == nil { + t.Error("getCallback() failed to return an error for a callback that " + + "does not exist.") + } +} diff --git a/single/collator.go b/single/collator.go new file mode 100644 index 0000000000000000000000000000000000000000..2ffe34cd656b0684857e3af21da80ad9682d1afb --- /dev/null +++ b/single/collator.go @@ -0,0 +1,76 @@ +package single + +import ( + "bytes" + "github.com/pkg/errors" + "sync" +) + +// Initial value of the collator maxNum that indicates it has yet to be set +const unsetCollatorMax = -1 + +// collator stores the list of payloads in the correct order. +type collator struct { + payloads [][]byte // List of payloads, in order + maxNum int // Max number of messages that can be received + count int // Current number of messages received + sync.Mutex +} + +// newCollator generates an empty list of payloads to fit the max number of +// possible messages. maxNum is set to indicate that it is not yet set. +func newCollator(messageCount uint64) *collator { + return &collator{ + payloads: make([][]byte, messageCount), + maxNum: unsetCollatorMax, + } +} + +// collate collects message payload parts. Once all parts are received, the full +// collated payload is returned along with true. Otherwise returns false. +func (c *collator) collate(payloadBytes []byte) ([]byte, bool, error) { + payload, err := unmarshalResponseMessage(payloadBytes) + if err != nil { + return nil, false, errors.Errorf("Failed to unmarshal response "+ + "payload: %+v", err) + } + + c.Lock() + defer c.Unlock() + + // If this is the first message received, then set the max number of + // messages expected to be received off its max number of parts. + if c.maxNum == unsetCollatorMax { + if int(payload.GetMaxParts()) > len(c.payloads) { + return nil, false, errors.Errorf("Max number of parts reported by "+ + "payload %d is larger than collator expected (%d).", + payload.GetMaxParts(), len(c.payloads)) + } + c.maxNum = int(payload.GetMaxParts()) + } + + // Make sure that the part number is within the expected number of parts + if int(payload.GetPartNum()) >= c.maxNum { + return nil, false, errors.Errorf("Payload part number (%d) greater "+ + "than max number of expected parts (%d).", + payload.GetPartNum(), c.maxNum) + } + + // Make sure no payload with the same part number exists + if c.payloads[payload.GetPartNum()] != nil { + return nil, false, errors.Errorf("A payload for the part number %d "+ + "already exists in the list.", payload.GetPartNum()) + } + + // Add the payload to the list + c.payloads[payload.GetPartNum()] = payload.GetContents() + c.count++ + + // Return false if not all messages have been received + if c.count < c.maxNum { + return nil, false, nil + } + + // Collate all the messages + return bytes.Join(c.payloads, nil), true, nil +} diff --git a/single/collator_test.go b/single/collator_test.go new file mode 100644 index 0000000000000000000000000000000000000000..79474123ecffa02a2cd68968579cde881bf11d6e --- /dev/null +++ b/single/collator_test.go @@ -0,0 +1,142 @@ +package single + +import ( + "bytes" + "reflect" + "strings" + "testing" +) + +// Happy path +func Test_newCollator(t *testing.T) { + messageCount := uint64(10) + expected := &collator{ + payloads: make([][]byte, messageCount), + maxNum: unsetCollatorMax, + count: 0, + } + c := newCollator(messageCount) + + if !reflect.DeepEqual(expected, c) { + t.Errorf("newCollator() failed to generated the expected collator."+ + "\nexepcted: %+v\nreceived: %+v", expected, c) + } +} + +// Happy path. +func TestCollator_collate(t *testing.T) { + messageCount := 16 + msgPayloadSize := 2 + msgParts := map[int]responseMessagePart{} + expectedData := make([]byte, messageCount*msgPayloadSize) + copy(expectedData, "This is the expected final data.") + + buff := bytes.NewBuffer(expectedData) + for i := 0; i < messageCount; i++ { + msgParts[i] = newResponseMessagePart(msgPayloadSize + 4) + msgParts[i].SetMaxParts(uint8(messageCount)) + msgParts[i].SetPartNum(uint8(i)) + msgParts[i].SetContents(buff.Next(msgPayloadSize)) + } + + c := newCollator(uint64(messageCount)) + + i := 0 + var fullPayload []byte + for j, part := range msgParts { + i++ + + var err error + var collated bool + + fullPayload, collated, err = c.collate(part.Marshal()) + if err != nil { + t.Errorf("collate() returned an error for part #%d: %+v", j, err) + } + + if i == messageCount && (!collated || fullPayload == nil) { + t.Errorf("collate() failed to collate a completed payload."+ + "\ncollated: %v\nfullPayload: %+v", collated, fullPayload) + } else if i < messageCount && (collated || fullPayload != nil) { + t.Errorf("collate() signaled it collated an unfinished payload."+ + "\ncollated: %v\nfullPayload: %+v", collated, fullPayload) + } + } + + if !bytes.Equal(expectedData, fullPayload) { + t.Errorf("collate() failed to return the correct collated data."+ + "\nexpected: %s\nreceived: %s", expectedData, fullPayload) + } +} + +// Error path: the byte slice cannot be unmarshaled. +func TestCollator_collate_UnmarshalError(t *testing.T) { + payloadBytes := []byte{1} + c := newCollator(1) + payload, collated, err := c.collate(payloadBytes) + + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Errorf("collate() failed to return an error for failing to "+ + "unmarshal the payload.\nerror: %+v", err) + } + + if payload != nil || collated { + t.Errorf("collate() signaled the payload was collated on error."+ + "\npayload: %+v\ncollated: %+v", payload, collated) + } +} + +// Error path: max reported parts by payload larger then set in collator +func TestCollator_collate_MaxPartsError(t *testing.T) { + payloadBytes := []byte{0xFF, 0xFF, 0xFF, 0xFF} + c := newCollator(1) + payload, collated, err := c.collate(payloadBytes) + + if err == nil || !strings.Contains(err.Error(), "Max number of parts") { + t.Errorf("collate() failed to return an error when the max number of "+ + "parts is larger than the payload size.\nerror: %+v", err) + } + + if payload != nil || collated { + t.Errorf("collate() signaled the payload was collated on error."+ + "\npayload: %+v\ncollated: %+v", payload, collated) + } +} + +// Error path: the message part number is greater than the max number of parts. +func TestCollator_collate_PartNumTooLargeError(t *testing.T) { + payloadBytes := []byte{25, 5, 5, 5} + c := newCollator(5) + payload, collated, err := c.collate(payloadBytes) + + if err == nil || !strings.Contains(err.Error(), "greater than max number of expected parts") { + t.Errorf("collate() failed to return an error when the part number is "+ + "greater than the max number of parts.\nerror: %+v", err) + } + + if payload != nil || collated { + t.Errorf("collate() signaled the payload was collated on error."+ + "\npayload: %+v\ncollated: %+v", payload, collated) + } +} + +// Error path: a message with the part number already exists. +func TestCollator_collate_PartExistsError(t *testing.T) { + payloadBytes := []byte{1, 5, 0, 1, 20} + c := newCollator(5) + payload, collated, err := c.collate(payloadBytes) + if err != nil { + t.Fatalf("collate() returned an error: %+v", err) + } + + payload, collated, err = c.collate(payloadBytes) + if err == nil || !strings.Contains(err.Error(), "A payload for the part number") { + t.Errorf("collate() failed to return an error when the part number "+ + "already exists.\nerror: %+v", err) + } + + if payload != nil || collated { + t.Errorf("collate() signaled the payload was collated on error."+ + "\npayload: %+v\ncollated: %+v", payload, collated) + } +} diff --git a/single/contact.go b/single/contact.go new file mode 100644 index 0000000000000000000000000000000000000000..25d207ddebb3fe36faffb03032a639c26b9ff271 --- /dev/null +++ b/single/contact.go @@ -0,0 +1,67 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package single + +import ( + "fmt" + "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/elixxir/crypto/e2e/singleUse" + "gitlab.com/xx_network/primitives/id" +) + +// Contact contains the information to respond to a single-use contact. +type Contact struct { + partner *id.ID // ID of the person to respond to + partnerPubKey *cyclic.Int // Public key of the partner + dhKey *cyclic.Int // DH key + tagFP singleUse.TagFP // Identifies which callback to use + maxParts uint8 // Max number of messages allowed in reply + used *int32 // Atomic variable +} + +// NewContact initialises a new Contact with the specified fields. +func NewContact(partner *id.ID, partnerPubKey, dhKey *cyclic.Int, + tagFP singleUse.TagFP, maxParts uint8) Contact { + used := int32(0) + return Contact{ + partner: partner.DeepCopy(), + partnerPubKey: partnerPubKey.DeepCopy(), + dhKey: dhKey.DeepCopy(), + tagFP: tagFP, + maxParts: maxParts, + used: &used, + } +} + +// GetMaxParts returns the maximum number of message parts that can be sent in a +// reply. +func (c Contact) GetMaxParts() uint8 { + return c.maxParts +} + +// GetPartner returns a copy of the partner ID. +func (c Contact) GetPartner() *id.ID { + return c.partner.DeepCopy() +} + +// String returns a string of the Contact structure. +func (c Contact) String() string { + format := "Contact{partner:%s partnerPubKey:%s dhKey:%s tagFP:%s maxParts:%d used:%d}" + return fmt.Sprintf(format, c.partner, c.partnerPubKey.Text(10), + c.dhKey.Text(10), c.tagFP, c.maxParts, *c.used) +} + +// Equal determines if c and b have equal field values. +func (c Contact) Equal(b Contact) bool { + return c.partner.Cmp(b.partner) && + c.partnerPubKey.Cmp(b.partnerPubKey) == 0 && + c.dhKey.Cmp(b.dhKey) == 0 && + c.tagFP == b.tagFP && + c.maxParts == b.maxParts && + *c.used == *b.used +} diff --git a/single/contact_test.go b/single/contact_test.go new file mode 100644 index 0000000000000000000000000000000000000000..533e5600f935f1feaf36aad48735e47917466357 --- /dev/null +++ b/single/contact_test.go @@ -0,0 +1,102 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package single + +import ( + "gitlab.com/elixxir/crypto/e2e/singleUse" + "gitlab.com/xx_network/primitives/id" + "math/rand" + "testing" +) + +// Happy path. +func TestNewContact(t *testing.T) { + grp := getGroup() + used := int32(0) + expected := Contact{ + partner: id.NewIdFromString("sender ID", id.User, t), + partnerPubKey: grp.NewInt(99), + dhKey: grp.NewInt(42), + tagFP: singleUse.UnmarshalTagFP([]byte("test tagFP")), + maxParts: uint8(rand.Uint64()), + used: &used, + } + + testC := NewContact(expected.partner, expected.partnerPubKey, + expected.dhKey, expected.tagFP, expected.maxParts) + + if !expected.Equal(testC) { + t.Errorf("NewContact() did not return the expected Contact."+ + "\nexpected: %s\nrecieved: %s", expected, testC) + } +} + +// Happy path. +func TestContact_GetMaxParts(t *testing.T) { + grp := getGroup() + maxParts := uint8(rand.Uint64()) + c := NewContact(id.NewIdFromString("sender ID", id.User, t), grp.NewInt(99), + grp.NewInt(42), singleUse.TagFP{}, maxParts) + + if maxParts != c.GetMaxParts() { + t.Errorf("GetMaxParts() failed to return the expected maxParts."+ + "\nexpected %d\nreceived: %d", maxParts, c.GetMaxParts()) + } +} + +// Happy path. +func TestContact_GetPartner(t *testing.T) { + grp := getGroup() + senderID := id.NewIdFromString("sender ID", id.User, t) + c := NewContact(senderID, grp.NewInt(99), grp.NewInt(42), singleUse.TagFP{}, + uint8(rand.Uint64())) + + if !senderID.Cmp(c.GetPartner()) { + t.Errorf("GetPartner() failed to return the expected sender ID."+ + "\nexpected %s\nreceived: %s", senderID, c.GetPartner()) + } + +} + +// Happy path. +func TestContact_String(t *testing.T) { + grp := getGroup() + a := NewContact(id.NewIdFromString("sender ID 1", id.User, t), grp.NewInt(99), + grp.NewInt(42), singleUse.UnmarshalTagFP([]byte("test tagFP 1")), uint8(rand.Uint64())) + b := NewContact(id.NewIdFromString("sender ID 2", id.User, t), + grp.NewInt(98), grp.NewInt(43), singleUse.UnmarshalTagFP([]byte("test tagFP 2")), uint8(rand.Uint64())) + c := NewContact(a.GetPartner(), a.partnerPubKey.DeepCopy(), a.dhKey.DeepCopy(), a.tagFP, a.maxParts) + + if a.String() == b.String() { + t.Errorf("String() did not return the expected string."+ + "\na: %s\nb: %s", a, b) + } + if a.String() != c.String() { + t.Errorf("String() did not return the expected string."+ + "\na: %s\nc: %s", a, c) + } +} + +// Happy path. +func TestContact_Equal(t *testing.T) { + grp := getGroup() + a := NewContact(id.NewIdFromString("sender ID 1", id.User, t), + grp.NewInt(99), grp.NewInt(42), singleUse.UnmarshalTagFP([]byte("test tagFP 1")), uint8(rand.Uint64())) + b := NewContact(id.NewIdFromString("sender ID 2", id.User, t), + grp.NewInt(98), grp.NewInt(43), singleUse.UnmarshalTagFP([]byte("test tagFP 2")), uint8(rand.Uint64())) + c := NewContact(a.GetPartner(), a.partnerPubKey.DeepCopy(), a.dhKey.DeepCopy(), a.tagFP, a.maxParts) + + if a.Equal(b) { + t.Errorf("Equal() found two different Contacts as equal."+ + "\na: %s\nb: %s", a, b) + } + if !a.Equal(c) { + t.Errorf("Equal() found two equal Contacts as not equal."+ + "\na: %s\nc: %s", a, c) + } +} diff --git a/single/fingerprintMap.go b/single/fingerprintMap.go new file mode 100644 index 0000000000000000000000000000000000000000..10230e7ec6297cc9380791e5e24fb637d8f06b45 --- /dev/null +++ b/single/fingerprintMap.go @@ -0,0 +1,55 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package single + +import ( + "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/elixxir/crypto/e2e/singleUse" + "gitlab.com/elixxir/primitives/format" + "sync" +) + +// fingerprintMap stores a map of fingerprint to key numbers. +type fingerprintMap struct { + fps map[format.Fingerprint]uint64 + sync.Mutex +} + +// newFingerprintMap returns a map of fingerprints generated from the provided +// key that is messageCount long. +func newFingerprintMap(dhKey *cyclic.Int, messageCount uint64) *fingerprintMap { + fpm := &fingerprintMap{ + fps: make(map[format.Fingerprint]uint64, messageCount), + } + + for i := uint64(0); i < messageCount; i++ { + fp := singleUse.NewResponseFingerprint(dhKey, i) + fpm.fps[fp] = i + } + + return fpm +} + +// getKey returns true and the corresponding key of the fingerprint exists in +// the map and returns false otherwise. If the fingerprint exists, then it is +// deleted prior to returning the key. +func (fpm *fingerprintMap) getKey(dhKey *cyclic.Int, fp format.Fingerprint) ([]byte, bool) { + fpm.Lock() + defer fpm.Unlock() + + num, exists := fpm.fps[fp] + if !exists { + return nil, false + } + + // Delete found fingerprint + delete(fpm.fps, fp) + + // Generate and return the key + return singleUse.NewResponseKey(dhKey, num), true +} diff --git a/single/fingerprintMap_test.go b/single/fingerprintMap_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2be4e3f8d447beaa6385770bedfca87864a5ca97 --- /dev/null +++ b/single/fingerprintMap_test.go @@ -0,0 +1,73 @@ +package single + +import ( + "bytes" + "gitlab.com/elixxir/crypto/e2e/singleUse" + "gitlab.com/elixxir/primitives/format" + "reflect" + "testing" +) + +// Happy path. +func Test_newFingerprintMap(t *testing.T) { + baseKey := getGroup().NewInt(42) + messageCount := 3 + expected := &fingerprintMap{ + fps: map[format.Fingerprint]uint64{ + singleUse.NewResponseFingerprint(baseKey, 0): 0, + singleUse.NewResponseFingerprint(baseKey, 1): 1, + singleUse.NewResponseFingerprint(baseKey, 2): 2, + }, + } + + fpm := newFingerprintMap(baseKey, uint64(messageCount)) + + if !reflect.DeepEqual(expected, fpm) { + t.Errorf("newFingerprintMap() did not generate the expected map."+ + "\nexpected: %+v\nreceived: %+v", expected, fpm) + } +} + +// Happy path. +func TestFingerprintMap_getKey(t *testing.T) { + baseKey := getGroup().NewInt(42) + fpm := newFingerprintMap(baseKey, 5) + fp := singleUse.NewResponseFingerprint(baseKey, 3) + expectedKey := singleUse.NewResponseKey(baseKey, 3) + + testKey, exists := fpm.getKey(baseKey, fp) + if !exists { + t.Errorf("getKey() failed to find key that exists in map."+ + "\nfingerprint: %+v", fp) + } + + if !bytes.Equal(expectedKey, testKey) { + t.Errorf("getKey() returned the wrong key.\nexpected: %+v\nreceived: %+v", + expectedKey, testKey) + } + + testFP, exists := fpm.fps[fp] + if exists { + t.Errorf("getKey() failed to delete the found fingerprint."+ + "\nfingerprint: %+v", testFP) + } +} + +// Error path: fingerprint does not exist in map. +func TestFingerprintMap_getKey_FingerprintNotInMap(t *testing.T) { + baseKey := getGroup().NewInt(42) + fpm := newFingerprintMap(baseKey, 5) + fp := singleUse.NewResponseFingerprint(baseKey, 30) + + key, exists := fpm.getKey(baseKey, fp) + if exists { + t.Errorf("getKey() found a fingerprint in the map that should not exist."+ + "\nfingerprint: %+v\nkey: %+v", fp, key) + } + + // Ensure no fingerprints were deleted + if len(fpm.fps) != 5 { + t.Errorf("getKey() deleted fingerprint."+ + "\nexpected size: %d\nreceived size: %d", 5, len(fpm.fps)) + } +} diff --git a/single/manager.go b/single/manager.go new file mode 100644 index 0000000000000000000000000000000000000000..d24c88fdd684271262da556d746220d9dc9d914f --- /dev/null +++ b/single/manager.go @@ -0,0 +1,93 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package single + +import ( + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/api" + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" + "gitlab.com/elixxir/client/storage" + "gitlab.com/elixxir/client/storage/reception" + "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/xx_network/primitives/id" +) + +const ( + rawMessageBuffSize = 100 + singleUseTransmission = "SingleUseTransmission" + singleUseReceiveTransmission = "SingleUseReceiveTransmission" + singleUseResponse = "SingleUseResponse" + singleUseReceiveResponse = "SingleUseReceiveResponse" + singleUseStop = "SingleUse" +) + +// Manager handles the transmission and reception of single-use communication. +type Manager struct { + // Client and its field + client *api.Client + store *storage.Session + reception *reception.Store + swb interfaces.Switchboard + net interfaces.NetworkManager + rng *fastRNG.StreamGenerator + + // Holds the information needed to manage each pending communication. A + // state is created when a transmission is started and is removed on + // response or timeout. + p *pending + + // List of callbacks that can be called when a transmission is received. For + // an entity to receive a message, it must register a callback in this map + // with the same tag used to send the message. + callbackMap *callbackMap +} + +// NewManager creates a new single-use communication manager. +func NewManager(client *api.Client, reception *reception.Store) *Manager { + return &Manager{ + client: client, + store: client.GetStorage(), + reception: reception, + swb: client.GetSwitchboard(), + net: client.GetNetworkInterface(), + rng: client.GetRng(), + p: newPending(), + callbackMap: newCallbackMap(), + } +} + +// StartProcesses starts the process of receiving single-use transmissions and +// replies. +func (m *Manager) StartProcesses() stoppable.Stoppable { + // Start waiting for single-use transmission + transmissionStop := stoppable.NewSingle(singleUseTransmission) + transmissionChan := make(chan message.Receive, rawMessageBuffSize) + m.swb.RegisterChannel(singleUseReceiveTransmission, &id.ID{}, message.Raw, transmissionChan) + go m.receiveTransmissionHandler(transmissionChan, transmissionStop.Quit()) + + // Start waiting for single-use response + responseStop := stoppable.NewSingle(singleUseResponse) + responseChan := make(chan message.Receive, rawMessageBuffSize) + m.swb.RegisterChannel(singleUseReceiveResponse, &id.ID{}, message.Raw, responseChan) + go m.receiveResponseHandler(responseChan, responseStop.Quit()) + + // Create a multi stoppable + singleUseMulti := stoppable.NewMulti(singleUseStop) + singleUseMulti.Add(transmissionStop) + singleUseMulti.Add(responseStop) + + return singleUseMulti +} + +// RegisterCallback registers a callback for received messages. +func (m *Manager) RegisterCallback(tag string, callback receiveComm) { + jww.DEBUG.Printf("Registering single-use callback with tag %s.", tag) + m.callbackMap.registerCallback(tag, callback) +} diff --git a/single/manager_test.go b/single/manager_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c7f94b1010116b43f290c51de1dcdf9e4cdc7a77 --- /dev/null +++ b/single/manager_test.go @@ -0,0 +1,358 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package single + +import ( + "bytes" + "errors" + "gitlab.com/elixxir/client/api" + "gitlab.com/elixxir/client/interfaces" + contact2 "gitlab.com/elixxir/client/interfaces/contact" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/stoppable" + "gitlab.com/elixxir/client/storage" + "gitlab.com/elixxir/client/storage/reception" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/client/switchboard" + "gitlab.com/elixxir/comms/network" + "gitlab.com/elixxir/crypto/e2e" + "gitlab.com/elixxir/crypto/e2e/singleUse" + "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/elixxir/ekv" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/comms/connect" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/id/ephemeral" + "gitlab.com/xx_network/primitives/ndf" + "math/rand" + "reflect" + "sync" + "testing" + "time" +) + +// Happy path. +func TestNewManager(t *testing.T) { + client := &api.Client{} + e := &Manager{ + client: client, + p: newPending(), + } + m := NewManager(client, reception.NewStore(versioned.NewKV(make(ekv.Memstore)))) + + if e.client != m.client || e.store != m.store || e.net != m.net || + e.rng != m.rng || !reflect.DeepEqual(e.p, m.p) { + t.Errorf("NewManager() did not return the expected new Manager."+ + "\nexpected: %+v\nreceived: %+v", e, m) + } +} + +// Happy path. +func TestManager_StartProcesses(t *testing.T) { + m := newTestManager(0, false, t) + partner := &contact2.Contact{ + ID: id.NewIdFromString("recipientID", id.User, t), + DhPubKey: m.store.E2e().GetDHPublicKey(), + } + tag := "Test tag" + payload := make([]byte, 132) + rand.New(rand.NewSource(42)).Read(payload) + callback, callbackChan := createReceiveComm() + + transmitMsg, _, rid, _, err := m.makeTransmitCmixMessage(partner, payload, + tag, 8, 32, 30*time.Second, time.Now(), rand.New(rand.NewSource(42))) + if err != nil { + t.Fatalf("Failed to create tranmission CMIX message: %+v", err) + } + + _, sender, err := m.processTransmission(transmitMsg, singleUse.NewTransmitFingerprint(m.store.E2e().GetDHPublicKey())) + + replyMsgs, err := m.makeReplyCmixMessages(sender, payload) + if err != nil { + t.Fatalf("Failed to generate reply CMIX messages: %+v", err) + } + + receiveMsg := message.Receive{ + Payload: transmitMsg.Marshal(), + MessageType: message.Raw, + Sender: rid, + RecipientID: partner.ID, + } + + m.callbackMap.registerCallback(tag, callback) + + _ = m.StartProcesses() + m.swb.(*switchboard.Switchboard).Speak(receiveMsg) + + timer := time.NewTimer(50 * time.Millisecond) + + select { + case results := <-callbackChan: + if !bytes.Equal(results.payload, payload) { + t.Errorf("Callback received wrong payload."+ + "\nexpected: %+v\nreceived: %+v", payload, results.payload) + } + case <-timer.C: + t.Errorf("Callback failed to be called.") + } + + callbackFunc, callbackFuncChan := createReplyComm() + m.p.Lock() + m.p.singleUse[*rid] = newState(sender.dhKey, sender.maxParts, callbackFunc) + m.p.Unlock() + eid, _, _, _ := ephemeral.GetId(partner.ID, id.ArrIDLen, time.Now().UnixNano()) + replyMsg := message.Receive{ + Payload: replyMsgs[0].Marshal(), + MessageType: message.Raw, + Sender: partner.ID, + RecipientID: rid, + EphemeralID: eid, + } + + go func() { + timer := time.NewTimer(50 * time.Millisecond) + select { + case <-timer.C: + t.Errorf("quitChan never set.") + case <-m.p.singleUse[*rid].quitChan: + } + }() + + m.swb.(*switchboard.Switchboard).Speak(replyMsg) + + timer = time.NewTimer(50 * time.Millisecond) + + select { + case results := <-callbackFuncChan: + if !bytes.Equal(results.payload, payload) { + t.Errorf("Callback received wrong payload."+ + "\nexpected: %+v\nreceived: %+v", payload, results.payload) + } + case <-timer.C: + t.Errorf("Callback failed to be called.") + } +} + +// Happy path: tests that the stoppable stops both routines. +func TestManager_StartProcesses_Stop(t *testing.T) { + m := newTestManager(0, false, t) + partner := &contact2.Contact{ + ID: id.NewIdFromString("recipientID", id.User, t), + DhPubKey: m.store.E2e().GetDHPublicKey(), + } + tag := "Test tag" + payload := make([]byte, 132) + rand.New(rand.NewSource(42)).Read(payload) + callback, callbackChan := createReceiveComm() + + transmitMsg, _, rid, _, err := m.makeTransmitCmixMessage(partner, payload, + tag, 8, 32, 30*time.Second, time.Now(), rand.New(rand.NewSource(42))) + if err != nil { + t.Fatalf("Failed to create tranmission CMIX message: %+v", err) + } + + _, sender, err := m.processTransmission(transmitMsg, singleUse.NewTransmitFingerprint(m.store.E2e().GetDHPublicKey())) + + replyMsgs, err := m.makeReplyCmixMessages(sender, payload) + if err != nil { + t.Fatalf("Failed to generate reply CMIX messages: %+v", err) + } + + receiveMsg := message.Receive{ + Payload: transmitMsg.Marshal(), + MessageType: message.Raw, + Sender: rid, + RecipientID: partner.ID, + } + + m.callbackMap.registerCallback(tag, callback) + + stop := m.StartProcesses() + if !stop.IsRunning() { + t.Error("Stoppable is not running.") + } + + err = stop.Close(1 * time.Millisecond) + if err != nil { + t.Errorf("Failed to close: %+v", err) + } + + m.swb.(*switchboard.Switchboard).Speak(receiveMsg) + + timer := time.NewTimer(50 * time.Millisecond) + + select { + case results := <-callbackChan: + t.Errorf("Callback called when the thread should have stopped."+ + "\npayload: %+v\ncontact: %+v", results.payload, results.c) + case <-timer.C: + } + + callbackFunc, callbackFuncChan := createReplyComm() + m.p.Lock() + m.p.singleUse[*rid] = newState(sender.dhKey, sender.maxParts, callbackFunc) + m.p.Unlock() + eid, _, _, _ := ephemeral.GetId(partner.ID, id.ArrIDLen, time.Now().UnixNano()) + replyMsg := message.Receive{ + Payload: replyMsgs[0].Marshal(), + MessageType: message.Raw, + Sender: partner.ID, + RecipientID: rid, + EphemeralID: eid, + } + + m.swb.(*switchboard.Switchboard).Speak(replyMsg) + + timer = time.NewTimer(50 * time.Millisecond) + + select { + case results := <-callbackFuncChan: + t.Errorf("Callback called when the thread should have stopped."+ + "\npayload: %+v\nerror: %+v", results.payload, results.err) + case <-timer.C: + } +} + +type receiveCommData struct { + payload []byte + c Contact +} + +func createReceiveComm() (receiveComm, chan receiveCommData) { + callbackChan := make(chan receiveCommData) + + callback := func(payload []byte, c Contact) { + callbackChan <- receiveCommData{payload: payload, c: c} + } + return callback, callbackChan +} + +func newTestManager(timeout time.Duration, cmixErr bool, t *testing.T) *Manager { + return &Manager{ + client: nil, + store: storage.InitTestingSession(t), + reception: reception.NewStore(versioned.NewKV(make(ekv.Memstore))), + swb: switchboard.New(), + net: newTestNetworkManager(timeout, cmixErr, t), + rng: fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + p: newPending(), + callbackMap: newCallbackMap(), + } +} + +func newTestNetworkManager(timeout time.Duration, cmixErr bool, t *testing.T) interfaces.NetworkManager { + instanceComms := &connect.ProtoComms{ + Manager: connect.NewManagerTesting(t), + } + + thisInstance, err := network.NewInstanceTesting(instanceComms, getNDF(), + getNDF(), nil, nil, t) + if err != nil { + t.Fatalf("Failed to create new test instance: %v", err) + } + + return &testNetworkManager{ + instance: thisInstance, + msgs: []format.Message{}, + cmixTimeout: timeout, + cmixErr: cmixErr, + } +} + +// testNetworkManager is a test implementation of NetworkManager interface. +type testNetworkManager struct { + instance *network.Instance + msgs []format.Message + cmixTimeout time.Duration + cmixErr bool + sync.RWMutex +} + +func (tnm *testNetworkManager) GetMsg(i int) format.Message { + tnm.RLock() + defer tnm.RUnlock() + return tnm.msgs[i] +} + +func (tnm *testNetworkManager) SendE2E(_ message.Send, _ params.E2E) ([]id.Round, e2e.MessageID, error) { + return nil, [32]byte{}, nil +} + +func (tnm *testNetworkManager) SendUnsafe(_ message.Send, _ params.Unsafe) ([]id.Round, error) { + return []id.Round{}, nil +} + +func (tnm *testNetworkManager) SendCMIX(msg format.Message, _ *id.ID, _ params.CMIX) (id.Round, ephemeral.Id, error) { + if tnm.cmixTimeout != 0 { + time.Sleep(tnm.cmixTimeout) + } else if tnm.cmixErr { + return 0, ephemeral.Id{}, errors.New("sendCMIX error") + } + + tnm.Lock() + defer tnm.Unlock() + + tnm.msgs = append(tnm.msgs, msg) + + return id.Round(rand.Uint64()), ephemeral.Id{}, nil +} + +func (tnm *testNetworkManager) GetInstance() *network.Instance { + return tnm.instance +} + +func (tnm *testNetworkManager) GetHealthTracker() interfaces.HealthTracker { + return nil +} + +func (tnm *testNetworkManager) Follow() (stoppable.Stoppable, error) { + return nil, nil +} + +func (tnm *testNetworkManager) CheckGarbledMessages() {} + +func getNDF() *ndf.NetworkDefinition { + return &ndf.NetworkDefinition{ + E2E: ndf.Group{ + Prime: "E2EE983D031DC1DB6F1A7A67DF0E9A8E5561DB8E8D49413394C049B" + + "7A8ACCEDC298708F121951D9CF920EC5D146727AA4AE535B0922C688B55B3DD2AE" + + "DF6C01C94764DAB937935AA83BE36E67760713AB44A6337C20E7861575E745D31F" + + "8B9E9AD8412118C62A3E2E29DF46B0864D0C951C394A5CBBDC6ADC718DD2A3E041" + + "023DBB5AB23EBB4742DE9C1687B5B34FA48C3521632C4A530E8FFB1BC51DADDF45" + + "3B0B2717C2BC6669ED76B4BDD5C9FF558E88F26E5785302BEDBCA23EAC5ACE9209" + + "6EE8A60642FB61E8F3D24990B8CB12EE448EEF78E184C7242DD161C7738F32BF29" + + "A841698978825B4111B4BC3E1E198455095958333D776D8B2BEEED3A1A1A221A6E" + + "37E664A64B83981C46FFDDC1A45E3D5211AAF8BFBC072768C4F50D7D7803D2D4F2" + + "78DE8014A47323631D7E064DE81C0C6BFA43EF0E6998860F1390B5D3FEACAF1696" + + "015CB79C3F9C2D93D961120CD0E5F12CBB687EAB045241F96789C38E89D796138E" + + "6319BE62E35D87B1048CA28BE389B575E994DCA755471584A09EC723742DC35873" + + "847AEF49F66E43873", + Generator: "2", + }, + CMIX: ndf.Group{ + Prime: "9DB6FB5951B66BB6FE1E140F1D2CE5502374161FD6538DF1648218642F0B5C48" + + "C8F7A41AADFA187324B87674FA1822B00F1ECF8136943D7C55757264E5A1A44F" + + "FE012E9936E00C1D3E9310B01C7D179805D3058B2A9F4BB6F9716BFE6117C6B5" + + "B3CC4D9BE341104AD4A80AD6C94E005F4B993E14F091EB51743BF33050C38DE2" + + "35567E1B34C3D6A5C0CEAA1A0F368213C3D19843D0B4B09DCB9FC72D39C8DE41" + + "F1BF14D4BB4563CA28371621CAD3324B6A2D392145BEBFAC748805236F5CA2FE" + + "92B871CD8F9C36D3292B5509CA8CAA77A2ADFC7BFD77DDA6F71125A7456FEA15" + + "3E433256A2261C6A06ED3693797E7995FAD5AABBCFBE3EDA2741E375404AE25B", + Generator: "5C7FF6B06F8F143FE8288433493E4769C4D988ACE5BE25A0E24809670716C613" + + "D7B0CEE6932F8FAA7C44D2CB24523DA53FBE4F6EC3595892D1AA58C4328A06C4" + + "6A15662E7EAA703A1DECF8BBB2D05DBE2EB956C142A338661D10461C0D135472" + + "085057F3494309FFA73C611F78B32ADBB5740C361C9F35BE90997DB2014E2EF5" + + "AA61782F52ABEB8BD6432C4DD097BC5423B285DAFB60DC364E8161F4A2A35ACA" + + "3A10B1C4D203CC76A470A33AFDCBDD92959859ABD8B56E1725252D78EAC66E71" + + "BA9AE3F1DD2487199874393CD4D832186800654760E1E34C09E4D155179F9EC0" + + "DC4473F996BDCE6EED1CABED8B6F116F7AD9CF505DF0F998E34AB27514B0FFE7", + }, + } +} diff --git a/single/receiveResponse.go b/single/receiveResponse.go new file mode 100644 index 0000000000000000000000000000000000000000..d25880dae86dee70d882463aeba08d0310daa08b --- /dev/null +++ b/single/receiveResponse.go @@ -0,0 +1,108 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package single + +import ( + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/crypto/e2e/auth" + "gitlab.com/elixxir/crypto/e2e/singleUse" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/id/ephemeral" +) + +// receiveResponseHandler handles the reception of single-use response messages. +func (m *Manager) receiveResponseHandler(rawMessages chan message.Receive, + quitChan <-chan struct{}) { + jww.DEBUG.Print("Waiting to receive single-use response messages.") + for { + select { + case <-quitChan: + jww.DEBUG.Printf("Stopping waiting to receive single-use " + + "response message.") + return + case msg := <-rawMessages: + jww.DEBUG.Printf("Received CMIX message; checking if it is a " + + "single-use response.") + + // Process CMIX message + err := m.processesResponse(msg.RecipientID, msg.EphemeralID, msg.Payload) + if err != nil { + jww.WARN.Printf("Failed to read single-use CMIX message "+ + "response: %+v", err) + } + } + } +} + +// processesResponse processes the CMIX message and collates its payload. If the +// message is invalid, an error is returned. +func (m *Manager) processesResponse(rid *id.ID, ephID ephemeral.Id, + msgBytes []byte) error { + + // Get the state from the map + m.p.RLock() + state, exists := m.p.singleUse[*rid] + m.p.RUnlock() + + // Check that the state exists + if !exists { + return errors.Errorf("no state exists for the reception ID %s.", rid) + } + + // Unmarshal CMIX message + cmixMsg := format.Unmarshal(msgBytes) + + // Ensure the fingerprints match + fp := cmixMsg.GetKeyFP() + key, exists := state.fpMap.getKey(state.dhKey, fp) + if !exists { + return errors.New("message fingerprint does not correspond to the " + + "expected fingerprint.") + } + + // Verify the CMIX message MAC + if !singleUse.VerifyMAC(key, cmixMsg.GetContents(), cmixMsg.GetMac()) { + return errors.New("failed to verify the CMIX message MAC.") + } + + // Denote that the message is not garbled + jww.DEBUG.Print("Received single-use response message.") + m.store.GetGarbledMessages().Remove(cmixMsg) + + // Decrypt and collate the payload + decryptedPayload := auth.Crypt(key, fp[:24], cmixMsg.GetContents()) + collatedPayload, collated, err := state.c.collate(decryptedPayload) + if err != nil { + return errors.Errorf("failed to collate payload: %+v", err) + } + jww.DEBUG.Print("Successfully processed single-use response message part.") + + // Once all message parts have been received delete and close everything + if collated { + jww.DEBUG.Print("Received all parts of single-use response message.") + // Exit the timeout handler + state.quitChan <- struct{}{} + + // Remove identity + m.reception.RemoveIdentity(ephID) + + // Remove state from map + m.p.Lock() + delete(m.p.singleUse, *rid) + m.p.Unlock() + + // Call in separate routine to prevent blocking + jww.DEBUG.Print("Calling single-use response message callback.") + go state.callback(collatedPayload, nil) + } + + return nil +} diff --git a/single/receiveResponse_test.go b/single/receiveResponse_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2782bc6fc1ad6d74eab2c0c2d51bf9988066734d --- /dev/null +++ b/single/receiveResponse_test.go @@ -0,0 +1,324 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package single + +import ( + "bytes" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/crypto/e2e/auth" + "gitlab.com/elixxir/crypto/e2e/singleUse" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/id/ephemeral" + "math/rand" + "strings" + "testing" + "time" +) + +// Happy path. +func TestManager_ReceiveResponseHandler(t *testing.T) { + m := newTestManager(0, false, t) + rawMessages := make(chan message.Receive, rawMessageBuffSize) + quitChan := make(chan struct{}) + partner := NewContact(id.NewIdFromString("recipientID", id.User, t), + m.store.E2e().GetGroup().NewInt(43), m.store.E2e().GetGroup().NewInt(42), + singleUse.TagFP{}, 8) + ephID, _, _, err := ephemeral.GetId(partner.partner, id.ArrIDLen, time.Now().UnixNano()) + payload := make([]byte, 2000) + rand.New(rand.NewSource(42)).Read(payload) + callback, callbackChan := createReplyComm() + rid := id.NewIdFromString("rid", id.User, t) + + m.p.singleUse[*rid] = newState(partner.dhKey, partner.maxParts, callback) + + msgs, err := m.makeReplyCmixMessages(partner, payload) + if err != nil { + t.Fatalf("Failed to generate CMIX messages: %+v", err) + } + + go func() { + timer := time.NewTimer(50 * time.Millisecond) + select { + case <-timer.C: + t.Errorf("quitChan never set.") + case <-m.p.singleUse[*rid].quitChan: + } + }() + + go m.receiveResponseHandler(rawMessages, quitChan) + + for _, msg := range msgs { + rawMessages <- message.Receive{ + Payload: msg.Marshal(), + Sender: partner.partner, + RecipientID: rid, + EphemeralID: ephID, + } + } + + timer := time.NewTimer(50 * time.Millisecond) + + select { + case results := <-callbackChan: + if !bytes.Equal(results.payload, payload) { + t.Errorf("Callback received wrong payload."+ + "\nexpected: %+v\nreceived: %+v", payload, results.payload) + } + if results.err != nil { + t.Errorf("Callback received an error: %+v", results.err) + } + case <-timer.C: + t.Errorf("Callback failed to be called.") + } + + quitChan <- struct{}{} +} + +// Error path: invalid CMIX message. +func TestManager_ReceiveResponseHandler_CmixMessageError(t *testing.T) { + m := newTestManager(0, false, t) + rawMessages := make(chan message.Receive, rawMessageBuffSize) + quitChan := make(chan struct{}) + partner := NewContact(id.NewIdFromString("recipientID", id.User, t), + m.store.E2e().GetGroup().NewInt(43), m.store.E2e().GetGroup().NewInt(42), + singleUse.TagFP{}, 8) + ephID, _, _, _ := ephemeral.GetId(partner.partner, id.ArrIDLen, time.Now().UnixNano()) + payload := make([]byte, 2000) + rand.New(rand.NewSource(42)).Read(payload) + callback, callbackChan := createReplyComm() + rid := id.NewIdFromString("rid", id.User, t) + + m.p.singleUse[*rid] = newState(partner.dhKey, partner.maxParts, callback) + + go func() { + timer := time.NewTimer(50 * time.Millisecond) + select { + case <-timer.C: + case <-m.p.singleUse[*rid].quitChan: + t.Error("quitChan called on error.") + } + }() + + go m.receiveResponseHandler(rawMessages, quitChan) + + rawMessages <- message.Receive{ + Payload: make([]byte, format.MinimumPrimeSize*2), + Sender: partner.partner, + RecipientID: rid, + EphemeralID: ephID, + } + + timer := time.NewTimer(50 * time.Millisecond) + + select { + case results := <-callbackChan: + t.Errorf("Callback called when it should not have been."+ + "payload: %+v\nerror: %+v", results.payload, results.err) + case <-timer.C: + } + + quitChan <- struct{}{} +} + +// Happy path. +func TestManager_processesResponse(t *testing.T) { + m := newTestManager(0, false, t) + rid := id.NewIdFromString("test RID", id.User, t) + ephID, _, _, err := ephemeral.GetId(rid, id.ArrIDLen, time.Now().UnixNano()) + if err != nil { + t.Fatalf("Failed to create ephemeral ID: %+v", err) + } + dhKey := getGroup().NewInt(5) + maxMsgs := uint8(6) + timeout := 5 * time.Millisecond + callback, callbackChan := createReplyComm() + + m.p.singleUse[*rid] = newState(dhKey, maxMsgs, callback) + + expectedData := []string{"This i", "s the ", "expect", "ed fin", "al dat", "a."} + var msgs []format.Message + for fp, i := range m.p.singleUse[*rid].fpMap.fps { + newMsg := format.NewMessage(format.MinimumPrimeSize) + part := newResponseMessagePart(newMsg.ContentsSize()) + part.SetContents([]byte(expectedData[i])) + part.SetPartNum(uint8(i)) + part.SetMaxParts(maxMsgs) + + key := singleUse.NewResponseKey(dhKey, i) + encryptedPayload := auth.Crypt(key, fp[:24], part.Marshal()) + + newMsg.SetKeyFP(fp) + newMsg.SetMac(singleUse.MakeMAC(key, encryptedPayload)) + newMsg.SetContents(encryptedPayload) + msgs = append(msgs, newMsg) + } + + go func() { + timer := time.NewTimer(timeout) + select { + case <-timer.C: + t.Errorf("quitChan never set.") + case <-m.p.singleUse[*rid].quitChan: + } + }() + + for i, msg := range msgs { + err := m.processesResponse(rid, ephID, msg.Marshal()) + if err != nil { + t.Errorf("processesResponse() returned an error (%d): %+v", i, err) + } + } + + timer := time.NewTimer(timeout) + select { + case <-timer.C: + t.Errorf("Callback never called.") + case results := <-callbackChan: + if results.err != nil { + t.Errorf("Callback returned an error: %+v", err) + } + if !bytes.Equal([]byte(strings.Join(expectedData, "")), results.payload) { + t.Errorf("Callback returned incorrect data."+ + "\nexpected: %s\nreceived: %s", expectedData, results.payload) + } + } + + if _, exists := m.p.singleUse[*rid]; exists { + t.Error("Failed to delete the state after collation is complete.") + } +} + +// Error path: no state in map. +func TestManager_processesResponse_NoStateError(t *testing.T) { + m := newTestManager(0, false, t) + rid := id.NewIdFromString("test RID", id.User, t) + + err := m.processesResponse(rid, ephemeral.Id{}, []byte{}) + if !check(err, "no state exists for the reception ID") { + t.Errorf("processesResponse() did not return an error when the state "+ + "is not in the map: %+v", err) + } +} + +// Error path: failed to verify MAC. +func TestManager_processesResponse_MacVerificationError(t *testing.T) { + m := newTestManager(0, false, t) + rid := id.NewIdFromString("test RID", id.User, t) + dhKey := getGroup().NewInt(5) + timeout := 5 * time.Millisecond + callback := func(payload []byte, err error) {} + + quitChan, _, err := m.p.addState(rid, dhKey, 1, callback, timeout) + if err != nil { + t.Fatalf("Failed to add state: %+v", err) + } + quitChan <- struct{}{} + + newMsg := format.NewMessage(format.MinimumPrimeSize) + part := newResponseMessagePart(newMsg.ContentsSize()) + part.SetContents([]byte("payload data")) + part.SetPartNum(0) + part.SetMaxParts(1) + newMsg.SetMac(singleUse.MakeMAC(dhKey.Bytes(), []byte("some data"))) + newMsg.SetContents([]byte("payload data")) + + var fp format.Fingerprint + for fpt, i := range m.p.singleUse[*rid].fpMap.fps { + if i == 0 { + fp = fpt + } + } + newMsg.SetKeyFP(fp) + + err = m.processesResponse(rid, ephemeral.Id{}, newMsg.Marshal()) + if !check(err, "MAC") { + t.Errorf("processesResponse() did not return an error when MAC "+ + "verification should have failed: %+v", err) + } +} + +// Error path: CMIX message fingerprint does not match fingerprints in map. +func TestManager_processesResponse_FingerprintError(t *testing.T) { + m := newTestManager(0, false, t) + rid := id.NewIdFromString("test RID", id.User, t) + dhKey := getGroup().NewInt(5) + timeout := 5 * time.Millisecond + callback := func(payload []byte, err error) {} + + quitChan, _, err := m.p.addState(rid, dhKey, 1, callback, timeout) + if err != nil { + t.Fatalf("Failed to add state: %+v", err) + } + quitChan <- struct{}{} + + var fp format.Fingerprint + for fpt, i := range m.p.singleUse[*rid].fpMap.fps { + if i == 0 { + fp = fpt + } + } + newMsg := format.NewMessage(format.MinimumPrimeSize) + part := newResponseMessagePart(newMsg.ContentsSize()) + part.SetContents([]byte("payload data")) + part.SetPartNum(0) + part.SetMaxParts(1) + + key := singleUse.NewResponseKey(dhKey, 0) + encryptedPayload := auth.Crypt(key, fp[:24], part.Marshal()) + + newMsg.SetKeyFP(format.NewFingerprint([]byte("Invalid Fingerprint"))) + newMsg.SetMac(singleUse.MakeMAC(key, encryptedPayload)) + newMsg.SetContents(encryptedPayload) + + err = m.processesResponse(rid, ephemeral.Id{}, newMsg.Marshal()) + if !check(err, "fingerprint") { + t.Errorf("processesResponse() did not return an error when "+ + "fingerprint was wrong: %+v", err) + } +} + +// Error path: collator fails because part number is wrong. +func TestManager_processesResponse_CollatorError(t *testing.T) { + m := newTestManager(0, false, t) + rid := id.NewIdFromString("test RID", id.User, t) + dhKey := getGroup().NewInt(5) + timeout := 5 * time.Millisecond + callback := func(payload []byte, err error) {} + + quitChan, _, err := m.p.addState(rid, dhKey, 1, callback, timeout) + if err != nil { + t.Fatalf("Failed to add state: %+v", err) + } + quitChan <- struct{}{} + + var fp format.Fingerprint + for fpt, i := range m.p.singleUse[*rid].fpMap.fps { + if i == 0 { + fp = fpt + } + } + newMsg := format.NewMessage(format.MinimumPrimeSize) + part := newResponseMessagePart(newMsg.ContentsSize()) + part.SetContents([]byte("payload data")) + part.SetPartNum(5) + part.SetMaxParts(1) + + key := singleUse.NewResponseKey(dhKey, 0) + encryptedPayload := auth.Crypt(key, fp[:24], part.Marshal()) + + newMsg.SetKeyFP(fp) + newMsg.SetMac(singleUse.MakeMAC(key, encryptedPayload)) + newMsg.SetContents(encryptedPayload) + + err = m.processesResponse(rid, ephemeral.Id{}, newMsg.Marshal()) + if !check(err, "collate") { + t.Errorf("processesResponse() did not return an error when "+ + "collation should have failed: %+v", err) + } +} diff --git a/single/reception.go b/single/reception.go new file mode 100644 index 0000000000000000000000000000000000000000..53b6c12eda52386a3544b547bee0b54b91bc1d2a --- /dev/null +++ b/single/reception.go @@ -0,0 +1,110 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package single + +import ( + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/interfaces/message" + cAuth "gitlab.com/elixxir/crypto/e2e/auth" + "gitlab.com/elixxir/crypto/e2e/singleUse" + "gitlab.com/elixxir/primitives/format" +) + +// receiveTransmissionHandler waits to receive single-use transmissions. When +// a message is received, its is returned via its registered callback. +func (m *Manager) receiveTransmissionHandler(rawMessages chan message.Receive, + quitChan <-chan struct{}) { + fp := singleUse.NewTransmitFingerprint(m.store.E2e().GetDHPublicKey()) + jww.DEBUG.Print("Waiting to receive single-use transmission messages.") + for { + select { + case <-quitChan: + jww.DEBUG.Printf("Stopping waiting to receive single-use " + + "transmission message.") + return + case msg := <-rawMessages: + jww.DEBUG.Printf("Received CMIX message; checking if it is a " + + "single-use transmission.") + + // Check if message is a single-use transmit message + cmixMsg := format.Unmarshal(msg.Payload) + if fp != cmixMsg.GetKeyFP() { + // If the verification fails, then ignore the message as it is + // likely garbled or for a different protocol + jww.INFO.Print("Failed to read single-use CMIX message: " + + "fingerprint verification failed.") + continue + } + + // Denote that the message is not garbled + jww.DEBUG.Printf("Received single-use transmission message.") + m.store.GetGarbledMessages().Remove(cmixMsg) + + // Handle message + payload, c, err := m.processTransmission(cmixMsg, fp) + if err != nil { + jww.WARN.Printf("Failed to read single-use CMIX message: %+v", + err) + continue + } + jww.DEBUG.Printf("Successfully processed single-use transmission message.") + + // Lookup the registered callback for the message's tag fingerprint + callback, err := m.callbackMap.getCallback(c.tagFP) + if err != nil { + jww.WARN.Printf("Failed to find module to pass single-use "+ + "payload: %+v", err) + continue + } + + jww.DEBUG.Printf("Calling single-use callback with tag "+ + "fingerprint %s.", c.tagFP) + + go callback(payload, c) + } + } +} + +// processTransmission unmarshalls and decrypts the message payload and +// returns the decrypted payload and the contact information for the sender. +func (m *Manager) processTransmission(msg format.Message, + fp format.Fingerprint) ([]byte, Contact, error) { + grp := m.store.E2e().GetGroup() + dhPrivKey := m.store.E2e().GetDHPrivateKey() + + // Unmarshal the CMIX message contents to a transmission message + transmitMsg, err := unmarshalTransmitMessage(msg.GetContents(), + grp.GetP().ByteLen()) + if err != nil { + return nil, Contact{}, errors.Errorf("failed to unmarshal contents: %+v", err) + } + + // Generate DH key and symmetric key + dhKey := grp.Exp(transmitMsg.GetPubKey(grp), dhPrivKey, grp.NewInt(1)) + key := singleUse.NewTransmitKey(dhKey) + + // Verify the MAC + if !singleUse.VerifyMAC(key, transmitMsg.GetPayload(), msg.GetMac()) { + return nil, Contact{}, errors.New("failed to verify MAC.") + } + + // Decrypt the transmission message payload + decryptedPayload := cAuth.Crypt(key, fp[:24], transmitMsg.GetPayload()) + + // Unmarshal the decrypted payload to a transmission message payload + payload, err := unmarshalTransmitMessagePayload(decryptedPayload) + if err != nil { + return nil, Contact{}, errors.Errorf("failed to unmarshal payload: %+v", err) + } + + c := NewContact(payload.GetRID(transmitMsg.GetPubKey(grp)), + transmitMsg.GetPubKey(grp), dhKey, payload.GetTagFP(), payload.GetMaxParts()) + + return payload.GetContents(), c, nil +} diff --git a/single/reception_test.go b/single/reception_test.go new file mode 100644 index 0000000000000000000000000000000000000000..5230f364533d183a8007ee465b47a9ef07cca1cd --- /dev/null +++ b/single/reception_test.go @@ -0,0 +1,256 @@ +package single + +import ( + "bytes" + contact2 "gitlab.com/elixxir/client/interfaces/contact" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/crypto/e2e/singleUse" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/primitives/id" + "math/rand" + "testing" + "time" +) + +// Happy path. +func TestManager_receiveTransmissionHandler(t *testing.T) { + m := newTestManager(0, false, t) + rawMessages := make(chan message.Receive, rawMessageBuffSize) + quitChan := make(chan struct{}) + partner := &contact2.Contact{ + ID: id.NewIdFromString("recipientID", id.User, t), + DhPubKey: m.store.E2e().GetDHPublicKey(), + } + tag := "Test tag" + payload := make([]byte, 132) + rand.New(rand.NewSource(42)).Read(payload) + callback, callbackChan := createReceiveComm() + + msg, _, _, _, err := m.makeTransmitCmixMessage(partner, payload, tag, 8, 32, + 30*time.Second, time.Now(), rand.New(rand.NewSource(42))) + if err != nil { + t.Fatalf("Failed to create tranmission CMIX message: %+v", err) + } + + m.callbackMap.registerCallback(tag, callback) + + go m.receiveTransmissionHandler(rawMessages, quitChan) + rawMessages <- message.Receive{ + Payload: msg.Marshal(), + } + + timer := time.NewTimer(50 * time.Millisecond) + + select { + case results := <-callbackChan: + if !bytes.Equal(results.payload, payload) { + t.Errorf("Callback received wrong payload."+ + "\nexpected: %+v\nreceived: %+v", payload, results.payload) + } + case <-timer.C: + t.Errorf("Callback failed to be called.") + } +} + +// Happy path: quit channel. +func TestManager_receiveTransmissionHandler_QuitChan(t *testing.T) { + m := newTestManager(0, false, t) + rawMessages := make(chan message.Receive, rawMessageBuffSize) + quitChan := make(chan struct{}) + tag := "Test tag" + payload := make([]byte, 132) + rand.New(rand.NewSource(42)).Read(payload) + callback, callbackChan := createReceiveComm() + + m.callbackMap.registerCallback(tag, callback) + + go m.receiveTransmissionHandler(rawMessages, quitChan) + quitChan <- struct{}{} + + timer := time.NewTimer(50 * time.Millisecond) + + select { + case results := <-callbackChan: + t.Errorf("Callback called when the message should not have been processed."+ + "\npayload: %+v\ncontact: %+v", results.payload, results.c) + case <-timer.C: + } +} + +// Error path: CMIX message fingerprint does not match. +func TestManager_receiveTransmissionHandler_FingerPrintError(t *testing.T) { + m := newTestManager(0, false, t) + rawMessages := make(chan message.Receive, rawMessageBuffSize) + quitChan := make(chan struct{}) + partner := &contact2.Contact{ + ID: id.NewIdFromString("recipientID", id.User, t), + DhPubKey: m.store.E2e().GetGroup().NewInt(42), + } + tag := "Test tag" + payload := make([]byte, 132) + rand.New(rand.NewSource(42)).Read(payload) + callback, callbackChan := createReceiveComm() + + msg, _, _, _, err := m.makeTransmitCmixMessage(partner, payload, tag, 8, 32, + 30*time.Second, time.Now(), rand.New(rand.NewSource(42))) + if err != nil { + t.Fatalf("Failed to create tranmission CMIX message: %+v", err) + } + + m.callbackMap.registerCallback(tag, callback) + + go m.receiveTransmissionHandler(rawMessages, quitChan) + rawMessages <- message.Receive{ + Payload: msg.Marshal(), + } + + timer := time.NewTimer(50 * time.Millisecond) + + select { + case results := <-callbackChan: + t.Errorf("Callback called when the fingerprints do not match."+ + "\npayload: %+v\ncontact: %+v", results.payload, results.c) + case <-timer.C: + } +} + +// Error path: cannot process transmission message. +func TestManager_receiveTransmissionHandler_ProcessMessageError(t *testing.T) { + m := newTestManager(0, false, t) + rawMessages := make(chan message.Receive, rawMessageBuffSize) + quitChan := make(chan struct{}) + partner := &contact2.Contact{ + ID: id.NewIdFromString("recipientID", id.User, t), + DhPubKey: m.store.E2e().GetDHPublicKey(), + } + tag := "Test tag" + payload := make([]byte, 132) + rand.New(rand.NewSource(42)).Read(payload) + callback, callbackChan := createReceiveComm() + + msg, _, _, _, err := m.makeTransmitCmixMessage(partner, payload, tag, 8, 32, + 30*time.Second, time.Now(), rand.New(rand.NewSource(42))) + if err != nil { + t.Fatalf("Failed to create tranmission CMIX message: %+v", err) + } + + msg.SetMac(make([]byte, format.MacLen)) + + m.callbackMap.registerCallback(tag, callback) + + go m.receiveTransmissionHandler(rawMessages, quitChan) + rawMessages <- message.Receive{ + Payload: msg.Marshal(), + } + + timer := time.NewTimer(50 * time.Millisecond) + + select { + case results := <-callbackChan: + t.Errorf("Callback called when the message should not have been processed."+ + "\npayload: %+v\ncontact: %+v", results.payload, results.c) + case <-timer.C: + } +} + +// Error path: tag fingerprint does not match. +func TestManager_receiveTransmissionHandler_TagFpError(t *testing.T) { + m := newTestManager(0, false, t) + rawMessages := make(chan message.Receive, rawMessageBuffSize) + quitChan := make(chan struct{}) + partner := &contact2.Contact{ + ID: id.NewIdFromString("recipientID", id.User, t), + DhPubKey: m.store.E2e().GetDHPublicKey(), + } + tag := "Test tag" + payload := make([]byte, 132) + rand.New(rand.NewSource(42)).Read(payload) + + msg, _, _, _, err := m.makeTransmitCmixMessage(partner, payload, tag, 8, 32, + 30*time.Second, time.Now(), rand.New(rand.NewSource(42))) + if err != nil { + t.Fatalf("Failed to create tranmission CMIX message: %+v", err) + } + + go m.receiveTransmissionHandler(rawMessages, quitChan) + rawMessages <- message.Receive{ + Payload: msg.Marshal(), + } +} + +// Happy path. +func TestManager_processTransmission(t *testing.T) { + m := newTestManager(0, false, t) + partner := &contact2.Contact{ + ID: id.NewIdFromString("partnerID", id.User, t), + DhPubKey: m.store.E2e().GetDHPublicKey(), + } + tag := "test tag" + payload := []byte("This is the payload.") + maxMsgs := uint8(6) + cmixMsg, dhKey, rid, _, err := m.makeTransmitCmixMessage(partner, payload, + tag, maxMsgs, 32, 30*time.Second, time.Now(), rand.New(rand.NewSource(42))) + if err != nil { + t.Fatalf("Failed to generate expected CMIX message: %+v", err) + } + + tMsg, err := unmarshalTransmitMessage(cmixMsg.GetContents(), m.store.E2e().GetGroup().GetP().ByteLen()) + if err != nil { + t.Fatalf("Failed to make transmitMessage: %+v", err) + } + + expectedC := NewContact(rid, tMsg.GetPubKey(m.store.E2e().GetGroup()), + dhKey, singleUse.NewTagFP(tag), maxMsgs) + + fp := singleUse.NewTransmitFingerprint(m.store.E2e().GetDHPublicKey()) + content, testC, err := m.processTransmission(cmixMsg, fp) + if err != nil { + t.Errorf("processTransmission() produced an error: %+v", err) + } + + if !expectedC.Equal(testC) { + t.Errorf("processTransmission() did not return the expected values."+ + "\nexpected: %+v\nrecieved: %+v", expectedC, testC) + } + + if !bytes.Equal(payload, content) { + t.Errorf("processTransmission() returned the wrong payload."+ + "\nexpected: %+v\nreceived: %+v", payload, content) + } +} + +// Error path: fails to unmarshal transmitMessage. +func TestManager_processTransmission_TransmitMessageUnmarshalError(t *testing.T) { + m := newTestManager(0, false, t) + cmixMsg := format.NewMessage(format.MinimumPrimeSize) + + fp := singleUse.NewTransmitFingerprint(m.store.E2e().GetDHPublicKey()) + _, _, err := m.processTransmission(cmixMsg, fp) + if !check(err, "failed to unmarshal contents") { + t.Errorf("processTransmission() did not produce an error when "+ + "the transmitMessage failed to unmarshal: %+v", err) + } +} + +// Error path: MAC fails to verify. +func TestManager_processTransmission_MacVerifyError(t *testing.T) { + m := newTestManager(0, false, t) + partner := &contact2.Contact{ + ID: id.NewIdFromString("partnerID", id.User, t), + DhPubKey: m.store.E2e().GetDHPublicKey(), + } + cmixMsg, _, _, _, err := m.makeTransmitCmixMessage(partner, []byte{}, "", 6, + 32, 30*time.Second, time.Now(), rand.New(rand.NewSource(42))) + if err != nil { + t.Fatalf("Failed to generate expected CMIX message: %+v", err) + } + + cmixMsg.SetMac(make([]byte, 32)) + + fp := singleUse.NewTransmitFingerprint(m.store.E2e().GetDHPublicKey()) + _, _, err = m.processTransmission(cmixMsg, fp) + if !check(err, "failed to verify MAC") { + t.Errorf("processTransmission() did not produce an error when "+ + "the MAC failed to verify: %+v", err) + } +} diff --git a/single/response.go b/single/response.go new file mode 100644 index 0000000000000000000000000000000000000000..e847847b0f5f8206952550dc84d39b05c7e2be24 --- /dev/null +++ b/single/response.go @@ -0,0 +1,189 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package single + +import ( + "bytes" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/interfaces/utility" + ds "gitlab.com/elixxir/comms/network/dataStructures" + cAuth "gitlab.com/elixxir/crypto/e2e/auth" + "gitlab.com/elixxir/crypto/e2e/singleUse" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/elixxir/primitives/states" + "gitlab.com/xx_network/primitives/id" + "sync" + "sync/atomic" + "time" +) + +// GetMaxResponsePayloadSize returns the maximum payload size for a response +// message. +func (m *Manager) GetMaxResponsePayloadSize() int { + // Generate empty messages to determine the available space for the payload + cmixMsg := format.NewMessage(m.store.Cmix().GetGroup().GetP().ByteLen()) + responseMsg := newResponseMessagePart(cmixMsg.ContentsSize()) + + return responseMsg.GetMaxContentsSize() +} + +// RespondSingleUse creates the single-use response messages with the given +// payload and sends them to the given partner. +func (m *Manager) RespondSingleUse(partner Contact, payload []byte, + timeout time.Duration) error { + return m.respondSingleUse(partner, payload, timeout, + m.net.GetInstance().GetRoundEvents()) +} + +// respondSingleUse allows for easier testing. +func (m *Manager) respondSingleUse(partner Contact, payload []byte, + timeout time.Duration, roundEvents roundEvents) error { + // Ensure that only a single reply can be sent in response + firstUse := atomic.CompareAndSwapInt32(partner.used, 0, 1) + if !firstUse { + return errors.Errorf("cannot send to single-use contact that has " + + "already been sent to.") + } + + // Generate messages from payload + msgs, err := m.makeReplyCmixMessages(partner, payload) + if err != nil { + return errors.Errorf("failed to create new CMIX messages: %+v", err) + } + + jww.DEBUG.Printf("Created %d single-use response CMIX message parts.", len(msgs)) + + // Tracks the numbers of rounds that messages are sent on + rounds := make([]id.Round, len(msgs)) + + sendResults := make(chan ds.EventReturn, len(msgs)) + + // Send CMIX messages + wg := sync.WaitGroup{} + for i, cmixMsg := range msgs { + wg.Add(1) + cmixMsgFunc := cmixMsg + j := i + go func() { + defer wg.Done() + // Send Message + round, ephID, err := m.net.SendCMIX(cmixMsgFunc, partner.partner, params.GetDefaultCMIX()) + if err != nil { + jww.ERROR.Print("Failed to send single-use response CMIX "+ + "message part %d: %+v", j, err) + } + jww.DEBUG.Printf("Sending single-use response CMIX message part "+ + "%d on round %d to ephemeral ID %d.", j, round, ephID.Int64()) + rounds[j] = round + + roundEvents.AddRoundEventChan(round, sendResults, timeout, + states.COMPLETED, states.FAILED) + }() + } + + // Wait for all go routines to finish + wg.Wait() + jww.DEBUG.Printf("Sent %d single-use response CMIX messages to %s.", len(msgs), partner.partner) + + // Count the number of rounds + roundMap := map[id.Round]struct{}{} + for _, roundID := range rounds { + roundMap[roundID] = struct{}{} + } + + // Wait until the result tracking responds + success, numRoundFail, numTimeOut := utility.TrackResults(sendResults, len(roundMap)) + if !success { + return errors.Errorf("tracking results of %d rounds: %d round "+ + "failures, %d round event time outs; the send cannot be retried.", + len(rounds), numRoundFail, numTimeOut) + } + jww.DEBUG.Printf("Tracked %d single-use response message round(s).", len(roundMap)) + + return nil +} + +// makeReplyCmixMessages +func (m *Manager) makeReplyCmixMessages(partner Contact, payload []byte) ([]format.Message, error) { + // Generate internal payloads based off key size to determine if the passed + // in payload is too large to fit in the available contents + cmixMsg := format.NewMessage(m.store.Cmix().GetGroup().GetP().ByteLen()) + responseMsg := newResponseMessagePart(cmixMsg.ContentsSize()) + + // Maximum payload size is the maximum amount of room in each message + // multiplied by the number of messages + maxPayloadSize := responseMsg.GetMaxContentsSize() * int(partner.GetMaxParts()) + + if maxPayloadSize < len(payload) { + return nil, errors.Errorf("length of provided payload (%d) too long "+ + "for message payload capacity (%d = %d byte payload * %d messages).", + len(payload), maxPayloadSize, responseMsg.GetMaxContentsSize(), + partner.GetMaxParts()) + } + + // Split payloads + payloadParts := m.splitPayload(payload, responseMsg.GetMaxContentsSize(), + int(partner.GetMaxParts())) + + // Create CMIX messages + cmixMsgs := make([]format.Message, len(payloadParts)) + wg := sync.WaitGroup{} + for i, contents := range payloadParts { + wg.Add(1) + go func(partner Contact, contents []byte, i uint8) { + defer wg.Done() + cmixMsgs[i] = m.makeMessagePart(partner, contents, uint8(len(payloadParts)), i) + }(partner, contents, uint8(i)) + } + + // Wait for all go routines to finish + wg.Wait() + + return cmixMsgs, nil +} + +// makeMessagePart generates a CMIX message containing a responseMessagePart. +func (m *Manager) makeMessagePart(partner Contact, contents []byte, maxPart, i uint8) format.Message { + cmixMsg := format.NewMessage(m.store.Cmix().GetGroup().GetP().ByteLen()) + responseMsg := newResponseMessagePart(cmixMsg.ContentsSize()) + + // Compose response message + responseMsg.SetMaxParts(maxPart) + responseMsg.SetPartNum(i) + responseMsg.SetContents(contents) + + // Encrypt payload + fp := singleUse.NewResponseFingerprint(partner.dhKey, uint64(i)) + key := singleUse.NewResponseKey(partner.dhKey, uint64(i)) + encryptedPayload := cAuth.Crypt(key, fp[:24], responseMsg.Marshal()) + + // Generate CMIX message MAC + mac := singleUse.MakeMAC(key, encryptedPayload) + + // Compose CMIX message contents + cmixMsg.SetContents(encryptedPayload) + cmixMsg.SetKeyFP(fp) + cmixMsg.SetMac(mac) + + return cmixMsg +} + +// splitPayload splits the given payload into separate payload parts and returns +// them in a slice. Each part's size is less than or equal to maxSize. Any extra +// data in the payload is not used if it is longer than the maximum capacity. +func (m *Manager) splitPayload(payload []byte, maxSize, maxParts int) [][]byte { + var parts [][]byte + buff := bytes.NewBuffer(payload) + + for i := 0; i < maxParts && buff.Len() > 0; i++ { + parts = append(parts, buff.Next(maxSize)) + } + return parts +} diff --git a/single/responseMessage.go b/single/responseMessage.go index a4ec404626f7837f3710760a0fd95c844e87051f..28aec540b40c2386842c0e4eb79701d82ee78e41 100644 --- a/single/responseMessage.go +++ b/single/responseMessage.go @@ -8,30 +8,42 @@ package single import ( + "encoding/binary" "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" ) const ( - partNumLen = 1 - maxPartsLen = 1 + partNumLen = 1 + maxPartsLen = 1 + responseMinSize = partNumLen + maxPartsLen + sizeSize ) +/* ++-----------------------------------------+ +| CMIX Message Contents | ++---------+----------+---------+----------+ +| partNum | maxParts | size | contents | +| 1 bytes | 1 byte | 2 bytes | variable | ++------------+----------+---------+-------+ +*/ + type responseMessagePart struct { data []byte // Serial of all contents partNum []byte // Index of message in a series of messages maxParts []byte // The number of parts in this message. - payload []byte // The encrypted payload + size []byte // Size of the contents + contents []byte // The encrypted contents } // newResponseMessagePart generates a new response message part of the specified // size. func newResponseMessagePart(externalPayloadSize int) responseMessagePart { - if externalPayloadSize < partNumLen+maxPartsLen { - jww.FATAL.Panicf("Failed to create new single use response message "+ + if externalPayloadSize < responseMinSize { + jww.FATAL.Panicf("Failed to create new single-use response message "+ "part: size of external payload (%d) is too small to contain the "+ "message part number and max parts (%d)", - externalPayloadSize, partNumLen+maxPartsLen) + externalPayloadSize, responseMinSize) } return mapResponseMessagePart(make([]byte, externalPayloadSize)) @@ -44,16 +56,17 @@ func mapResponseMessagePart(data []byte) responseMessagePart { data: data, partNum: data[:partNumLen], maxParts: data[partNumLen : maxPartsLen+partNumLen], - payload: data[maxPartsLen+partNumLen:], + size: data[maxPartsLen+partNumLen : responseMinSize], + contents: data[responseMinSize:], } } // unmarshalResponseMessage converts a byte buffer into a response message part. func unmarshalResponseMessage(b []byte) (responseMessagePart, error) { - if len(b) < partNumLen+maxPartsLen { + if len(b) < responseMinSize { return responseMessagePart{}, errors.Errorf("Size of passed in bytes "+ "(%d) is too small to contain the message part number and max "+ - "parts (%d).", len(b), partNumLen+maxPartsLen) + "parts (%d).", len(b), responseMinSize) } return mapResponseMessagePart(b), nil } @@ -83,23 +96,32 @@ func (m responseMessagePart) SetMaxParts(max uint8) { copy(m.maxParts, []byte{max}) } -// GetPayload returns the encrypted payload of the message part. -func (m responseMessagePart) GetPayload() []byte { - return m.payload +// GetContents returns the contents of the message part. +func (m responseMessagePart) GetContents() []byte { + return m.contents[:binary.BigEndian.Uint16(m.size)] +} + +// GetContentsSize returns the length of the contents. +func (m responseMessagePart) GetContentsSize() int { + return int(binary.BigEndian.Uint16(m.size)) } -// GetPayloadSize returns the length of the encrypted payload. -func (m responseMessagePart) GetPayloadSize() int { - return len(m.payload) +// GetMaxContentsSize returns the max capacity of the contents. +func (m responseMessagePart) GetMaxContentsSize() int { + return len(m.contents) } -// SetPayload sets the encrypted payload of the message part. -func (m responseMessagePart) SetPayload(payload []byte) { - if len(payload) != m.GetPayloadSize() { - jww.FATAL.Panicf("Failed to set payload of single use response "+ - "message part: size of supplied payload (%d) is different from "+ - "the size of the message payload (%d).", - len(payload), m.GetPayloadSize()+maxPartsLen) +// SetContents sets the contents of the message part. Does not zero out previous +// contents. +func (m responseMessagePart) SetContents(contents []byte) { + if len(contents) > len(m.contents) { + jww.FATAL.Panicf("Failed to set contents of single-use response "+ + "message part: max size of message contents (%d) is smaller than "+ + "the size of the supplied contents (%d).", + len(m.contents), len(contents)) } - copy(m.payload, payload) + + binary.BigEndian.PutUint16(m.size, uint16(len(contents))) + + copy(m.contents, contents) } diff --git a/single/responseMessage_test.go b/single/responseMessage_test.go index cd3041090bfb537e94ce12fc652d2784a86dd801..8871fdc898efa1b00f84c27a914dc6850bcef25c 100644 --- a/single/responseMessage_test.go +++ b/single/responseMessage_test.go @@ -11,6 +11,7 @@ import ( "bytes" "math/rand" "reflect" + "strings" "testing" ) @@ -22,21 +23,22 @@ func Test_newResponseMessagePart(t *testing.T) { data: make([]byte, payloadSize), partNum: make([]byte, partNumLen), maxParts: make([]byte, maxPartsLen), - payload: make([]byte, payloadSize-partNumLen-maxPartsLen), + size: make([]byte, sizeSize), + contents: make([]byte, payloadSize-partNumLen-maxPartsLen-sizeSize), } rmp := newResponseMessagePart(payloadSize) if !reflect.DeepEqual(expected, rmp) { t.Errorf("newResponseMessagePart() did not return the expected "+ - "responseMessagePart.\nexpected: %+v\nreceived: %v", expected, rmp) + "responseMessagePart.\nexpected: %+v\nreceived: %+v", expected, rmp) } } -// Error path: provided payload size is not large enough. +// Error path: provided contents size is not large enough. func Test_newResponseMessagePart_PayloadSizeError(t *testing.T) { defer func() { - if r := recover(); r == nil { + if r := recover(); r == nil || !strings.Contains(r.(string), "size of external payload") { t.Error("newResponseMessagePart() did not panic when the size of " + "the payload is smaller than the required size.") } @@ -50,11 +52,13 @@ func Test_mapResponseMessagePart(t *testing.T) { prng := rand.New(rand.NewSource(42)) expectedPartNum := uint8(prng.Uint32()) expectedMaxParts := uint8(prng.Uint32()) - expectedPayload := make([]byte, prng.Intn(2000)) - prng.Read(expectedPayload) + size := []byte{uint8(prng.Uint64()), uint8(prng.Uint64())} + expectedContents := make([]byte, prng.Intn(2000)) + prng.Read(expectedContents) var data []byte data = append(data, expectedPartNum, expectedMaxParts) - data = append(data, expectedPayload...) + data = append(data, size...) + data = append(data, expectedContents...) rmp := mapResponseMessagePart(data) @@ -68,9 +72,9 @@ func Test_mapResponseMessagePart(t *testing.T) { "\nexpected: %d\nreceived: %d", expectedMaxParts, rmp.maxParts[0]) } - if !bytes.Equal(expectedPayload, rmp.payload) { - t.Errorf("mapResponseMessagePart() did not correctly map payload."+ - "\nexpected: %+v\nreceived: %+v", expectedPayload, rmp.payload) + if !bytes.Equal(expectedContents, rmp.contents) { + t.Errorf("mapResponseMessagePart() did not correctly map contents."+ + "\nexpected: %+v\nreceived: %+v", expectedContents, rmp.contents) } if !bytes.Equal(data, rmp.data) { @@ -137,35 +141,41 @@ func TestResponseMessagePart_SetMaxParts_GetMaxParts(t *testing.T) { } // Happy path. -func TestResponseMessagePart_SetPayload_GetPayload_GetPayloadSize(t *testing.T) { +func TestResponseMessagePart_SetContents_GetContents_GetContentsSize_GetMaxContentsSize(t *testing.T) { prng := rand.New(rand.NewSource(42)) externalPayloadSize := prng.Intn(2000) - payloadSize := externalPayloadSize - partNumLen - maxPartsLen - expectedPayload := make([]byte, payloadSize) - prng.Read(expectedPayload) + contentSize := externalPayloadSize - responseMinSize - 10 + expectedContents := make([]byte, contentSize) + prng.Read(expectedContents) rmp := newResponseMessagePart(externalPayloadSize) - rmp.SetPayload(expectedPayload) + rmp.SetContents(expectedContents) - if !bytes.Equal(expectedPayload, rmp.GetPayload()) { - t.Errorf("GetPayload() failed to return the expected payload."+ - "\nexpected: %+v\nrecieved: %+v", expectedPayload, rmp.GetPayload()) + if !bytes.Equal(expectedContents, rmp.GetContents()) { + t.Errorf("GetContents() failed to return the expected contents."+ + "\nexpected: %+v\nrecieved: %+v", expectedContents, rmp.GetContents()) } - if payloadSize != rmp.GetPayloadSize() { - t.Errorf("GetPayloadSize() failed to return the expected payload size."+ - "\nexpected: %d\nrecieved: %d", payloadSize, rmp.GetPayloadSize()) + if contentSize != rmp.GetContentsSize() { + t.Errorf("GetContentsSize() failed to return the expected contents size."+ + "\nexpected: %d\nrecieved: %d", contentSize, rmp.GetContentsSize()) + } + + if externalPayloadSize-responseMinSize != rmp.GetMaxContentsSize() { + t.Errorf("GetMaxContentsSize() failed to return the expected max contents size."+ + "\nexpected: %d\nrecieved: %d", + externalPayloadSize-responseMinSize, rmp.GetMaxContentsSize()) } } -// Error path: size of supplied payload does not match message payload size. -func TestResponseMessagePart_SetPayload_PayloadSizeError(t *testing.T) { +// Error path: size of supplied contents does not match message contents size. +func TestResponseMessagePart_SetContents_ContentsSizeError(t *testing.T) { defer func() { - if r := recover(); r == nil { - t.Error("SetPayload() did not panic when the size of the supplied " + - "bytes is not the same as the payload content size.") + if r := recover(); r == nil || !strings.Contains(r.(string), "max size of message contents") { + t.Error("SetContents() did not panic when the size of the supplied " + + "bytes is larger than the content size.") } }() rmp := newResponseMessagePart(255) - rmp.SetPayload([]byte{1, 2, 3}) + rmp.SetContents(make([]byte, 500)) } diff --git a/single/response_test.go b/single/response_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c19f472bafa42e122aea75f3f1085f3fc2f522cb --- /dev/null +++ b/single/response_test.go @@ -0,0 +1,260 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package single + +import ( + "bytes" + cAuth "gitlab.com/elixxir/crypto/e2e/auth" + "gitlab.com/elixxir/crypto/e2e/singleUse" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/primitives/id" + "math/rand" + "reflect" + "testing" + "time" +) + +// Happy path. +func TestManager_GetMaxResponsePayloadSize(t *testing.T) { + m := newTestManager(0, false, t) + cmixPrimeSize := m.store.Cmix().GetGroup().GetP().ByteLen() + expectedSize := 2*cmixPrimeSize - format.KeyFPLen - format.MacLen - format.RecipientIDLen - responseMinSize + testSize := m.GetMaxResponsePayloadSize() + + if expectedSize != testSize { + t.Errorf("GetMaxResponsePayloadSize() failed to return the expected size."+ + "\nexpected: %d\nreceived: %d", expectedSize, testSize) + } +} + +// Happy path. +func TestManager_respondSingleUse(t *testing.T) { + m := newTestManager(0, false, t) + used := int32(0) + partner := Contact{ + partner: id.NewIdFromString("partner ID:", id.User, t), + partnerPubKey: m.store.E2e().GetGroup().NewInt(42), + dhKey: m.store.E2e().GetGroup().NewInt(99), + tagFP: singleUse.NewTagFP("tag"), + maxParts: 10, + used: &used, + } + payload := make([]byte, 400*int(partner.maxParts)) + rand.New(rand.NewSource(42)).Read(payload) + re := newTestRoundEvents(false) + + expectedMsgs, err := m.makeReplyCmixMessages(partner, payload) + if err != nil { + t.Fatalf("Failed to created expected messages: %+v", err) + } + + err = m.respondSingleUse(partner, payload, 10*time.Millisecond, re) + if err != nil { + t.Errorf("respondSingleUse() produced an error: %+v", err) + } + + // Check that all messages are expected and received + if len(m.net.(*testNetworkManager).msgs) != int(partner.GetMaxParts()) { + t.Errorf("Recieved incorrect number of messages."+ + "\nexpected: %d\nreceived: %d", int(partner.GetMaxParts()), + len(m.net.(*testNetworkManager).msgs)) + } + + // Check that all received messages were expected + var exists bool + for _, received := range m.net.(*testNetworkManager).msgs { + exists = false + for _, msg := range expectedMsgs { + if reflect.DeepEqual(msg, received) { + exists = true + } + } + if !exists { + t.Errorf("Unexpected message: %+v", received) + } + } +} + +// Error path: response has already been sent. +func TestManager_respondSingleUse_ResponseUsedError(t *testing.T) { + m := newTestManager(0, false, t) + used := int32(1) + partner := Contact{ + partner: id.NewIdFromString("partner ID:", id.User, t), + partnerPubKey: m.store.E2e().GetGroup().NewInt(42), + dhKey: m.store.E2e().GetGroup().NewInt(99), + tagFP: singleUse.NewTagFP("tag"), + maxParts: 10, + used: &used, + } + + err := m.respondSingleUse(partner, []byte{}, 10*time.Millisecond, newTestRoundEvents(false)) + if !check(err, "cannot send to single-use contact that has already been sent to") { + t.Errorf("respondSingleUse() did not produce the expected error when "+ + "the contact has been used: %+v", err) + } +} + +// Error path: cannot create CMIX message when payload is too large. +func TestManager_respondSingleUse_MakeCmixMessageError(t *testing.T) { + m := newTestManager(0, false, t) + used := int32(0) + partner := Contact{ + partner: id.NewIdFromString("partner ID:", id.User, t), + partnerPubKey: m.store.E2e().GetGroup().NewInt(42), + dhKey: m.store.E2e().GetGroup().NewInt(99), + tagFP: singleUse.NewTagFP("tag"), + maxParts: 10, + used: &used, + } + payload := make([]byte, 500*int(partner.maxParts)) + rand.New(rand.NewSource(42)).Read(payload) + + err := m.respondSingleUse(partner, payload, 10*time.Millisecond, newTestRoundEvents(false)) + if !check(err, "failed to create new CMIX messages") { + t.Errorf("respondSingleUse() did not produce the expected error when "+ + "the CMIX message creation failed: %+v", err) + } +} + +// Error path: TrackResults returns an error. +func TestManager_respondSingleUse_TrackResultsError(t *testing.T) { + m := newTestManager(0, false, t) + used := int32(0) + partner := Contact{ + partner: id.NewIdFromString("partner ID:", id.User, t), + partnerPubKey: m.store.E2e().GetGroup().NewInt(42), + dhKey: m.store.E2e().GetGroup().NewInt(99), + tagFP: singleUse.NewTagFP("tag"), + maxParts: 10, + used: &used, + } + payload := make([]byte, 400*int(partner.maxParts)) + rand.New(rand.NewSource(42)).Read(payload) + + err := m.respondSingleUse(partner, payload, 10*time.Millisecond, newTestRoundEvents(true)) + if !check(err, "tracking results of") { + t.Errorf("respondSingleUse() did not produce the expected error when "+ + "the CMIX message creation failed: %+v", err) + } +} + +// Happy path. +func TestManager_makeReplyCmixMessages(t *testing.T) { + m := newTestManager(0, false, t) + partner := Contact{ + partner: id.NewIdFromString("partner ID:", id.User, t), + partnerPubKey: m.store.E2e().GetGroup().NewInt(42), + dhKey: m.store.E2e().GetGroup().NewInt(99), + tagFP: singleUse.NewTagFP("tag"), + maxParts: 255, + } + payload := make([]byte, 400*int(partner.maxParts)) + rand.New(rand.NewSource(42)).Read(payload) + + msgs, err := m.makeReplyCmixMessages(partner, payload) + if err != nil { + t.Errorf("makeReplyCmixMessages() returned an error: %+v", err) + } + + buff := bytes.NewBuffer(payload) + for i, msg := range msgs { + checkReplyCmixMessage(partner, + buff.Next(len(msgs[0].GetContents())-responseMinSize), msg, len(msgs), i, t) + } +} + +// Error path: size of payload too large. +func TestManager_makeReplyCmixMessages_PayloadSizeError(t *testing.T) { + m := newTestManager(0, false, t) + partner := Contact{ + partner: id.NewIdFromString("partner ID:", id.User, t), + partnerPubKey: m.store.E2e().GetGroup().NewInt(42), + dhKey: m.store.E2e().GetGroup().NewInt(99), + tagFP: singleUse.NewTagFP("tag"), + maxParts: 255, + } + payload := make([]byte, 500*int(partner.maxParts)) + rand.New(rand.NewSource(42)).Read(payload) + + _, err := m.makeReplyCmixMessages(partner, payload) + if err == nil { + t.Error("makeReplyCmixMessages() did not return an error when the " + + "payload is too large.") + } +} + +func checkReplyCmixMessage(c Contact, payload []byte, msg format.Message, maxParts, i int, t *testing.T) { + expectedFP := singleUse.NewResponseFingerprint(c.dhKey, uint64(i)) + key := singleUse.NewResponseKey(c.dhKey, uint64(i)) + expectedMac := singleUse.MakeMAC(key, msg.GetContents()) + + // Check CMIX message + if expectedFP != msg.GetKeyFP() { + t.Errorf("CMIX message #%d had incorrect fingerprint."+ + "\nexpected: %s\nrecieved: %s", i, expectedFP, msg.GetKeyFP()) + } + + if !singleUse.VerifyMAC(key, msg.GetContents(), msg.GetMac()) { + t.Errorf("CMIX message #%d had incorrect MAC."+ + "\nexpected: %+v\nrecieved: %+v", i, expectedMac, msg.GetMac()) + } + + // Decrypt payload + decryptedPayload := cAuth.Crypt(key, expectedFP[:24], msg.GetContents()) + responseMsg, err := unmarshalResponseMessage(decryptedPayload) + if err != nil { + t.Errorf("Failed to unmarshal pay load of CMIX message #%d: %+v", i, err) + } + + if !bytes.Equal(payload, responseMsg.GetContents()) { + t.Errorf("Response message #%d had incorrect contents."+ + "\nexpected: %+v\nrecieved: %+v", + i, payload, responseMsg.GetContents()) + } + + if uint8(maxParts) != responseMsg.GetMaxParts() { + t.Errorf("Response message #%d had incorrect max parts."+ + "\nexpected: %+v\nrecieved: %+v", + i, maxParts, responseMsg.GetMaxParts()) + } + + if i != int(responseMsg.GetPartNum()) { + t.Errorf("Response message #%d had incorrect part number."+ + "\nexpected: %+v\nrecieved: %+v", + i, i, responseMsg.GetPartNum()) + } +} + +// Happy path. +func TestManager_splitPayload(t *testing.T) { + m := newTestManager(0, false, t) + maxSize := 5 + maxParts := 10 + payload := []byte("0123456789012345678901234567890123456789012345678901234" + + "5678901234567890123456789012345678901234567890123456789") + expectedParts := [][]byte{ + payload[:maxSize], + payload[maxSize : 2*maxSize], + payload[2*maxSize : 3*maxSize], + payload[3*maxSize : 4*maxSize], + payload[4*maxSize : 5*maxSize], + payload[5*maxSize : 6*maxSize], + payload[6*maxSize : 7*maxSize], + payload[7*maxSize : 8*maxSize], + payload[8*maxSize : 9*maxSize], + payload[9*maxSize : 10*maxSize], + } + + testParts := m.splitPayload(payload, maxSize, maxParts) + + if !reflect.DeepEqual(expectedParts, testParts) { + t.Errorf("splitPayload() failed to correctly split the payload."+ + "\nexpected: %s\nreceived: %s", expectedParts, testParts) + } +} diff --git a/single/singleUseMap.go b/single/singleUseMap.go new file mode 100644 index 0000000000000000000000000000000000000000..48c3f096f864d5d94e88cf1d7e4cd34b36811f07 --- /dev/null +++ b/single/singleUseMap.go @@ -0,0 +1,123 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package single + +import ( + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/xx_network/primitives/id" + "sync" + "sync/atomic" + "time" +) + +// pending contains a map of all pending single-use states. +type pending struct { + singleUse map[id.ID]*state + sync.RWMutex +} + +type replyComm func(payload []byte, err error) + +// state contains the information and state of each single-use message that is +// being transmitted. +type state struct { + dhKey *cyclic.Int + fpMap *fingerprintMap // List of fingerprints for each response part + c *collator // Collects all response message parts + callback replyComm // Returns the error status of the communication + quitChan chan struct{} // Sending on channel kills the timeout handler +} + +// newPending creates a pending object with an empty map. +func newPending() *pending { + return &pending{ + singleUse: map[id.ID]*state{}, + } +} + +// newState generates a new state object with the fingerprint map and collator +// initialised. +func newState(dhKey *cyclic.Int, messageCount uint8, callback replyComm) *state { + return &state{ + dhKey: dhKey, + fpMap: newFingerprintMap(dhKey, uint64(messageCount)), + c: newCollator(uint64(messageCount)), + callback: callback, + quitChan: make(chan struct{}), + } +} + +// addState adds a new state to the map and starts a thread waiting for all the +// message parts or for the timeout to occur. +func (p *pending) addState(rid *id.ID, dhKey *cyclic.Int, maxMsgs uint8, + callback replyComm, timeout time.Duration) (chan struct{}, *int32, error) { + p.Lock() + + // Check if the state already exists + if _, exists := p.singleUse[*rid]; exists { + return nil, nil, errors.Errorf("a state already exists in the map with "+ + "the ID %s.", rid) + } + + jww.DEBUG.Printf("Successfully added single-use state with the ID %s to "+ + "the map.", rid) + + // Add the state + p.singleUse[*rid] = newState(dhKey, maxMsgs, callback) + quitChan := p.singleUse[*rid].quitChan + p.Unlock() + + // Create atomic which is set when the timeoutHandler thread is killed + quit := int32(0) + + go p.timeoutHandler(rid, callback, timeout, quitChan, &quit) + + return quitChan, &quit, nil +} + +// timeoutHandler waits for the signal to complete or times out and deletes the +// state. +func (p *pending) timeoutHandler(rid *id.ID, callback replyComm, + timeout time.Duration, quitChan chan struct{}, quit *int32) { + jww.DEBUG.Printf("Starting handler for sending single-use transmission "+ + "that will timeout after %s.", timeout) + + timer := time.NewTimer(timeout) + + // Signal on the atomic when this thread quits + defer func() { + atomic.StoreInt32(quit, 1) + }() + + select { + case <-quitChan: + jww.DEBUG.Print("Single-use transmission timeout handler quitting.") + return + case <-timer.C: + jww.WARN.Printf("Single-use transmission timeout handler timed out "+ + "after %s.", timeout) + + p.Lock() + if _, exists := p.singleUse[*rid]; !exists { + p.Unlock() + return + } + delete(p.singleUse, *rid) + + p.Unlock() + + err := errors.Errorf("waiting for response to single-use transmission "+ + "timed out after %s.", timeout.String()) + jww.DEBUG.Printf("Deleted single-use from map. Calling callback with "+ + "error: %+v", err) + + callback(nil, err) + } +} diff --git a/single/singleUseMap_test.go b/single/singleUseMap_test.go new file mode 100644 index 0000000000000000000000000000000000000000..3b4a18171640935735b5d3b28042141c77ded768 --- /dev/null +++ b/single/singleUseMap_test.go @@ -0,0 +1,206 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package single + +import ( + "gitlab.com/xx_network/primitives/id" + "reflect" + "strings" + "sync/atomic" + "testing" + "time" +) + +// Happy path: trigger the exit channel. +func Test_pending_addState_ExitChan(t *testing.T) { + p := newPending() + rid := id.NewIdFromString("test RID", id.User, t) + dhKey := getGroup().NewInt(5) + maxMsgs := uint8(6) + timeout := 500 * time.Millisecond + callback, callbackChan := createReplyComm() + + quitChan, quit, err := p.addState(rid, dhKey, maxMsgs, callback, timeout) + if err != nil { + t.Errorf("addState() returned an error: %+v", err) + } + + hasQuit := atomic.CompareAndSwapInt32(quit, 0, 1) + if !hasQuit { + t.Error("Quit atomic called.") + } + + expectedState := newState(dhKey, maxMsgs, callback) + + quitChan <- struct{}{} + + timer := time.NewTimer(timeout + 1*time.Millisecond) + + select { + case results := <-callbackChan: + t.Errorf("Callback called when the quit channel was used."+ + "\npayload: %+v\nerror: %+v", results.payload, results.err) + case <-timer.C: + } + + state, exists := p.singleUse[*rid] + if !exists { + t.Error("State not found in map.") + } + if !equalState(*expectedState, *state, t) { + t.Errorf("State in map is incorrect.\nexpected: %+v\nreceived: %+v", + *expectedState, *state) + } + + hasQuit = atomic.CompareAndSwapInt32(quit, 0, 1) + if hasQuit { + t.Error("Quit atomic not called.") + } +} + +// Happy path: state is removed before deletion can occur. +func Test_pending_addState_StateRemoved(t *testing.T) { + p := newPending() + rid := id.NewIdFromString("test RID", id.User, t) + dhKey := getGroup().NewInt(5) + maxMsgs := uint8(6) + timeout := 5 * time.Millisecond + callback, callbackChan := createReplyComm() + + _, _, err := p.addState(rid, dhKey, maxMsgs, callback, timeout) + if err != nil { + t.Errorf("addState() returned an error: %+v", err) + } + + p.Lock() + delete(p.singleUse, *rid) + p.Unlock() + + timer := time.NewTimer(timeout + 1*time.Millisecond) + + select { + case results := <-callbackChan: + t.Errorf("Callback should not have been called.\npayload: %+v\nerror: %+v", + results.payload, results.err) + case <-timer.C: + } +} + +// Error path: timeout occurs and deletes the entry from the map. +func Test_pending_addState_TimeoutError(t *testing.T) { + p := newPending() + rid := id.NewIdFromString("test RID", id.User, t) + dhKey := getGroup().NewInt(5) + maxMsgs := uint8(6) + timeout := 5 * time.Millisecond + callback, callbackChan := createReplyComm() + + _, _, err := p.addState(rid, dhKey, maxMsgs, callback, timeout) + if err != nil { + t.Errorf("addState() returned an error: %+v", err) + } + + expectedState := newState(dhKey, maxMsgs, callback) + p.Lock() + state, exists := p.singleUse[*rid] + p.Unlock() + if !exists { + t.Error("State not found in map.") + } + if !equalState(*expectedState, *state, t) { + t.Errorf("State in map is incorrect.\nexpected: %+v\nreceived: %+v", + *expectedState, *state) + } + + timer := time.NewTimer(timeout * 2) + + select { + case results := <-callbackChan: + state, exists = p.singleUse[*rid] + if exists { + t.Errorf("State found in map when it should have been deleted."+ + "\nstate: %+v", state) + } + if results.payload != nil { + t.Errorf("Payload not nil on timeout.\npayload: %+v", results.payload) + } + if results.err == nil || !strings.Contains(results.err.Error(), "timed out") { + t.Errorf("Callback did not return a time out error on return: %+v", results.err) + } + case <-timer.C: + t.Error("Failed to time out.") + } +} + +// Error path: state already exists. +func Test_pending_addState_StateExistsError(t *testing.T) { + p := newPending() + rid := id.NewIdFromString("test RID", id.User, t) + dhKey := getGroup().NewInt(5) + maxMsgs := uint8(6) + timeout := 5 * time.Millisecond + callback, _ := createReplyComm() + + quitChan, _, err := p.addState(rid, dhKey, maxMsgs, callback, timeout) + if err != nil { + t.Errorf("addState() returned an error: %+v", err) + } + quitChan <- struct{}{} + + quitChan, _, err = p.addState(rid, dhKey, maxMsgs, callback, timeout) + if !check(err, "a state already exists in the map") { + t.Errorf("addState() did not return an error when the state already "+ + "exists: %+v", err) + } +} + +type replyCommData struct { + payload []byte + err error +} + +func createReplyComm() (func(payload []byte, err error), chan replyCommData) { + callbackChan := make(chan replyCommData) + callback := func(payload []byte, err error) { + callbackChan <- replyCommData{ + payload: payload, + err: err, + } + } + return callback, callbackChan +} + +// equalState determines if the two states have equal values. +func equalState(a, b state, t *testing.T) bool { + if a.dhKey.Cmp(b.dhKey) != 0 { + t.Errorf("DH Keys differ.\nexpected: %s\nreceived: %s", + a.dhKey.Text(10), b.dhKey.Text(10)) + return false + } + if !reflect.DeepEqual(a.fpMap.fps, b.fpMap.fps) { + t.Errorf("Fingerprint maps differ.\nexpected: %+v\nreceived: %+v", + a.fpMap.fps, b.fpMap.fps) + return false + } + if !reflect.DeepEqual(b.c, b.c) { + t.Errorf("collators differ.\nexpected: %+v\nreceived: %+v", + a.c, b.c) + return false + } + if reflect.ValueOf(a.callback).Pointer() != reflect.ValueOf(b.callback).Pointer() { + t.Errorf("callbackFuncs differ.\nexpected: %p\nreceived: %p", + a.callback, b.callback) + return false + } + return true +} + +// check returns true if the error is not nil and contains the substring. +func check(err error, subStr string) bool { + return err != nil && strings.Contains(err.Error(), subStr) +} diff --git a/single/transmission.go b/single/transmission.go new file mode 100644 index 0000000000000000000000000000000000000000..2e5493d34ee19149ba04cb65c6ad4d5a3203ac54 --- /dev/null +++ b/single/transmission.go @@ -0,0 +1,289 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package single + +import ( + "fmt" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + contact2 "gitlab.com/elixxir/client/interfaces/contact" + "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/interfaces/utility" + "gitlab.com/elixxir/client/storage/reception" + ds "gitlab.com/elixxir/comms/network/dataStructures" + "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/elixxir/crypto/e2e/auth" + "gitlab.com/elixxir/crypto/e2e/singleUse" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/elixxir/primitives/states" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/id/ephemeral" + "io" + "sync/atomic" + "time" +) + +// GetMaxTransmissionPayloadSize returns the maximum payload size for a +// transmission message. +func (m *Manager) GetMaxTransmissionPayloadSize() int { + // Generate empty messages to determine the available space for the payload + cmixPrimeSize := m.store.Cmix().GetGroup().GetP().ByteLen() + e2ePrimeSize := m.store.E2e().GetGroup().GetP().ByteLen() + cmixMsg := format.NewMessage(cmixPrimeSize) + transmitMsg := newTransmitMessage(cmixMsg.ContentsSize(), e2ePrimeSize) + msgPayload := newTransmitMessagePayload(transmitMsg.GetPayloadSize()) + + return msgPayload.GetMaxContentsSize() +} + +// TransmitSingleUse creates a CMIX message, sends it, and waits for delivery. +func (m *Manager) TransmitSingleUse(partner *contact2.Contact, payload []byte, + tag string, maxMsgs uint8, callback replyComm, timeout time.Duration) error { + + rngReader := m.rng.GetStream() + defer rngReader.Close() + + return m.transmitSingleUse(partner, payload, tag, maxMsgs, rngReader, + callback, timeout, m.net.GetInstance().GetRoundEvents()) +} + +// roundEvents interface allows custom round events to be passed in for testing. +type roundEvents interface { + AddRoundEventChan(id.Round, chan ds.EventReturn, time.Duration, + ...states.Round) *ds.EventCallback +} + +// transmitSingleUse has the fields passed in for easier testing. +func (m *Manager) transmitSingleUse(partner *contact2.Contact, payload []byte, + tag string, MaxMsgs uint8, rng io.Reader, callback replyComm, + timeout time.Duration, roundEvents roundEvents) error { + + // Get ephemeral ID address size; this will block until the client knows the + // address size if it is currently unknown + if m.store.Reception().IsIdSizeDefault() { + m.store.Reception().WaitForIdSizeUpdate() + } + addressSize := m.store.Reception().GetIDSize() + + // Create new CMIX message containing the transmission payload + cmixMsg, dhKey, rid, ephID, err := m.makeTransmitCmixMessage(partner, + payload, tag, MaxMsgs, addressSize, timeout, time.Now(), rng) + if err != nil { + return errors.Errorf("failed to create new CMIX message: %+v", err) + } + + jww.DEBUG.Printf("Created single-use transmission CMIX message with new ID "+ + "%s and ephemeral ID %d", rid, ephID.Int64()) + + timeStart := time.Now() + + // Add message state to map + quitChan, quit, err := m.p.addState(rid, dhKey, MaxMsgs, callback, timeout) + if err != nil { + return errors.Errorf("failed to add pending state: %+v", err) + } + + // Add identity for newly generated ID + err = m.reception.AddIdentity(reception.Identity{ + EphId: ephID, + Source: rid, + End: timeStart.Add(2 * timeout), + ExtraChecks: 10, + StartValid: timeStart.Add(-2 * timeout), + EndValid: timeStart.Add(2 * timeout), + RequestMask: 48*time.Hour - timeout, + Ephemeral: true, + }) + if err != nil { + errorString := fmt.Sprintf("failed to add new identity to "+ + "reception storage for single-use: %+v", err) + jww.ERROR.Print(errorString) + + // Exit the state timeout handler, delete the state from map, and + // return an error on the callback + quitChan <- struct{}{} + m.p.Lock() + delete(m.p.singleUse, *rid) + m.p.Unlock() + go callback(nil, errors.New(errorString)) + } + + go func() { + // Send Message + jww.DEBUG.Printf("Sending single-use transmission CMIX message to %s.", partner.ID) + round, _, err := m.net.SendCMIX(cmixMsg, partner.ID, params.GetDefaultCMIX()) + if err != nil { + errorString := fmt.Sprintf("failed to send single-use transmission "+ + "CMIX message: %+v", err) + jww.ERROR.Print(errorString) + + // Exit the state timeout handler, delete the state from map, and + // return an error on the callback + quitChan <- struct{}{} + m.p.Lock() + delete(m.p.singleUse, *rid) + m.p.Unlock() + go callback(nil, errors.New(errorString)) + } + + // Check if the state timeout handler has quit + if atomic.LoadInt32(quit) == 1 { + jww.ERROR.Print("Stopping to send single-use transmission CMIX " + + "message because the timeout handler quit.") + return + } + + // Update the timeout for the elapsed time + roundEventTimeout := timeout - time.Now().Sub(timeStart) - time.Millisecond + + // Check message delivery + sendResults := make(chan ds.EventReturn, 1) + roundEvents.AddRoundEventChan(round, sendResults, roundEventTimeout, + states.COMPLETED, states.FAILED) + + jww.DEBUG.Printf("Sent single-use transmission CMIX message to %s and "+ + "ephemeral ID %d on round %d.", partner.ID, ephID.Int64(), round) + + // Wait until the result tracking responds + success, numRoundFail, numTimeOut := utility.TrackResults(sendResults, 1) + if !success { + errorString := fmt.Sprintf("failed to send single-use transmission "+ + "message: %d round failures, %d round event time outs.", + numRoundFail, numTimeOut) + jww.ERROR.Print(errorString) + + // Exit the state timeout handler, delete the state from map, and + // return an error on the callback + quitChan <- struct{}{} + m.p.Lock() + delete(m.p.singleUse, *rid) + m.p.Unlock() + go callback(nil, errors.New(errorString)) + } + jww.DEBUG.Print("Tracked single-use transmission message round.") + }() + + return nil +} + +// makeTransmitCmixMessage generates a CMIX message containing the transmission message, +// which contains the encrypted payload. +func (m *Manager) makeTransmitCmixMessage(partner *contact2.Contact, + payload []byte, tag string, maxMsgs uint8, addressSize uint, + timeout time.Duration, timeNow time.Time, rng io.Reader) (format.Message, + *cyclic.Int, *id.ID, ephemeral.Id, error) { + e2eGrp := m.store.E2e().GetGroup() + + // Generate internal payloads based off key size to determine if the passed + // in payload is too large to fit in the available contents + cmixMsg := format.NewMessage(m.store.Cmix().GetGroup().GetP().ByteLen()) + transmitMsg := newTransmitMessage(cmixMsg.ContentsSize(), e2eGrp.GetP().ByteLen()) + msgPayload := newTransmitMessagePayload(transmitMsg.GetPayloadSize()) + + if msgPayload.GetMaxContentsSize() < len(payload) { + return format.Message{}, nil, nil, ephemeral.Id{}, + errors.Errorf("length of provided payload (%d) too long for message "+ + "payload capacity (%d).", len(payload), len(msgPayload.contents)) + } + + // Generate DH key and public key + dhKey, publicKey, err := generateDhKeys(e2eGrp, partner.DhPubKey, rng) + if err != nil { + return format.Message{}, nil, nil, ephemeral.Id{}, err + } + + // Compose payload + msgPayload.SetTagFP(singleUse.NewTagFP(tag)) + msgPayload.SetMaxParts(maxMsgs) + msgPayload.SetContents(payload) + + // Generate new user ID and ephemeral ID + rid, ephID, err := makeIDs(&msgPayload, publicKey, addressSize, timeout, + timeNow, rng) + if err != nil { + return format.Message{}, nil, nil, ephemeral.Id{}, + errors.Errorf("failed to generate IDs: %+v", err) + } + + // Encrypt payload + fp := singleUse.NewTransmitFingerprint(partner.DhPubKey) + key := singleUse.NewTransmitKey(dhKey) + encryptedPayload := auth.Crypt(key, fp[:24], msgPayload.Marshal()) + + // Generate CMIX message MAC + mac := singleUse.MakeMAC(key, encryptedPayload) + + // Compose transmission message + transmitMsg.SetPubKey(publicKey) + transmitMsg.SetPayload(encryptedPayload) + + // Compose CMIX message contents + cmixMsg.SetContents(transmitMsg.Marshal()) + cmixMsg.SetKeyFP(fp) + cmixMsg.SetMac(mac) + + return cmixMsg, dhKey, rid, ephID, nil +} + +// generateDhKeys generates a new public key and DH key. +func generateDhKeys(grp *cyclic.Group, dhPubKey *cyclic.Int, + rng io.Reader) (*cyclic.Int, *cyclic.Int, error) { + // Generate private key + privKeyBytes, err := csprng.GenerateInGroup(grp.GetP().Bytes(), + grp.GetP().ByteLen(), rng) + if err != nil { + return nil, nil, errors.Errorf("failed to generate key in group: %+v", + err) + } + privKey := grp.NewIntFromBytes(privKeyBytes) + + // Generate public key and DH key + publicKey := grp.ExpG(privKey, grp.NewInt(1)) + dhKey := grp.Exp(dhPubKey, privKey, grp.NewInt(1)) + + return dhKey, publicKey, nil +} + +// makeIDs generates a new user ID and ephemeral ID with a start and end within +// the given timout. The ID is generated from the unencrypted msg payload, which +// contains a nonce. If the generated ephemeral ID has a window that is not +// within +/- the given 2*timeout from now, then the IDs are generated again +// using a new nonce. +func makeIDs(msg *transmitMessagePayload, publicKey *cyclic.Int, addressSize uint, + timeout time.Duration, timeNow time.Time, rng io.Reader) (*id.ID, ephemeral.Id, error) { + var rid *id.ID + var ephID ephemeral.Id + + // Generate acceptable window for the ephemeral ID to exist in + windowStart, windowEnd := timeNow.Add(-2*timeout), timeNow.Add(2*timeout) + start, end := timeNow, timeNow + + // Loop until the ephemeral ID's start and end are within bounds + for windowStart.Before(start) || windowEnd.After(end) { + // Generate new nonce + err := msg.SetNonce(rng) + if err != nil { + return nil, ephemeral.Id{}, + errors.Errorf("failed to generate nonce: %+v", err) + } + + // Generate ID from unencrypted payload + rid = msg.GetRID(publicKey) + + // Generate the ephemeral ID + ephID, start, end, err = ephemeral.GetId(rid, addressSize, timeNow.UnixNano()) + if err != nil { + return nil, ephemeral.Id{}, errors.Errorf("failed to generate "+ + "ephemeral ID from newly generated ID: %+v", err) + } + jww.DEBUG.Printf("ephemeral.GetId(%s, %d, %d) = %d", rid, addressSize, timeNow.UnixNano(), ephID.Int64()) + } + + return rid, ephID, nil +} diff --git a/single/transmission_test.go b/single/transmission_test.go new file mode 100644 index 0000000000000000000000000000000000000000..85bf4ebf450f0537fae9ad950afa10234d1281cf --- /dev/null +++ b/single/transmission_test.go @@ -0,0 +1,456 @@ +package single + +import ( + "bytes" + contact2 "gitlab.com/elixxir/client/interfaces/contact" + pb "gitlab.com/elixxir/comms/mixmessages" + ds "gitlab.com/elixxir/comms/network/dataStructures" + "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/elixxir/crypto/e2e/auth" + "gitlab.com/elixxir/crypto/e2e/singleUse" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/elixxir/primitives/states" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/id/ephemeral" + "math/rand" + "reflect" + "strings" + "sync" + "testing" + "time" +) + +// Happy path. +func TestManager_GetMaxTransmissionPayloadSize(t *testing.T) { + m := newTestManager(0, false, t) + cmixPrimeSize := m.store.Cmix().GetGroup().GetP().ByteLen() + e2ePrimeSize := m.store.E2e().GetGroup().GetP().ByteLen() + expectedSize := 2*cmixPrimeSize - e2ePrimeSize - format.KeyFPLen - format.MacLen - format.RecipientIDLen - transmitPlMinSize + testSize := m.GetMaxTransmissionPayloadSize() + + if expectedSize != testSize { + t.Errorf("GetMaxTransmissionPayloadSize() failed to return the expected size."+ + "\nexpected: %d\nreceived: %d", expectedSize, testSize) + } +} + +// Happy path. +func TestManager_transmitSingleUse(t *testing.T) { + m := newTestManager(0, false, t) + prng := rand.New(rand.NewSource(42)) + partner := &contact2.Contact{ + ID: id.NewIdFromString("Contact ID", id.User, t), + DhPubKey: m.store.E2e().GetGroup().NewInt(5), + } + payload := make([]byte, 95) + rand.New(rand.NewSource(42)).Read(payload) + tag := "testTag" + maxMsgs := uint8(8) + callback, callbackChan := createReplyComm() + timeout := 15 * time.Millisecond + + err := m.transmitSingleUse(partner, payload, tag, maxMsgs, prng, + callback, timeout, newTestRoundEvents(false)) + if err != nil { + t.Errorf("transmitSingleUse() returned an error: %+v", err) + } + + for _, state := range m.p.singleUse { + state.quitChan <- struct{}{} + } + + expectedMsg, _, _, _, err := m.makeTransmitCmixMessage(partner, payload, + tag, maxMsgs, 32, 30*time.Second, time.Now(), rand.New(rand.NewSource(42))) + if err != nil { + t.Fatalf("Failed to make expected message: %+v", err) + } + + if !reflect.DeepEqual(expectedMsg, m.net.(*testNetworkManager).GetMsg(0)) { + t.Errorf("transmitSingleUse() failed to send the correct CMIX message."+ + "\nexpected: %+v\nreceived: %+v", + expectedMsg, m.net.(*testNetworkManager).GetMsg(0)) + } + + timer := time.NewTimer(timeout * 2) + + select { + case results := <-callbackChan: + t.Errorf("Callback called when the thread should have quit."+ + "\npayload: %+v\nerror: %+v", results.payload, results.err) + case <-timer.C: + } +} + +// Error path: function quits early if the timoutHandler quit. +func TestManager_transmitSingleUse_QuitChanError(t *testing.T) { + m := newTestManager(10*time.Millisecond, false, t) + partner := &contact2.Contact{ + ID: id.NewIdFromString("Contact ID", id.User, t), + DhPubKey: m.store.E2e().GetGroup().NewInt(5), + } + callback, callbackChan := createReplyComm() + timeout := 15 * time.Millisecond + + err := m.transmitSingleUse(partner, []byte{}, "testTag", 9, + rand.New(rand.NewSource(42)), callback, timeout, newTestRoundEvents(false)) + if err != nil { + t.Errorf("transmitSingleUse() returned an error: %+v", err) + } + + for _, state := range m.p.singleUse { + state.quitChan <- struct{}{} + } + + timer := time.NewTimer(2 * timeout) + + select { + case results := <-callbackChan: + if results.payload != nil || results.err != nil { + t.Errorf("Callback called when the timeout thread should have quit."+ + "\npayload: %+v\nerror: %+v", results.payload, results.err) + } + case <-timer.C: + } +} + +// Error path: fails to add a new identity. +func TestManager_transmitSingleUse_AddIdentityError(t *testing.T) { + timeout := 15 * time.Millisecond + m := newTestManager(timeout, false, t) + partner := &contact2.Contact{ + ID: id.NewIdFromString("Contact ID", id.User, t), + DhPubKey: m.store.E2e().GetGroup().NewInt(5), + } + callback, callbackChan := createReplyComm() + + err := m.transmitSingleUse(partner, []byte{}, "testTag", 9, + rand.New(rand.NewSource(42)), callback, timeout, newTestRoundEvents(false)) + if err != nil { + t.Errorf("transmitSingleUse() returned an error: %+v", err) + } + + for _, state := range m.p.singleUse { + state.quitChan <- struct{}{} + } + + timer := time.NewTimer(2 * timeout) + + select { + case results := <-callbackChan: + if results.payload != nil || !check(results.err, "Failed to add new identity") { + t.Errorf("Callback did not return the correct error when the "+ + "routine quit early.\npayload: %+v\nerror: %+v", + results.payload, results.err) + } + case <-timer.C: + } +} + +// Error path: SendCMIX fails to send message. +func TestManager_transmitSingleUse_SendCMIXError(t *testing.T) { + m := newTestManager(0, true, t) + partner := &contact2.Contact{ + ID: id.NewIdFromString("Contact ID", id.User, t), + DhPubKey: m.store.E2e().GetGroup().NewInt(5), + } + callback, callbackChan := createReplyComm() + timeout := 15 * time.Millisecond + + err := m.transmitSingleUse(partner, []byte{}, "testTag", 9, + rand.New(rand.NewSource(42)), callback, timeout, newTestRoundEvents(false)) + if err != nil { + t.Errorf("transmitSingleUse() returned an error: %+v", err) + } + + timer := time.NewTimer(timeout * 2) + + select { + case results := <-callbackChan: + if results.payload != nil || !check(results.err, "failed to send single-use transmission CMIX message") { + t.Errorf("Callback did not return the correct error when the "+ + "routine quit early.\npayload: %+v\nerror: %+v", + results.payload, results.err) + } + case <-timer.C: + } +} + +// Error path: failed to create CMIX message because the payload is too large. +func TestManager_transmitSingleUse_MakeTransmitCmixMessageError(t *testing.T) { + m := newTestManager(0, false, t) + prng := rand.New(rand.NewSource(42)) + payload := make([]byte, m.store.Cmix().GetGroup().GetP().ByteLen()) + + err := m.transmitSingleUse(nil, payload, "", 0, prng, nil, 0, nil) + if err == nil { + t.Error("transmitSingleUse() did not return an error when the payload " + + "is too large.") + } +} + +// Error path: failed to add pending state because is already exists. +func TestManager_transmitSingleUse_AddStateError(t *testing.T) { + m := newTestManager(0, false, t) + partner := &contact2.Contact{ + ID: id.NewIdFromString("Contact ID", id.User, t), + DhPubKey: m.store.E2e().GetGroup().NewInt(5), + } + payload := make([]byte, 95) + rand.New(rand.NewSource(42)).Read(payload) + tag := "testTag" + maxMsgs := uint8(8) + callback, _ := createReplyComm() + timeout := 15 * time.Millisecond + + // Create new CMIX and add a state + _, dhKey, rid, _, err := m.makeTransmitCmixMessage(partner, payload, tag, + maxMsgs, 32, 30*time.Second, time.Now(), rand.New(rand.NewSource(42))) + if err != nil { + t.Fatalf("Failed to create new CMIX message: %+v", err) + } + m.p.singleUse[*rid] = newState(dhKey, maxMsgs, nil) + + err = m.transmitSingleUse(partner, payload, tag, maxMsgs, + rand.New(rand.NewSource(42)), callback, timeout, nil) + if !check(err, "failed to add pending state") { + t.Errorf("transmitSingleUse() failed to error when on adding state "+ + "when the state already exists: %+v", err) + } +} + +// Error path: timeout occurs on tracking results of round. +func TestManager_transmitSingleUse_RoundTimeoutError(t *testing.T) { + m := newTestManager(0, false, t) + prng := rand.New(rand.NewSource(42)) + partner := &contact2.Contact{ + ID: id.NewIdFromString("Contact ID", id.User, t), + DhPubKey: m.store.E2e().GetGroup().NewInt(5), + } + payload := make([]byte, 95) + rand.New(rand.NewSource(42)).Read(payload) + callback, callbackChan := createReplyComm() + timeout := 15 * time.Millisecond + + err := m.transmitSingleUse(partner, payload, "testTag", 8, prng, callback, + timeout, newTestRoundEvents(true)) + if err != nil { + t.Errorf("transmitSingleUse() returned an error: %+v", err) + } + + timer := time.NewTimer(timeout * 2) + + select { + case results := <-callbackChan: + if results.payload != nil || !check(results.err, "round failures") { + t.Errorf("Callback did not return the correct error when it "+ + "should have timed out.\npayload: %+v\nerror: %+v", + results.payload, results.err) + } + case <-timer.C: + } +} + +// Happy path +func TestManager_makeTransmitCmixMessage(t *testing.T) { + m := newTestManager(0, false, t) + prng := rand.New(rand.NewSource(42)) + partner := &contact2.Contact{ + ID: id.NewIdFromString("recipientID", id.User, t), + DhPubKey: m.store.E2e().GetGroup().NewInt(42), + } + tag := "Test tag" + payload := make([]byte, 132) + rand.New(rand.NewSource(42)).Read(payload) + maxMsgs := uint8(8) + timeNow := time.Now() + + msg, dhKey, rid, ephID, err := m.makeTransmitCmixMessage(partner, payload, + tag, maxMsgs, 32, 30*time.Second, timeNow, prng) + + if err != nil { + t.Errorf("makeTransmitCmixMessage() produced an error: %+v", err) + } + + fp := singleUse.NewTransmitFingerprint(partner.DhPubKey) + key := singleUse.NewTransmitKey(dhKey) + + encPayload, err := unmarshalTransmitMessage(msg.GetContents(), + m.store.E2e().GetGroup().GetP().ByteLen()) + if err != nil { + t.Errorf("Failed to unmarshal contents: %+v", err) + } + + decryptedPayload, err := unmarshalTransmitMessagePayload(auth.Crypt(key, + fp[:24], encPayload.GetPayload())) + if err != nil { + t.Errorf("Failed to unmarshal payload: %+v", err) + } + + if !bytes.Equal(payload, decryptedPayload.GetContents()) { + t.Errorf("Failed to decrypt payload.\nexpected: %+v\nreceived: %+v", + payload, decryptedPayload.GetContents()) + } + + if !singleUse.VerifyMAC(key, encPayload.GetPayload(), msg.GetMac()) { + t.Error("Failed to verify the message MAC.") + } + + if fp != msg.GetKeyFP() { + t.Errorf("Failed to verify the CMIX message fingperprint."+ + "\nexpected: %s\nreceived: %s", fp, msg.GetKeyFP()) + } + + if maxMsgs != decryptedPayload.GetMaxParts() { + t.Errorf("Incorrect maxMsgs.\nexpected: %d\nreceived: %d", + maxMsgs, decryptedPayload.GetMaxParts()) + } + + expectedTagFP := singleUse.NewTagFP(tag) + if decryptedPayload.GetTagFP() != expectedTagFP { + t.Errorf("Incorrect TagFP.\nexpected: %s\nreceived: %s", + expectedTagFP, decryptedPayload.GetTagFP()) + } + + if !rid.Cmp(decryptedPayload.GetRID(encPayload.GetPubKey(m.store.E2e().GetGroup()))) { + t.Errorf("Returned incorrect recipient ID.\nexpected: %s\nreceived: %s", + decryptedPayload.GetRID(encPayload.GetPubKey(m.store.E2e().GetGroup())), rid) + } + + expectedEphID, _, _, err := ephemeral.GetId(rid, uint(len(rid)), timeNow.UnixNano()) + if err != nil { + t.Fatalf("Failed to generate expected ephemeral ID: %+v", err) + } + + if expectedEphID != ephID { + t.Errorf("Returned incorrect ephemeral ID.\nexpected: %d\nreceived: %d", + expectedEphID.Int64(), ephID.Int64()) + } +} + +// Error path: supplied payload to large for message. +func TestManager_makeTransmitCmixMessage_PayloadTooLargeError(t *testing.T) { + m := newTestManager(0, false, t) + prng := rand.New(rand.NewSource(42)) + payload := make([]byte, 1000) + rand.New(rand.NewSource(42)).Read(payload) + + _, _, _, _, err := m.makeTransmitCmixMessage(nil, payload, "", 8, 32, + 30*time.Second, time.Now(), prng) + + if !check(err, "too long for message payload capacity") { + t.Errorf("makeTransmitCmixMessage() failed to error when the payload is too "+ + "large: %+v", err) + } +} + +// Error path: key generation fails. +func TestManager_makeTransmitCmixMessage_KeyGenerationError(t *testing.T) { + m := newTestManager(0, false, t) + prng := strings.NewReader("a") + partner := &contact2.Contact{ + ID: id.NewIdFromString("recipientID", id.User, t), + DhPubKey: m.store.E2e().GetGroup().NewInt(42), + } + + _, _, _, _, err := m.makeTransmitCmixMessage(partner, nil, "", 8, 32, + 30*time.Second, time.Now(), prng) + + if !check(err, "failed to generate key in group") { + t.Errorf("makeTransmitCmixMessage() failed to error when key "+ + "generation failed: %+v", err) + } +} + +// Happy path: test for consistency. +func Test_makeIDs_Consistency(t *testing.T) { + m := newTestManager(0, false, t) + cmixMsg := format.NewMessage(m.store.Cmix().GetGroup().GetP().ByteLen()) + transmitMsg := newTransmitMessage(cmixMsg.ContentsSize(), m.store.E2e().GetGroup().GetP().ByteLen()) + msgPayload := newTransmitMessagePayload(transmitMsg.GetPayloadSize()) + msgPayload.SetTagFP(singleUse.NewTagFP("tag")) + msgPayload.SetMaxParts(8) + msgPayload.SetContents([]byte("payload")) + _, publicKey, err := generateDhKeys(m.store.E2e().GetGroup(), + m.store.E2e().GetGroup().NewInt(42), rand.New(rand.NewSource(42))) + if err != nil { + t.Fatalf("Failed to generate public key: %+v", err) + } + addressSize := uint(32) + + expectedPayload, err := unmarshalTransmitMessagePayload(msgPayload.Marshal()) + if err != nil { + t.Fatalf("Failed to copy payload: %+v", err) + } + + err = expectedPayload.SetNonce(rand.New(rand.NewSource(42))) + if err != nil { + t.Fatalf("Failed to set nonce: %+v", err) + } + + timeNow := time.Now() + + rid, ephID, err := makeIDs(&msgPayload, publicKey, addressSize, + 30*time.Second, timeNow, rand.New(rand.NewSource(42))) + if err != nil { + t.Errorf("makeIDs() returned an error: %+v", err) + } + + if expectedPayload.GetNonce() != msgPayload.GetNonce() { + t.Errorf("makeIDs() failed to set the expected nonce."+ + "\nexpected: %d\nreceived: %d", expectedPayload.GetNonce(), msgPayload.GetNonce()) + } + + if !expectedPayload.GetRID(publicKey).Cmp(rid) { + t.Errorf("makeIDs() did not return the expected ID."+ + "\nexpected: %s\nreceived: %s", expectedPayload.GetRID(publicKey), rid) + } + + expectedEphID, _, _, err := ephemeral.GetId(expectedPayload.GetRID(publicKey), + addressSize, timeNow.UnixNano()) + if err != nil { + t.Fatalf("Failed to generate expected ephemeral ID: %+v", err) + } + + if expectedEphID != ephID { + t.Errorf("makeIDs() did not return the expected ephemeral ID."+ + "\nexpected: %d\nreceived: %d", expectedEphID.Int64(), ephID.Int64()) + } +} + +// Error path: failed to generate nonce. +func Test_makeIDs_NonceError(t *testing.T) { + msgPayload := newTransmitMessagePayload(transmitPlMinSize) + + _, _, err := makeIDs(&msgPayload, &cyclic.Int{}, 32, 30*time.Second, + time.Now(), strings.NewReader("")) + if !check(err, "failed to generate nonce") { + t.Errorf("makeIDs() did not return an error when failing to make nonce: %+v", err) + } +} + +type testRoundEvents struct { + callbacks map[id.Round][states.NUM_STATES]map[*ds.EventCallback]*ds.EventCallback + timeoutError bool + mux sync.RWMutex +} + +func newTestRoundEvents(timeoutError bool) *testRoundEvents { + return &testRoundEvents{ + callbacks: make(map[id.Round][states.NUM_STATES]map[*ds.EventCallback]*ds.EventCallback), + timeoutError: timeoutError, + } +} + +func (r *testRoundEvents) AddRoundEventChan(_ id.Round, + eventChan chan ds.EventReturn, _ time.Duration, _ ...states.Round) *ds.EventCallback { + + eventChan <- struct { + RoundInfo *pb.RoundInfo + TimedOut bool + }{ + RoundInfo: &pb.RoundInfo{State: uint32(states.COMPLETED)}, + TimedOut: r.timeoutError, + } + + return nil +} diff --git a/single/transmitMessage.go b/single/transmitMessage.go index 31e1ae5db51964ebf44623dd201e935e359bbd01..768c47e2697dfe38c731ae0fbd2ffb7759442755 100644 --- a/single/transmitMessage.go +++ b/single/transmitMessage.go @@ -8,12 +8,27 @@ package single import ( + "encoding/binary" "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/elixxir/crypto/e2e/singleUse" "gitlab.com/xx_network/primitives/id" + "io" ) +/* ++-----------------------------------------------------------------+ +| CMIX Message Contents | ++------------+----------------------------------------------------+ +| pubKey | payload (transmitMessagePayload) | +| pubKeySize | externalPayloadSize - pubKeySize | ++------------+----------+---------+----------+---------+----------+ + | Tag FP | nonce | maxParts | size | contents | + | 16 bytes | 8 bytes | 1 byte | 2 bytes | variable | + +----------+---------+----------+---------+----------+ +*/ + type transmitMessage struct { data []byte // Serial of all contents pubKey []byte @@ -24,7 +39,7 @@ type transmitMessage struct { // size of the specified external payload. func newTransmitMessage(externalPayloadSize, pubKeySize int) transmitMessage { if externalPayloadSize < pubKeySize { - jww.FATAL.Panicf("Payload size of single use transmission message "+ + jww.FATAL.Panicf("Payload size of single-use transmission message "+ "(%d) too small to contain the public key (%d).", externalPayloadSize, pubKeySize) } @@ -87,7 +102,7 @@ func (m transmitMessage) GetPayloadSize() int { // size is correct. func (m transmitMessage) SetPayload(b []byte) { if len(b) != len(m.payload) { - jww.FATAL.Panicf("Size of payload of single use transmission message "+ + jww.FATAL.Panicf("Size of payload of single-use transmission message "+ "(%d) is not the same as the size of the supplied payload (%d).", len(m.payload), len(b)) } @@ -95,13 +110,21 @@ func (m transmitMessage) SetPayload(b []byte) { copy(m.payload, b) } -const numSize = 1 +const ( + tagFPSize = singleUse.TagFpSize + nonceSize = 8 + maxPartsSize = 1 + sizeSize = 2 + transmitPlMinSize = tagFPSize + nonceSize + maxPartsSize + sizeSize +) // transmitMessagePayload is the structure of transmitMessage's payload. type transmitMessagePayload struct { data []byte // Serial of all contents - rid []byte // Response reception ID - num []byte // Number of messages expected in response + tagFP []byte // Tag fingerprint identifies the type of message + nonce []byte + maxParts []byte // Max number of messages expected in response + size []byte // Size of the contents contents []byte } @@ -109,35 +132,41 @@ type transmitMessagePayload struct { // size of the specified payload, which should match the size of the payload in // the corresponding transmitMessage. func newTransmitMessagePayload(payloadSize int) transmitMessagePayload { - if payloadSize < id.ArrIDLen+numSize { - jww.FATAL.Panicf("Size of single use transmission message payload "+ - "(%d) too small to contain the reception ID (%d) + the message "+ - "count (%d).", - payloadSize, id.ArrIDLen, numSize) + if payloadSize < transmitPlMinSize { + jww.FATAL.Panicf("Size of single-use transmission message payload "+ + "(%d) too small to contain the necessary data (%d).", + payloadSize, transmitPlMinSize) } - return mapTransmitMessagePayload(make([]byte, payloadSize)) + // Map fields to data + mp := mapTransmitMessagePayload(make([]byte, payloadSize)) + + return mp } // mapTransmitMessagePayload builds a message payload mapped to the passed in // data. It is mapped by reference; a copy is not made. func mapTransmitMessagePayload(data []byte) transmitMessagePayload { - return transmitMessagePayload{ + mp := transmitMessagePayload{ data: data, - rid: data[:id.ArrIDLen], - num: data[id.ArrIDLen : id.ArrIDLen+numSize], - contents: data[id.ArrIDLen+numSize:], + tagFP: data[:tagFPSize], + nonce: data[tagFPSize : tagFPSize+nonceSize], + maxParts: data[tagFPSize+nonceSize : tagFPSize+nonceSize+maxPartsSize], + size: data[tagFPSize+nonceSize+maxPartsSize : transmitPlMinSize], + contents: data[transmitPlMinSize:], } + + return mp } // unmarshalTransmitMessagePayload unmarshalls a byte slice into a // transmitMessagePayload. An error is returned if the slice is not large enough // for the reception ID and message count. func unmarshalTransmitMessagePayload(b []byte) (transmitMessagePayload, error) { - if len(b) < id.ArrIDLen+numSize { + if len(b) < transmitPlMinSize { return transmitMessagePayload{}, errors.Errorf("Length of marshaled "+ - "bytes(%d) too small to contain the reception ID (%d) + the "+ - "message count (%d).", len(b), id.ArrIDLen, numSize) + "bytes(%d) too small to contain the necessary data (%d).", + len(b), transmitPlMinSize) } return mapTransmitMessagePayload(b), nil @@ -148,49 +177,72 @@ func (mp transmitMessagePayload) Marshal() []byte { return mp.data } -// GetRID returns the reception ID. -func (mp transmitMessagePayload) GetRID() *id.ID { - rid, err := id.Unmarshal(mp.rid) - if err != nil { - jww.FATAL.Panicf("Failed to unmarshal transmission ID of single use "+ - "transmission message payload: %+v", err) - } +// GetRID generates the reception ID from the bytes of the payload. +func (mp transmitMessagePayload) GetRID(pubKey *cyclic.Int) *id.ID { + return singleUse.NewRecipientID(pubKey, mp.Marshal()) +} - return rid +// GetTagFP returns the tag fingerprint. +func (mp transmitMessagePayload) GetTagFP() singleUse.TagFP { + return singleUse.UnmarshalTagFP(mp.tagFP) } -// SetRID sets the reception ID of the payload. -func (mp transmitMessagePayload) SetRID(rid *id.ID) { - copy(mp.rid, rid.Marshal()) +// SetTagFP sets the tag fingerprint. +func (mp transmitMessagePayload) SetTagFP(tagFP singleUse.TagFP) { + copy(mp.tagFP, tagFP.Bytes()) } -// GetCount returns the number of messages expected in response. -func (mp transmitMessagePayload) GetCount() uint8 { - return mp.num[0] +// GetNonce returns the nonce as a uint64. +func (mp transmitMessagePayload) GetNonce() uint64 { + return binary.BigEndian.Uint64(mp.nonce) } -// SetCount sets the number of expected messages. -func (mp transmitMessagePayload) SetCount(num uint8) { - copy(mp.num, []byte{num}) +// SetNonce generates a random nonce from the RNG. An error is returned if the +// reader fails. +func (mp transmitMessagePayload) SetNonce(rng io.Reader) error { + if _, err := rng.Read(mp.nonce); err != nil { + return errors.Errorf("failed to generate nonce: %+v", err) + } + + return nil +} + +// GetMaxParts returns the number of messages expected in response. +func (mp transmitMessagePayload) GetMaxParts() uint8 { + return mp.maxParts[0] +} + +// SetMaxParts sets the number of expected messages. +func (mp transmitMessagePayload) SetMaxParts(num uint8) { + copy(mp.maxParts, []byte{num}) } // GetContents returns the payload's contents. func (mp transmitMessagePayload) GetContents() []byte { - return mp.contents + return mp.contents[:binary.BigEndian.Uint16(mp.size)] } // GetContentsSize returns the length of payload's contents. func (mp transmitMessagePayload) GetContentsSize() int { + return int(binary.BigEndian.Uint16(mp.size)) +} + +// GetMaxContentsSize returns the max capacity of the contents. +func (mp transmitMessagePayload) GetMaxContentsSize() int { return len(mp.contents) } -// SetContents saves the contents to the payload, if the size is correct. -func (mp transmitMessagePayload) SetContents(b []byte) { - if len(b) != len(mp.contents) { - jww.FATAL.Panicf("Size of content of single use transmission message "+ - "payload (%d) is not the same as the size of the supplied "+ - "contents (%d).", len(mp.contents), len(b)) +// SetContents saves the contents to the payload, if the size is correct. Does +// not zero out previous content. +func (mp transmitMessagePayload) SetContents(contents []byte) { + if len(contents) > len(mp.contents) { + jww.FATAL.Panicf("Failed to set contents of single-use transmission "+ + "message: max size of message content (%d) is smaller than the "+ + "size of the supplied contents (%d).", + len(mp.contents), len(contents)) } - copy(mp.contents, b) + binary.BigEndian.PutUint16(mp.size, uint16(len(contents))) + + copy(mp.contents, contents) } diff --git a/single/transmitMessage_test.go b/single/transmitMessage_test.go index 75bf80a01daea28d194a9ed49fddd4726939b37f..833fd5f4b91530547684e8a4f117ab22f628da43 100644 --- a/single/transmitMessage_test.go +++ b/single/transmitMessage_test.go @@ -1,12 +1,22 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + package single import ( "bytes" + "encoding/binary" "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/elixxir/crypto/e2e/singleUse" + "gitlab.com/elixxir/primitives/format" "gitlab.com/xx_network/crypto/large" - "gitlab.com/xx_network/primitives/id" "math/rand" "reflect" + "strings" "testing" ) @@ -32,7 +42,7 @@ func Test_newTransmitMessage(t *testing.T) { // Error path: public key size is larger than external payload size. func Test_newTransmitMessage_PubKeySizeError(t *testing.T) { defer func() { - if r := recover(); r == nil { + if r := recover(); r == nil || !strings.Contains(r.(string), "Payload size") { t.Error("newTransmitMessage() did not panic when the size of " + "the payload is smaller than the size of the public key.") } @@ -135,13 +145,13 @@ func TestTransmitMessage_SetPayload_GetPayload_GetPayloadSize(t *testing.T) { testPayload := m.GetPayload() if !bytes.Equal(payload, testPayload) { - t.Errorf("GetPayload() returned incorrect payload."+ + t.Errorf("GetContents() returned incorrect payload."+ "\nexpected: %+v\nreceived: %+v", payload, testPayload) } payloadSize := externalPayloadSize - pubKeySize if payloadSize != m.GetPayloadSize() { - t.Errorf("GetPayloadSize() returned incorrect payload size."+ + t.Errorf("GetContentsSize() returned incorrect content size."+ "\nexpected: %d\nreceived: %d", payloadSize, m.GetPayloadSize()) } } @@ -149,9 +159,9 @@ func TestTransmitMessage_SetPayload_GetPayload_GetPayloadSize(t *testing.T) { // Error path: supplied payload is not the same size as message payload. func TestTransmitMessage_SetPayload_PayloadSizeError(t *testing.T) { defer func() { - if r := recover(); r == nil { - t.Error("SetPayload() did not panic when the size of supplied " + - "payload is not the same size as message payload.") + if r := recover(); r == nil || !strings.Contains(r.(string), "is not the same as the size") { + t.Error("SetContents() did not panic when the size of supplied " + + "contents is not the same size as message contents.") } }() @@ -165,23 +175,25 @@ func Test_newTransmitMessagePayload(t *testing.T) { payloadSize := prng.Intn(2000) expected := transmitMessagePayload{ data: make([]byte, payloadSize), - rid: make([]byte, id.ArrIDLen), - num: make([]byte, numSize), - contents: make([]byte, payloadSize-id.ArrIDLen-numSize), + tagFP: make([]byte, tagFPSize), + nonce: make([]byte, nonceSize), + maxParts: make([]byte, maxPartsSize), + size: make([]byte, sizeSize), + contents: make([]byte, payloadSize-transmitPlMinSize), } mp := newTransmitMessagePayload(payloadSize) if !reflect.DeepEqual(expected, mp) { t.Errorf("newTransmitMessagePayload() did not produce the expected "+ - "transmitMessagePayload.\nexpected: %#v\nreceived: %#v", expected, mp) + "transmitMessagePayload.\nexpected: %+v\nreceived: %+v", expected, mp) } } -// Error path: payload size is smaller than than rid size + num size. +// Error path: payload size is smaller than than rid size + maxParts size. func Test_newTransmitMessagePayload_PayloadSizeError(t *testing.T) { defer func() { - if r := recover(); r == nil { + if r := recover(); r == nil || !strings.Contains(r.(string), "Size of single-use transmission message payload") { t.Error("newTransmitMessagePayload() did not panic when the size " + "of the payload is smaller than the size of the reception ID " + "+ the message count.") @@ -194,13 +206,17 @@ func Test_newTransmitMessagePayload_PayloadSizeError(t *testing.T) { // Happy path. func Test_mapTransmitMessagePayload(t *testing.T) { prng := rand.New(rand.NewSource(42)) - rid := id.NewIdFromUInt(prng.Uint64(), id.User, t) + tagFP := singleUse.NewTagFP("Tag") + nonceBytes := make([]byte, nonceSize) num := uint8(prng.Uint64()) + size := []byte{uint8(prng.Uint64()), uint8(prng.Uint64())} contents := make([]byte, prng.Intn(1000)) prng.Read(contents) var data []byte - data = append(data, rid.Bytes()...) + data = append(data, tagFP.Bytes()...) + data = append(data, nonceBytes...) data = append(data, num) + data = append(data, size...) data = append(data, contents...) mp := mapTransmitMessagePayload(data) @@ -209,14 +225,24 @@ func Test_mapTransmitMessagePayload(t *testing.T) { "for data.\nexpected: %+v\nreceived: %+v", data, mp.data) } - if !bytes.Equal(rid.Bytes(), mp.rid) { + if !bytes.Equal(tagFP.Bytes(), mp.tagFP) { + t.Errorf("mapTransmitMessagePayload() failed to map the correct bytes "+ + "for tagFP.\nexpected: %+v\nreceived: %+v", tagFP.Bytes(), mp.tagFP) + } + + if !bytes.Equal(nonceBytes, mp.nonce) { + t.Errorf("mapTransmitMessagePayload() failed to map the correct bytes "+ + "for the nonce.\nexpected: %s\nreceived: %s", nonceBytes, mp.nonce) + } + + if num != mp.maxParts[0] { t.Errorf("mapTransmitMessagePayload() failed to map the correct bytes "+ - "for rid.\nexpected: %+v\nreceived: %+v", rid.Bytes(), mp.rid) + "for maxParts.\nexpected: %d\nreceived: %d", num, mp.maxParts[0]) } - if num != mp.num[0] { + if !bytes.Equal(size, mp.size) { t.Errorf("mapTransmitMessagePayload() failed to map the correct bytes "+ - "for num.\nexpected: %d\nreceived: %d", num, mp.num[0]) + "for size.\nexpected: %+v\nreceived: %+v", size, mp.size) } if !bytes.Equal(contents, mp.contents) { @@ -255,41 +281,69 @@ func Test_unmarshalTransmitMessagePayload(t *testing.T) { } // Happy path. -func TestTransmitMessagePayload_SetRID_GetRID(t *testing.T) { +func TestTransmitMessagePayload_GetRID(t *testing.T) { prng := rand.New(rand.NewSource(42)) mp := newTransmitMessagePayload(prng.Intn(2000)) - rid := id.NewIdFromUInt(prng.Uint64(), id.User, t) + expectedRID := singleUse.NewRecipientID(getGroup().NewInt(42), mp.Marshal()) - mp.SetRID(rid) - testRID := mp.GetRID() + testRID := mp.GetRID(getGroup().NewInt(42)) - if !rid.Cmp(testRID) { + if !expectedRID.Cmp(testRID) { t.Errorf("GetRID() did not return the expected ID."+ - "\nexpected: %s\nreceived: %s", rid, testRID) + "\nexpected: %s\nreceived: %s", expectedRID, testRID) + } +} + +// Happy path. +func Test_transmitMessagePayload_SetNonce_GetNonce(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + mp := newTransmitMessagePayload(prng.Intn(2000)) + + expectedNonce := prng.Uint64() + expectedNonceBytes := make([]byte, 8) + binary.BigEndian.PutUint64(expectedNonceBytes, expectedNonce) + err := mp.SetNonce(strings.NewReader(string(expectedNonceBytes))) + if err != nil { + t.Errorf("SetNonce() produced an error: %+v", err) + } + + if expectedNonce != mp.GetNonce() { + t.Errorf("GetNonce() did not return the expected nonce."+ + "\nexpected: %d\nreceived: %d", expectedNonce, mp.GetNonce()) + } +} + +// Error path: RNG return an error. +func Test_transmitMessagePayload_SetNonce_RngError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + mp := newTransmitMessagePayload(prng.Intn(2000)) + err := mp.SetNonce(strings.NewReader("")) + if !check(err, "failed to generate nonce") { + t.Errorf("SetNonce() did not return an error when nonce generation "+ + "fails: %+v", err) } } // Happy path. -func TestTransmitMessagePayload_SetCount_GetCount(t *testing.T) { +func TestTransmitMessagePayload_SetMaxParts_GetMaxParts(t *testing.T) { prng := rand.New(rand.NewSource(42)) mp := newTransmitMessagePayload(prng.Intn(2000)) count := uint8(prng.Uint64()) - mp.SetCount(count) - testCount := mp.GetCount() + mp.SetMaxParts(count) + testCount := mp.GetMaxParts() if count != testCount { - t.Errorf("GetCount() did not return the expected count."+ + t.Errorf("GetMaxParts() did not return the expected count."+ "\nexpected: %d\nreceived: %d", count, testCount) } } // Happy path. -func TestTransmitMessagePayload_SetContents_GetContents_GetContentsSize(t *testing.T) { +func TestTransmitMessagePayload_SetContents_GetContents_GetContentsSize_GetMaxContentsSize(t *testing.T) { prng := rand.New(rand.NewSource(42)) - payloadSize := prng.Intn(2000) - mp := newTransmitMessagePayload(payloadSize) - contentsSize := payloadSize - id.ArrIDLen - numSize + mp := newTransmitMessagePayload(format.MinimumPrimeSize) + contentsSize := (format.MinimumPrimeSize - transmitPlMinSize) / 2 contents := make([]byte, contentsSize) prng.Read(contents) @@ -304,19 +358,24 @@ func TestTransmitMessagePayload_SetContents_GetContents_GetContentsSize(t *testi t.Errorf("GetContentsSize() did not return the expected size."+ "\nexpected: %d\nreceived: %d", contentsSize, mp.GetContentsSize()) } + + if format.MinimumPrimeSize-transmitPlMinSize != mp.GetMaxContentsSize() { + t.Errorf("GetMaxContentsSize() did not return the expected size."+ + "\nexpected: %d\nreceived: %d", format.MinimumPrimeSize-transmitPlMinSize, mp.GetMaxContentsSize()) + } } // Error path: supplied bytes are smaller than payload contents. func TestTransmitMessagePayload_SetContents(t *testing.T) { defer func() { - if r := recover(); r == nil { + if r := recover(); r == nil || !strings.Contains(r.(string), "max size of message content") { t.Error("SetContents() did not panic when the size of the " + "supplied bytes is not the same as the payload content size.") } }() - mp := newTransmitMessagePayload(255) - mp.SetContents([]byte{1, 2, 3}) + mp := newTransmitMessagePayload(format.MinimumPrimeSize) + mp.SetContents(make([]byte, format.MinimumPrimeSize+1)) } func getGroup() *cyclic.Group { diff --git a/storage/reception/store.go b/storage/reception/store.go index df6aada4a8edb61d0466ce89c9a7fb6c4f12a285..0802ccc9832f4a4c9127fdd1e1a064de42318d08 100644 --- a/storage/reception/store.go +++ b/storage/reception/store.go @@ -6,12 +6,10 @@ import ( "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/storage/versioned" - "gitlab.com/elixxir/crypto/hash" - "gitlab.com/xx_network/crypto/randomness" + "gitlab.com/xx_network/crypto/large" "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/id/ephemeral" "io" - "math/big" "strconv" "sync" "time" @@ -66,7 +64,7 @@ func NewStore(kv *versioned.KV) *Store { func LoadStore(kv *versioned.KV) *Store { kv = kv.Prefix(receptionPrefix) s := &Store{ - kv: kv, + kv: kv, idSizeCond: sync.NewCond(&sync.Mutex{}), } @@ -190,8 +188,8 @@ func (s *Store) AddIdentity(identity Identity) error { s.mux.Lock() defer s.mux.Unlock() - if identity.StartValid.After(identity.EndValid){ - return errors.Errorf("Cannot add an identity which start valid " + + if identity.StartValid.After(identity.EndValid) { + return errors.Errorf("Cannot add an identity which start valid "+ "time (%s) is after its end valid time(%s)", identity.StartValid, identity.EndValid) } @@ -335,13 +333,8 @@ func (s *Store) selectIdentity(rng io.Reader, now time.Time) (IdentityUse, error return IdentityUse{}, errors.WithMessage(err, "Failed to "+ "choose ID due to rng failure") } - h, err := hash.NewCMixHash() - if err != nil { - return IdentityUse{}, err - } - selectedNum := randomness.RandInInterval( - big.NewInt(int64(len(s.active)-1)), seed, h) + selectedNum := large.NewInt(1).Mod(large.NewIntFromBytes(seed), large.NewInt(int64(len(s.active)))) selected = s.active[selectedNum.Uint64()] } @@ -349,6 +342,10 @@ func (s *Store) selectIdentity(rng io.Reader, now time.Time) (IdentityUse, error selected.ExtraChecks-- } + jww.DEBUG.Printf("Selected identity: EphId: %d ID: %s End: %s StartValid: %s EndValid: %s", + selected.EphId.Int64(), selected.Source, selected.End.Format("01/02/06 03:04:05 pm"), + selected.StartValid.Format("01/02/06 03:04:05 pm"), selected.EndValid.Format("01/02/06 03:04:05 pm")) + return IdentityUse{ Identity: selected.Identity, Fake: false,