diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 8f259f54562711daaefe6a1b1394ac0d2197f6ed..51e3b61ecd7b4bb3ba4a755c5c232cf3d2a3d1b7 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -85,25 +85,33 @@ tag: - git tag $(release/client.linux64 version | grep "Elixxir Client v"| cut -d ' ' -f3) -f - git push origin_tags -f --tags -bindings: +bindings-ios: stage: build except: - tags tags: - ios script: - - export PATH="/usr/local/opt/go@1.13/bin:$PATH" - go get -u golang.org/x/mobile/cmd/gomobile - - go get -u golang.org/x/mobile/bind - - rm -rf $HOME/go/src/gitlab.com/elixxir/client/ - - mkdir -p $HOME/go/src/gitlab.com/elixxir/client/ - - cp -r * $HOME/go/src/gitlab.com/elixxir/client/ - - GO111MODULE=on gomobile bind -target android -androidapi 21 gitlab.com/elixxir/client/bindings - - GO111MODULE=on gomobile bind -target ios gitlab.com/elixxir/client/bindings + - gomobile bind -target ios gitlab.com/elixxir/client/bindings - zip -r iOS.zip Bindings.framework artifacts: paths: - iOS.zip + +bindings-android: + stage: build + except: + - tags + tags: + - android + script: + - export ANDROID_HOME=/android-sdk + - export PATH=$PATH:/android-sdk/platform-tools/:/usr/local/go/bin + - go get -u golang.org/x/mobile/cmd/gomobile + - gomobile bind -target android -androidapi 21 gitlab.com/elixxir/client/bindings + artifacts: + paths: - bindings.aar - bindings-sources.jar diff --git a/api/authenticatedChannel.go b/api/authenticatedChannel.go index 88682af983c107345b6415648bec9558f2eb4d6b..24b5a8e4af3c771d59dabc3e01d6ffa1b8125a17 100644 --- a/api/authenticatedChannel.go +++ b/api/authenticatedChannel.go @@ -125,3 +125,17 @@ func (c *Client) MakePrecannedContact(precannedID uint) contact.Contact { Facts: make([]fact.Fact, 0), } } + +// GetRelationshipFingerprint returns a unique 15 character fingerprint for an +// E2E relationship. An error is returned if no relationship with the partner +// is found. +func (c *Client) GetRelationshipFingerprint(partner *id.ID) (string, error) { + m, err := c.storage.E2e().GetPartner(partner) + if err != nil { + return "", errors.Errorf("could not get partner %s: %+v", partner, err) + } else if m == nil { + return "", errors.Errorf("manager for partner %s is nil.", partner) + } + + return m.GetRelationshipFingerprint(), nil +} diff --git a/api/client.go b/api/client.go index 7880fa697821243fd6b262389aadcc4e4e525789..e6fddd7f457db98ee26788e0b8c4676cd06277b9 100644 --- a/api/client.go +++ b/api/client.go @@ -8,10 +8,6 @@ package api import ( - "gitlab.com/xx_network/comms/connect" - "gitlab.com/xx_network/primitives/id" - "time" - "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/auth" @@ -28,12 +24,17 @@ import ( "gitlab.com/elixxir/crypto/cyclic" "gitlab.com/elixxir/crypto/fastRNG" "gitlab.com/elixxir/primitives/version" + "gitlab.com/xx_network/comms/connect" "gitlab.com/xx_network/crypto/csprng" "gitlab.com/xx_network/crypto/large" "gitlab.com/xx_network/crypto/signature/rsa" + "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/ndf" + "time" ) +const followerStoppableName = "client" + type Client struct { //generic RNG for client rng *fastRNG.StreamGenerator @@ -181,7 +182,7 @@ func OpenClient(storageDir string, password []byte, parameters params.Network) ( rng: rngStreamGen, comms: nil, network: nil, - runner: stoppable.NewMulti("client"), + runner: stoppable.NewMulti(followerStoppableName), status: newStatusTracker(), parameters: parameters, } @@ -350,6 +351,7 @@ func (c *Client) initPermissioning(def *ndf.NetworkDefinition) error { } // ----- Client Functions ----- + // StartNetworkFollower kicks off the tracking of the network. It starts // long running network client threads and returns an object for checking // state and stopping those threads. @@ -380,7 +382,7 @@ func (c *Client) initPermissioning(def *ndf.NetworkDefinition) error { // Responds to confirmations of successful rekey operations // - Auth Callback (/auth/callback.go) // Handles both auth confirm and requests -func (c *Client) StartNetworkFollower() (<-chan interfaces.ClientError, error) { +func (c *Client) StartNetworkFollower(timeout time.Duration) (<-chan interfaces.ClientError, error) { u := c.GetUser() jww.INFO.Printf("StartNetworkFollower() \n\tTransmisstionID: %s "+ "\n\tReceptionID: %s", u.TransmissionID, u.ReceptionID) @@ -399,12 +401,21 @@ func (c *Client) StartNetworkFollower() (<-chan interfaces.ClientError, error) { } } - err := c.status.toStarting() + // Wait for any threads from the previous follower to close and then create + // a new stoppable + err := stoppable.WaitForStopped(c.runner, timeout) + if err != nil { + return nil, err + } else { + c.runner = stoppable.NewMulti(followerStoppableName) + } + + err = c.status.toStarting() if err != nil { return nil, errors.WithMessage(err, "Failed to Start the Network Follower") } - stopAuth := c.auth.StartProcessies() + stopAuth := c.auth.StartProcesses() c.runner.Add(stopAuth) stopFollow, err := c.network.Follow(cer) @@ -431,19 +442,18 @@ func (c *Client) StartNetworkFollower() (<-chan interfaces.ClientError, error) { // fails to stop it. // if the network follower is running and this fails, the client object will // most likely be in an unrecoverable state and need to be trashed. -func (c *Client) StopNetworkFollower(timeout time.Duration) error { +func (c *Client) StopNetworkFollower() error { err := c.status.toStopping() if err != nil { return errors.WithMessage(err, "Failed to Stop the Network Follower") } - err = c.runner.Close(timeout) - c.runner = stoppable.NewMulti("client") + err = c.runner.Close() err2 := c.status.toStopped() if err2 != nil { - if err ==nil{ + if err == nil { err = err2 - }else{ - err = errors.WithMessage(err,err2.Error()) + } else { + err = errors.WithMessage(err, err2.Error()) } } return err diff --git a/api/results.go b/api/results.go index 590634b9646f2b92f3b56e29d422145806807d66..d87c538c38ecb9f8baebe489815cca148870d616 100644 --- a/api/results.go +++ b/api/results.go @@ -12,7 +12,6 @@ import ( jww "github.com/spf13/jwalterweatherman" pb "gitlab.com/elixxir/comms/mixmessages" - "gitlab.com/elixxir/comms/network" ds "gitlab.com/elixxir/comms/network/dataStructures" "gitlab.com/elixxir/primitives/states" "gitlab.com/xx_network/comms/connect" @@ -131,7 +130,7 @@ func (c *Client) getRoundResults(roundList []id.Round, timeout time.Duration, // Find out what happened to old (historical) rounds if any are needed if len(historicalRequest.Rounds) > 0 { - go c.getHistoricalRounds(historicalRequest, networkInstance, sendResults, commsInterface) + go c.getHistoricalRounds(historicalRequest, sendResults, commsInterface) } // Determine the results of all rounds requested @@ -180,7 +179,7 @@ func (c *Client) getRoundResults(roundList []id.Round, timeout time.Duration, // Helper function which asynchronously pings a random gateway until // it gets information on it's requested historical rounds func (c *Client) getHistoricalRounds(msg *pb.HistoricalRounds, - instance *network.Instance, sendResults chan ds.EventReturn, comms historicalRoundsComm) { + sendResults chan ds.EventReturn, comms historicalRoundsComm) { var resp *pb.HistoricalRoundsResponse @@ -189,7 +188,7 @@ func (c *Client) getHistoricalRounds(msg *pb.HistoricalRounds, // Find a gateway to request about the roundRequests result, err := c.GetNetworkInterface().GetSender().SendToAny(func(host *connect.Host) (interface{}, error) { return comms.RequestHistoricalRounds(host, msg) - }) + }, nil) // If an error, retry with (potentially) a different gw host. // If no error from received gateway request, exit loop diff --git a/api/send.go b/api/send.go index 19d6a54b853f911e47af7d2d9bcafa46e2b7bd4d..5ef62966daf88579091196898ce4e661583696e3 100644 --- a/api/send.go +++ b/api/send.go @@ -27,7 +27,7 @@ func (c *Client) SendE2E(m message.Send, param params.E2E) ([]id.Round, e2e.MessageID, error) { jww.INFO.Printf("SendE2E(%s, %d. %v)", m.Recipient, m.MessageType, m.Payload) - return c.network.SendE2E(m, param) + return c.network.SendE2E(m, param, nil) } // SendUnsafe sends an unencrypted payload to the provided recipient @@ -52,6 +52,14 @@ func (c *Client) SendCMIX(msg format.Message, recipientID *id.ID, return c.network.SendCMIX(msg, recipientID, param) } +// SendManyCMIX sends many "raw" CMIX message payloads to each of the +// provided recipients. Used for group chat functionality. Returns the +// round ID of the round the payload was sent or an error if it fails. +func (c *Client) SendManyCMIX(messages map[id.ID]format.Message, + params params.CMIX) (id.Round, []ephemeral.Id, error) { + return c.network.SendManyCMIX(messages, params) +} + // NewCMIXMessage Creates a new cMix message with the right properties // for the current cMix network. // FIXME: this is weird and shouldn't be necessary, but it is. diff --git a/api/utilsInterfaces_test.go b/api/utilsInterfaces_test.go index 259349681d4c3466fcf5af3bddc84be63215d076..1faabce01353d23af8665858b4561e8aeb856836 100644 --- a/api/utilsInterfaces_test.go +++ b/api/utilsInterfaces_test.go @@ -93,7 +93,7 @@ func (t *testNetworkManagerGeneric) Follow(report interfaces.ClientErrorReport) func (t *testNetworkManagerGeneric) CheckGarbledMessages() { return } -func (t *testNetworkManagerGeneric) SendE2E(m message.Send, p params.E2E) ( +func (t *testNetworkManagerGeneric) SendE2E(message.Send, params.E2E, *stoppable.Single) ( []id.Round, cE2e.MessageID, error) { rounds := []id.Round{id.Round(0), id.Round(1), id.Round(2)} return rounds, cE2e.MessageID{}, nil @@ -105,6 +105,9 @@ func (t *testNetworkManagerGeneric) SendUnsafe(m message.Send, p params.Unsafe) func (t *testNetworkManagerGeneric) SendCMIX(message format.Message, rid *id.ID, p params.CMIX) (id.Round, ephemeral.Id, error) { return id.Round(0), ephemeral.Id{}, nil } +func (t *testNetworkManagerGeneric) SendManyCMIX(messages map[id.ID]format.Message, p params.CMIX) (id.Round, []ephemeral.Id, error) { + return 0, []ephemeral.Id{}, nil +} func (t *testNetworkManagerGeneric) GetInstance() *network.Instance { return t.instance } diff --git a/auth/callback.go b/auth/callback.go index 7ece838249e97d3fadbb7fb6f6bcae46e059b458..107887b5542ed6f4184f4e07de858444ca0750f4 100644 --- a/auth/callback.go +++ b/auth/callback.go @@ -23,20 +23,21 @@ import ( "strings" ) -func (m *Manager) StartProcessies() stoppable.Stoppable { - +func (m *Manager) StartProcesses() stoppable.Stoppable { stop := stoppable.NewSingle("Auth") go func() { for { select { case <-stop.Quit(): + stop.ToStopped() return case msg := <-m.rawMessages: m.processAuthMessage(msg) } } }() + return stop } @@ -132,7 +133,7 @@ func (m *Manager) handleRequest(cmixMsg format.Message, // confirmation in case there are state issues. // do not store if _, err := m.storage.E2e().GetPartner(partnerID); err == nil { - jww.WARN.Printf("Recieved Auth request for %s, "+ + jww.WARN.Printf("Received Auth request for %s, "+ "channel already exists. Ignoring", partnerID) //exit return @@ -140,8 +141,8 @@ func (m *Manager) handleRequest(cmixMsg format.Message, //check if the relationship already exists, rType, sr2, _, err := m.storage.Auth().GetRequest(partnerID) if err != nil && !strings.Contains(err.Error(), auth.NoRequest) { - // if another error is recieved, print it and exit - jww.WARN.Printf("Recieved new Auth request for %s, "+ + // if another error is received, print it and exit + jww.WARN.Printf("Received new Auth request for %s, "+ "internal lookup produced bad result: %+v", partnerID, err) return diff --git a/bindings/authenticatedChannels.go b/bindings/authenticatedChannels.go index 30b12f0a1b3039fba5b18eafeff874cc7a438a60..da1d0d7ea128c15a6e4d407f8a7a88430666fd16 100644 --- a/bindings/authenticatedChannels.go +++ b/bindings/authenticatedChannels.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "gitlab.com/elixxir/crypto/contact" + "gitlab.com/xx_network/primitives/id" ) // Create an insecure e2e relationship with a precanned user @@ -123,3 +124,15 @@ func (c *Client) VerifyOwnership(receivedMarshaled, verifiedMarshaled []byte) (b return c.api.VerifyOwnership(received, verified), nil } + +// GetRelationshipFingerprint returns a unique 15 character fingerprint for an +// E2E relationship. An error is returned if no relationship with the partner +// is found. +func (c *Client) GetRelationshipFingerprint(partnerID []byte) (string, error) { + partner, err := id.Unmarshal(partnerID) + if err != nil { + return "", err + } + + return c.api.GetRelationshipFingerprint(partner) +} diff --git a/bindings/callback.go b/bindings/callback.go index a6526d24d3ba961db2b79bdc2fbd6a1b950821dc..897ddcfc68ee19c9e967076ba52764f63fa5f667 100644 --- a/bindings/callback.go +++ b/bindings/callback.go @@ -16,7 +16,7 @@ import ( // Listener provides a callback to hear a message // An object implementing this interface can be called back when the client -// gets a message of the type that the regi sterer specified at registration +// gets a message of the type that the registerer specified at registration // time. type Listener interface { // Hear is called to receive a message in the UI diff --git a/bindings/client.go b/bindings/client.go index d51e0de37bcd812068118c604c7dc85afd9f8ae5..898c609bdfa1e61f1b40fd25186b427cf8a8f949 100644 --- a/bindings/client.go +++ b/bindings/client.go @@ -198,8 +198,9 @@ func UnmarshalSendReport(b []byte) (*SendReport, error) { // Responds to sent rekeys and executes them // - KeyExchange Confirm (/keyExchange/confirm.go) // Responds to confirmations of successful rekey operations -func (c *Client) StartNetworkFollower(clientError ClientError) error { - errChan, err := c.api.StartNetworkFollower() +func (c *Client) StartNetworkFollower(clientError ClientError, timeoutMS int) error { + timeout := time.Duration(timeoutMS) * time.Millisecond + errChan, err := c.api.StartNetworkFollower(timeout) if err != nil { return errors.New(fmt.Sprintf("Failed to start the "+ "network follower: %+v", err)) @@ -218,9 +219,8 @@ func (c *Client) StartNetworkFollower(clientError ClientError) error { // fails to stop it. // if the network follower is running and this fails, the client object will // most likely be in an unrecoverable state and need to be trashed. -func (c *Client) StopNetworkFollower(timeoutMS int) error { - timeout := time.Duration(timeoutMS) * time.Millisecond - if err := c.api.StopNetworkFollower(timeout); err != nil { +func (c *Client) StopNetworkFollower() error { + if err := c.api.StopNetworkFollower(); err != nil { return errors.New(fmt.Sprintf("Failed to stop the "+ "network follower: %+v", err)) } diff --git a/bindings/group.go b/bindings/group.go new file mode 100644 index 0000000000000000000000000000000000000000..8752c6665dac3461c8c452d9fbdb558a2f834899 --- /dev/null +++ b/bindings/group.go @@ -0,0 +1,298 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package bindings + +import ( + "github.com/pkg/errors" + gc "gitlab.com/elixxir/client/groupChat" + gs "gitlab.com/elixxir/client/groupChat/groupStore" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/xx_network/primitives/id" +) + +// GroupChat object contains the group chat manager. +type GroupChat struct { + m *gc.Manager +} + +// GroupRequestFunc contains a function callback that is called when a group +// request is received. +type GroupRequestFunc interface { + GroupRequestCallback(g Group) +} + +// GroupReceiveFunc contains a function callback that is called when a group +// message is received. +type GroupReceiveFunc interface { + GroupReceiveCallback(msg GroupMessageReceive) +} + +// NewGroupManager creates a new group chat manager. +func NewGroupManager(client *Client, requestFunc GroupRequestFunc, + receiveFunc GroupReceiveFunc) (GroupChat, error) { + + requestCallback := func(g gs.Group) { + requestFunc.GroupRequestCallback(Group{g}) + } + receiveCallback := func(msg gc.MessageReceive) { + receiveFunc.GroupReceiveCallback(GroupMessageReceive{msg}) + } + + // Create a new group chat manager + m, err := gc.NewManager(&client.api, requestCallback, receiveCallback) + if err != nil { + return GroupChat{}, err + } + + // Start group request and message retrieval workers + client.api.AddService(m.StartProcesses) + + return GroupChat{m}, nil +} + +// MakeGroup creates a new group and sends a group request to all members in the +// group. The ID of the new group, the rounds the requests were sent on, and the +// status of the send are contained in NewGroupReport. +func (g GroupChat) MakeGroup(membership IdList, name, message []byte) (NewGroupReport, error) { + grp, rounds, status, err := g.m.MakeGroup(membership.list, name, message) + return NewGroupReport{Group{grp}, rounds, status}, err +} + +// ResendRequest resends a group request to all members in the group. The rounds +// they were sent on and the status of the send are contained in NewGroupReport. +func (g GroupChat) ResendRequest(groupIdBytes []byte) (NewGroupReport, error) { + groupID, err := id.Unmarshal(groupIdBytes) + if err != nil { + return NewGroupReport{}, + errors.Errorf("Failed to unmarshal group ID: %+v", err) + } + + rounds, status, err := g.m.ResendRequest(groupID) + + return NewGroupReport{Group{}, rounds, status}, nil +} + +// JoinGroup allows a user to join a group when they receive a request. The +// caller must pass in the serialized bytes of a Group. +func (g GroupChat) JoinGroup(serializedGroupData []byte) error { + grp, err := gs.DeserializeGroup(serializedGroupData) + if err != nil { + return err + } + return g.m.JoinGroup(grp) +} + +// LeaveGroup deletes a group so a user no longer has access. +func (g GroupChat) LeaveGroup(groupIdBytes []byte) error { + groupID, err := id.Unmarshal(groupIdBytes) + if err != nil { + return errors.Errorf("Failed to unmarshal group ID: %+v", err) + } + + return g.m.LeaveGroup(groupID) +} + +// Send sends the message to the specified group. Returns the round the messages +// were sent on. +func (g GroupChat) Send(groupIdBytes, message []byte) (int64, error) { + groupID, err := id.Unmarshal(groupIdBytes) + if err != nil { + return 0, errors.Errorf("Failed to unmarshal group ID: %+v", err) + } + + round, err := g.m.Send(groupID, message) + return int64(round), err +} + +// GetGroups returns an IdList containing a list of group IDs that the user is a +// part of. +func (g GroupChat) GetGroups() IdList { + return IdList{g.m.GetGroups()} +} + +// GetGroup returns the group with the group ID. If no group exists, then the +// error "failed to find group" is returned. +func (g GroupChat) GetGroup(groupIdBytes []byte) (Group, error) { + groupID, err := id.Unmarshal(groupIdBytes) + if err != nil { + return Group{}, errors.Errorf("Failed to unmarshal group ID: %+v", err) + } + + grp, exists := g.m.GetGroup(groupID) + if !exists { + return Group{}, errors.New("failed to find group") + } + + return Group{grp}, nil +} + +// NumGroups returns the number of groups the user is a part of. +func (g GroupChat) NumGroups() int { + return g.m.NumGroups() +} + +// NewGroupReport is returned when creating a new group and contains the ID of +// the group, a list of rounds that the group requests were sent on, and the +// status of the send. +type NewGroupReport struct { + group Group + rounds []id.Round + status gc.RequestStatus +} + +// GetGroup returns the Group. +func (ngr NewGroupReport) GetGroup() Group { + return ngr.group +} + +// GetRoundList returns the RoundList containing a list of rounds requests were +// sent on. +func (ngr NewGroupReport) GetRoundList() RoundList { + return RoundList{ngr.rounds} +} + +// GetStatus returns the status of the requests sent when creating a new group. +// status = 0 an error occurred before any requests could be sent +// 1 all requests failed to send +// 2 some request failed and some succeeded +// 3, all requests sent successfully +func (ngr NewGroupReport) GetStatus() int { + return int(ngr.status) +} + +//// +// Group Structure +//// + +// Group structure contains the identifying and membership information of a +// group chat. +type Group struct { + g gs.Group +} + +// GetName returns the name set by the user for the group. +func (g Group) GetName() []byte { + return g.g.Name +} + +// GetID return the 33-byte unique group ID. +func (g Group) GetID() []byte { + return g.g.ID.Bytes() +} + +// GetMembership returns a list of contacts, one for each member in the group. +// The list is in order; the first contact is the leader/creator of the group. +// All subsequent members are ordered by their ID. +func (g Group) GetMembership() GroupMembership { + return GroupMembership{g.g.Members} +} + +// Serialize serializes the Group. +func (g Group) Serialize() []byte { + return g.g.Serialize() +} + +//// +// Membership Structure +//// + +// GroupMembership structure contains a list of members that are part of a +// group. The first member is the group leader. +type GroupMembership struct { + m group.Membership +} + +// Len returns the number of members in the group membership. +func (gm GroupMembership) Len() int { + return gm.Len() +} + +// Get returns the member at the index. The member at index 0 is always the +// group leader. An error is returned if the index is out of range. +func (gm GroupMembership) Get(i int) (GroupMember, error) { + if i < 0 || i > gm.Len() { + return GroupMember{}, errors.Errorf("ID list index must be between %d "+ + "and the last element %d.", 0, gm.Len()) + } + return GroupMember{gm.m[i]}, nil +} + +//// +// Member Structure +//// +// GroupMember represents a member in the group membership list. +type GroupMember struct { + group.Member +} + +// GetID returns the 33-byte user ID of the member. +func (gm GroupMember) GetID() []byte { + return gm.ID.Bytes() +} + +// GetDhKey returns the byte representation of the public Diffie–Hellman key of +// the member. +func (gm GroupMember) GetDhKey() []byte { + return gm.DhKey.Bytes() +} + +//// +// Message Receive Structure +//// + +// GroupMessageReceive contains a group message, its ID, and its data that a +// user receives. +type GroupMessageReceive struct { + gc.MessageReceive +} + +// GetGroupID returns the 33-byte group ID. +func (gmr GroupMessageReceive) GetGroupID() []byte { + return gmr.GroupID.Bytes() +} + +// GetMessageID returns the message ID. +func (gmr GroupMessageReceive) GetMessageID() []byte { + return gmr.ID.Bytes() +} + +// GetPayload returns the message payload. +func (gmr GroupMessageReceive) GetPayload() []byte { + return gmr.Payload +} + +// GetSenderID returns the 33-byte user ID of the sender. +func (gmr GroupMessageReceive) GetSenderID() []byte { + return gmr.SenderID.Bytes() +} + +// GetRecipientID returns the 33-byte user ID of the recipient. +func (gmr GroupMessageReceive) GetRecipientID() []byte { + return gmr.RecipientID.Bytes() +} + +// GetEphemeralID returns the ephemeral ID of the recipient. +func (gmr GroupMessageReceive) GetEphemeralID() int64 { + return gmr.EphemeralID.Int64() +} + +// GetTimestampNano returns the message timestamp in nanoseconds. +func (gmr GroupMessageReceive) GetTimestampNano() int64 { + return gmr.Timestamp.UnixNano() +} + +// GetRoundID returns the ID of the round the message was sent on. +func (gmr GroupMessageReceive) GetRoundID() int64 { + return int64(gmr.RoundID) +} + +// GetRoundTimestampNano returns the timestamp, in nanoseconds, of the round the +// message was sent on. +func (gmr GroupMessageReceive) GetRoundTimestampNano() int64 { + return gmr.RoundTimestamp.UnixNano() +} diff --git a/bindings/list.go b/bindings/list.go index c44fb1679fc0e6492c7c711e7d610bab01198c78..a97df24f299d994380d46de8ae9971ad9133c1e7 100644 --- a/bindings/list.go +++ b/bindings/list.go @@ -8,7 +8,7 @@ package bindings import ( - "errors" + "github.com/pkg/errors" "gitlab.com/elixxir/crypto/contact" "gitlab.com/elixxir/primitives/fact" "gitlab.com/xx_network/primitives/id" @@ -115,3 +115,41 @@ func (fl *FactList) Add(factData string, factType int) error { func (fl *FactList) Stringify() (string, error) { return fl.c.Facts.Stringify(), nil } + +/* ID list */ +// IdList contains a list of IDs. +type IdList struct { + list []*id.ID +} + +// MakeIdList creates a new empty IdList. +func MakeIdList() IdList { + return IdList{[]*id.ID{}} +} + +// Len returns the number of IDs in the list. +func (idl IdList) Len() int { + return len(idl.list) +} + +// Add appends the ID bytes to the end of the list. +func (idl IdList) Add(idBytes []byte) error { + newID, err := id.Unmarshal(idBytes) + if err != nil { + return err + } + + idl.list = append(idl.list, newID) + return nil +} + +// Get returns the ID at the index. An error is returned if the index is out of +// range. +func (idl IdList) Get(i int) ([]byte, error) { + if i < 0 || i > len(idl.list) { + return nil, errors.Errorf("ID list index must be between %d and the "+ + "last element %d.", 0, len(idl.list)) + } + + return idl.list[i].Bytes(), nil +} diff --git a/bindings/message.go b/bindings/message.go index 32947d5fc3b0fc7e87579bdacd39c1bc69edb404..44526fa2e3756868e1b895452d1484f1ff4366fa 100644 --- a/bindings/message.go +++ b/bindings/message.go @@ -17,33 +17,51 @@ type Message struct { r message.Receive } -//Returns the id of the message +// GetID returns the id of the message func (m *Message) GetID() []byte { return m.r.ID[:] } -// Returns the message's sender ID, if available +// GetSender returns the message's sender ID, if available func (m *Message) GetSender() []byte { return m.r.Sender.Bytes() } -// Returns the message's payload/contents +// GetPayload returns the message's payload/contents func (m *Message) GetPayload() []byte { return m.r.Payload } -// Returns the message's type +// GetMessageType returns the message's type func (m *Message) GetMessageType() int { return int(m.r.MessageType) } -// Returns the message's timestamp in ms +// GetTimestampMS returns the message's timestamp in milliseconds func (m *Message) GetTimestampMS() int64 { ts := m.r.Timestamp.UnixNano() ts = (ts + 999999) / 1000000 return ts } +// GetTimestampNano returns the message's timestamp in nanoseconds func (m *Message) GetTimestampNano() int64 { return m.r.Timestamp.UnixNano() } + +// GetRoundTimestampMS returns the message's round timestamp in milliseconds +func (m *Message) GetRoundTimestampMS() int64 { + ts := m.r.RoundTimestamp.UnixNano() + ts = (ts + 999999) / 1000000 + return ts +} + +// GetRoundTimestampNano returns the message's round timestamp in nanoseconds +func (m *Message) GetRoundTimestampNano() int64 { + return m.r.RoundTimestamp.UnixNano() +} + +// GetRoundId returns the message's round ID +func (m *Message) GetRoundId() int64 { + return int64(m.r.RoundId) +} diff --git a/bindings/send.go b/bindings/send.go index 886bf0ea819557c50ddef4b41846d0ae3d5a5d00..586100d051b0f4f0248b235630111476b87b930f 100644 --- a/bindings/send.go +++ b/bindings/send.go @@ -57,6 +57,53 @@ func (c *Client) SendCmix(recipient, contents []byte, parameters string) (int, e return int(rid), nil } +// SendManyCMIX sends many "raw" CMIX message payloads to each of the +// provided recipients. Used for group chat functionality. Returns the +// round ID of the round the payload was sent or an error if it fails. +// This will return an error if: +// - any recipient ID is invalid +// - any of the the message contents are too long for the message structure +// - the message cannot be sent + +// This will return the round the message was sent on if it is successfully sent +// This can be used to register a round event to learn about message delivery. +// on failure a round id of -1 is returned +// fixme: cannot use a slice of slices over bindings. Will need to modify this function once +// a proper input format has been specified +//func (c *Client) SendManyCMIX(recipients, contents [][]byte, parameters string) (int, error) { +// +// p, err := params.GetCMIXParameters(parameters) +// if err != nil { +// return -1, errors.New(fmt.Sprintf("Failed to sendCmix: %+v", +// err)) +// } +// +// // Build messages +// messages := make(map[id.ID]format.Message, len(contents)) +// for i := 0; i < len(contents); i++ { +// msg, err := c.api.NewCMIXMessage(contents[i]) +// if err != nil { +// return -1, errors.New(fmt.Sprintf("Failed to sendCmix: %+v", +// err)) +// } +// +// u, err := id.Unmarshal(recipients[i]) +// if err != nil { +// return -1, errors.New(fmt.Sprintf("Failed to sendCmix: %+v", +// err)) +// } +// +// messages[*u] = msg +// } +// +// rid, _, err := c.api.SendManyCMIX(messages, p) +// if err != nil { +// return -1, errors.New(fmt.Sprintf("Failed to sendCmix: %+v", +// err)) +// } +// return int(rid), nil +//} + // SendUnsafe sends an unencrypted payload to the provided recipient // with the provided msgType. Returns the list of rounds in which parts // of the message were sent or an error if it fails. diff --git a/cmd/group.go b/cmd/group.go new file mode 100644 index 0000000000000000000000000000000000000000..b8064fcb67df8c94c334017023449d1c6e8f99e9 --- /dev/null +++ b/cmd/group.go @@ -0,0 +1,345 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +// The group subcommand allows creation and sending messages to groups + +package cmd + +import ( + "bufio" + "fmt" + "github.com/spf13/cobra" + jww "github.com/spf13/jwalterweatherman" + "github.com/spf13/viper" + "gitlab.com/elixxir/client/api" + "gitlab.com/elixxir/client/groupChat" + "gitlab.com/elixxir/client/groupChat/groupStore" + "gitlab.com/xx_network/primitives/id" + "os" + "time" +) + +// groupCmd represents the base command when called without any subcommands +var groupCmd = &cobra.Command{ + Use: "group", + Short: "Group commands for cMix client", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + + client := initClient() + + // Print user's reception ID + user := client.GetUser() + jww.INFO.Printf("User: %s", user.ReceptionID) + + _, _ = initClientCallbacks(client) + + _, err := client.StartNetworkFollower(5 * time.Second) + if err != nil { + jww.FATAL.Panicf("%+v", err) + } + + // Initialize the group chat manager + groupManager, recChan, reqChan := initGroupManager(client) + + // Wait until connected or crash on timeout + connected := make(chan bool, 10) + client.GetHealth().AddChannel(connected) + waitUntilConnected(connected) + + // After connection, make sure we have registered with at least 85% of + // the nodes + for numReg, total := 1, 100; numReg < (total*3)/4; { + time.Sleep(1 * time.Second) + numReg, total, err = client.GetNodeRegistrationStatus() + if err != nil { + jww.FATAL.Panicf("%+v", err) + } + + jww.INFO.Printf("Registering with nodes (%d/%d)...", numReg, total) + } + + // Get group message and name + msgBody := []byte(viper.GetString("message")) + name := []byte(viper.GetString("name")) + timeout := viper.GetDuration("receiveTimeout") + + if viper.IsSet("create") { + filePath := viper.GetString("create") + createGroup(name, msgBody, filePath, groupManager) + } + + if viper.IsSet("resend") { + groupIdString := viper.GetString("resend") + resendRequests(groupIdString, groupManager) + } + + if viper.GetBool("join") { + joinGroup(reqChan, timeout, groupManager) + } + + if viper.IsSet("leave") { + groupIdString := viper.GetString("leave") + leaveGroup(groupIdString, groupManager) + } + + if viper.IsSet("sendMessage") { + groupIdString := viper.GetString("sendMessage") + sendGroup(groupIdString, msgBody, groupManager) + } + + if viper.IsSet("wait") { + numMessages := viper.GetUint("wait") + messageWait(numMessages, timeout, recChan) + } + + if viper.GetBool("list") { + listGroups(groupManager) + } + + if viper.IsSet("show") { + groupIdString := viper.GetString("show") + showGroup(groupIdString, groupManager) + } + }, +} + +// initGroupManager creates a new group chat manager and starts the process +// service. +func initGroupManager(client *api.Client) (*groupChat.Manager, + chan groupChat.MessageReceive, chan groupStore.Group) { + recChan := make(chan groupChat.MessageReceive, 10) + receiveCb := func(msg groupChat.MessageReceive) { + recChan <- msg + } + + reqChan := make(chan groupStore.Group, 10) + requestCb := func(g groupStore.Group) { + reqChan <- g + } + + jww.INFO.Print("Creating new group manager.") + manager, err := groupChat.NewManager(client, requestCb, receiveCb) + if err != nil { + jww.FATAL.Panicf("Failed to initialize group chat manager: %+v", err) + } + + // Start group request and message receiver + client.AddService(manager.StartProcesses) + + return manager, recChan, reqChan +} + +// createGroup creates a new group with the provided name and sends out requests +// to the list of user IDs found at the given file path. +func createGroup(name, msg []byte, filePath string, gm *groupChat.Manager) { + userIdStrings := ReadLines(filePath) + userIDs := make([]*id.ID, 0, len(userIdStrings)) + for _, userIdStr := range userIdStrings { + userID, _ := parseRecipient(userIdStr) + userIDs = append(userIDs, userID) + } + + grp, rids, status, err := gm.MakeGroup(userIDs, name, msg) + if err != nil { + jww.FATAL.Panicf("Failed to create new group: %+v", err) + } + + // Integration grabs the group ID from this line + jww.INFO.Printf("NewGroupID: b64:%s", grp.ID) + jww.INFO.Printf("Created Group: Requests:%s on rounds %#v, %v", status, rids, grp) + fmt.Printf("Created new group with name %q and message %q\n", grp.Name, + grp.InitMessage) +} + +// resendRequests resends group requests for the group ID. +func resendRequests(groupIdString string, gm *groupChat.Manager) { + groupID, _ := parseRecipient(groupIdString) + rids, status, err := gm.ResendRequest(groupID) + if err != nil { + jww.FATAL.Panicf("Failed to resend requests to group %s: %+v", + groupID, err) + } + + jww.INFO.Printf("Resending requests to group %s: %v, %s", groupID, rids, status) + fmt.Println("Resending group requests to group.") +} + +// joinGroup joins a group when a request is received on the group request +// channel. +func joinGroup(reqChan chan groupStore.Group, timeout time.Duration, gm *groupChat.Manager) { + jww.INFO.Print("Waiting for group request to be received.") + fmt.Println("Waiting for group request to be received.") + + select { + case grp := <-reqChan: + err := gm.JoinGroup(grp) + if err != nil { + jww.FATAL.Panicf("%+v", err) + } + + jww.INFO.Printf("Joined group: %s", grp.ID) + fmt.Printf("Joined group with name %q and message %q\n", + grp.Name, grp.InitMessage) + case <-time.NewTimer(timeout).C: + jww.INFO.Printf("Timed out after %s waiting for group request.", timeout) + fmt.Println("Timed out waiting for group request.") + return + } +} + +// leaveGroup leaves the group. +func leaveGroup(groupIdString string, gm *groupChat.Manager) { + groupID, _ := parseRecipient(groupIdString) + jww.INFO.Printf("Leaving group %s.", groupID) + + err := gm.LeaveGroup(groupID) + if err != nil { + jww.FATAL.Panicf("Failed to leave group %s: %+v", groupID, err) + } + + jww.INFO.Printf("Left group: %s", groupID) + fmt.Println("Left group.") +} + +// sendGroup send the message to the group. +func sendGroup(groupIdString string, msg []byte, gm *groupChat.Manager) { + groupID, _ := parseRecipient(groupIdString) + + jww.INFO.Printf("Sending to group %s message %q", groupID, msg) + + rid, err := gm.Send(groupID, msg) + if err != nil { + jww.FATAL.Panicf("Sending message to group %s: %+v", groupID, err) + } + + jww.INFO.Printf("Sent to group %s on round %d", groupID, rid) + fmt.Printf("Sent message %q to group.\n", msg) +} + +// messageWait waits for the given number of messages to be received on the +// groupChat.MessageReceive channel. +func messageWait(numMessages uint, timeout time.Duration, recChan chan groupChat.MessageReceive) { + jww.INFO.Printf("Waiting for %d group message(s) to be received.", numMessages) + fmt.Printf("Waiting for %d group message(s) to be received.\n", numMessages) + + for i := uint(0); i < numMessages; { + select { + case msg := <-recChan: + i++ + jww.INFO.Printf("Received group message %d/%d: %s", i, numMessages, msg) + fmt.Printf("Received group message: %q\n", msg.Payload) + case <-time.NewTimer(timeout).C: + jww.INFO.Printf("Timed out after %s waiting for group message.", timeout) + fmt.Printf("Timed out waiting for %d group message(s).\n", numMessages) + return + } + } +} + +// listGroups prints a list of all groups. +func listGroups(gm *groupChat.Manager) { + for i, gid := range gm.GetGroups() { + jww.INFO.Printf("Group %d: %s", i, gid) + } + + fmt.Printf("Printed list of %d groups.\n", gm.NumGroups()) +} + +// showGroup prints all the information of the group. +func showGroup(groupIdString string, gm *groupChat.Manager) { + groupID, _ := parseRecipient(groupIdString) + + grp, ok := gm.GetGroup(groupID) + if !ok { + jww.FATAL.Printf("Could not find group: %s", groupID) + } + + jww.INFO.Printf("Show group %#v", grp) + fmt.Printf("Got group with name %q and message %q\n", grp.Name, grp.InitMessage) +} + +// ReadLines returns each line in a file as a string. +func ReadLines(fileName string) []string { + file, err := os.Open(fileName) + if err != nil { + jww.FATAL.Panicf(err.Error()) + } + defer file.Close() + + var res []string + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + res = append(res, scanner.Text()) + } + + if err := scanner.Err(); err != nil { + jww.FATAL.Panicf(err.Error()) + } + return res +} + +func init() { + groupCmd.Flags().String("create", "", + "Create a group with from the list of contact file paths.") + err := viper.BindPFlag("create", groupCmd.Flags().Lookup("create")) + checkBindErr(err, "create") + + groupCmd.Flags().String("name", "Group Name", + "The name of the new group to create.") + err = viper.BindPFlag("name", groupCmd.Flags().Lookup("name")) + checkBindErr(err, "name") + + groupCmd.Flags().String("resend", "", + "Resend invites for all users in this group ID.") + err = viper.BindPFlag("resend", groupCmd.Flags().Lookup("resend")) + checkBindErr(err, "resend") + + groupCmd.Flags().Bool("join", false, + "Waits for group request joins the group.") + err = viper.BindPFlag("join", groupCmd.Flags().Lookup("join")) + checkBindErr(err, "join") + + groupCmd.Flags().String("leave", "", + "Leave this group ID.") + err = viper.BindPFlag("leave", groupCmd.Flags().Lookup("leave")) + checkBindErr(err, "leave") + + groupCmd.Flags().String("sendMessage", "", + "Send message to this group ID.") + err = viper.BindPFlag("sendMessage", groupCmd.Flags().Lookup("sendMessage")) + checkBindErr(err, "sendMessage") + + groupCmd.Flags().Uint("wait", 0, + "Waits for number of messages to be received.") + err = viper.BindPFlag("wait", groupCmd.Flags().Lookup("wait")) + checkBindErr(err, "wait") + + groupCmd.Flags().Duration("receiveTimeout", time.Minute, + "Amount of time to wait for a group request or message before timing out.") + err = viper.BindPFlag("receiveTimeout", groupCmd.Flags().Lookup("receiveTimeout")) + checkBindErr(err, "receiveTimeout") + + groupCmd.Flags().Bool("list", false, + "Prints list all groups to which this client belongs.") + err = viper.BindPFlag("list", groupCmd.Flags().Lookup("list")) + checkBindErr(err, "list") + + groupCmd.Flags().String("show", "", + "Prints the members of this group ID.") + err = viper.BindPFlag("show", groupCmd.Flags().Lookup("show")) + checkBindErr(err, "show") + + rootCmd.AddCommand(groupCmd) +} + +func checkBindErr(err error, key string) { + if err != nil { + jww.ERROR.Printf("viper.BindPFlag failed for %s: %+v", key, err) + } +} diff --git a/cmd/root.go b/cmd/root.go index 38019cac871de70b04dcf06b7b9de772df45fc94..c77caf8b79de8946d4bd1596bf627ded6596c6fa 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -80,42 +80,20 @@ var rootCmd = &cobra.Command{ recipientContact = user.GetContact() } - // Set up reception handler - swboard := client.GetSwitchboard() - recvCh := make(chan message.Receive, 10000) - listenerID := swboard.RegisterChannel("DefaultCLIReceiver", - switchboard.AnyUser(), message.Text, recvCh) - jww.INFO.Printf("Message ListenerID: %v", listenerID) - - // Set up auth request handler, which simply prints the - // user id of the requester. - authMgr := client.GetAuthRegistrar() - authMgr.AddGeneralRequestCallback(printChanRequest) - - // If unsafe channels, add auto-acceptor + confCh, recvCh := initClientCallbacks(client) + + // The following block is used to check if the request from + // a channel authorization is from the recipient we intend in + // this run. authConfirmed := false - authMgr.AddGeneralConfirmCallback(func( - partner contact.Contact) { - jww.INFO.Printf("Channel Confirmed: %s", - partner.ID) - authConfirmed = recipientID.Cmp(partner.ID) - }) - if viper.GetBool("unsafe-channel-creation") { - authMgr.AddGeneralRequestCallback(func( - requestor contact.Contact, message string) { - jww.INFO.Printf("Channel Request: %s", - requestor.ID) - _, err := client.ConfirmAuthenticatedChannel( - requestor) - if err != nil { - jww.FATAL.Panicf("%+v", err) - } - authConfirmed = recipientID.Cmp( - requestor.ID) - }) - } + go func() { + for { + requestor := <-confCh + authConfirmed = recipientID.Cmp(requestor) + } + }() - _, err := client.StartNetworkFollower() + _, err := client.StartNetworkFollower(5 * time.Second) if err != nil { jww.FATAL.Panicf("%+v", err) } @@ -270,7 +248,7 @@ var rootCmd = &cobra.Command{ } fmt.Printf("Received %d\n", receiveCnt) - err = client.StopNetworkFollower(5 * time.Second) + err = client.StopNetworkFollower() if err != nil { jww.WARN.Printf( "Failed to cleanly close threads: %+v\n", @@ -283,6 +261,44 @@ var rootCmd = &cobra.Command{ }, } +func initClientCallbacks(client *api.Client) (chan *id.ID, + chan message.Receive) { + // Set up reception handler + swboard := client.GetSwitchboard() + recvCh := make(chan message.Receive, 10000) + listenerID := swboard.RegisterChannel("DefaultCLIReceiver", + switchboard.AnyUser(), message.Text, recvCh) + jww.INFO.Printf("Message ListenerID: %v", listenerID) + + // Set up auth request handler, which simply prints the + // user id of the requester. + authMgr := client.GetAuthRegistrar() + authMgr.AddGeneralRequestCallback(printChanRequest) + + // If unsafe channels, add auto-acceptor + authConfirmed := make(chan *id.ID, 10) + authMgr.AddGeneralConfirmCallback(func( + partner contact.Contact) { + jww.INFO.Printf("Channel Confirmed: %s", + partner.ID) + authConfirmed <- partner.ID + }) + if viper.GetBool("unsafe-channel-creation") { + authMgr.AddGeneralRequestCallback(func( + requestor contact.Contact, message string) { + jww.INFO.Printf("Channel Request: %s", + requestor.ID) + _, err := client.ConfirmAuthenticatedChannel( + requestor) + if err != nil { + jww.FATAL.Panicf("%+v", err) + } + authConfirmed <- requestor.ID + }) + } + return authConfirmed, recvCh +} + // Helper function which prints the round resuls func printRoundResults(allRoundsSucceeded, timedOut bool, rounds map[id.Round]api.RoundResult, roundIDs []id.Round, msg message.Send) { @@ -352,7 +368,6 @@ func createClient() *api.Client { err = api.NewClient(string(ndfJSON), storeDir, []byte(pass), regCode) } - } if err != nil { diff --git a/cmd/single.go b/cmd/single.go index 15f803b11bec0f775845fdfdfb9f8292c75ea2a7..b827627fa2af0be3f5ec0fa338170f7af6a36493 100644 --- a/cmd/single.go +++ b/cmd/single.go @@ -62,7 +62,7 @@ var singleCmd = &cobra.Command{ }) } - _, err := client.StartNetworkFollower() + _, err := client.StartNetworkFollower(5 * time.Second) if err != nil { jww.FATAL.Panicf("%+v", err) } diff --git a/cmd/ud.go b/cmd/ud.go index 38cbecb2cd1c6572cd140b158f435fe4291500cd..2a3b3d31927ee62522444c179c7dab50c859f4cd 100644 --- a/cmd/ud.go +++ b/cmd/ud.go @@ -62,7 +62,7 @@ var udCmd = &cobra.Command{ }) } - _, err := client.StartNetworkFollower() + _, err := client.StartNetworkFollower(50 * time.Millisecond) if err != nil { jww.FATAL.Panicf("%+v", err) } @@ -176,7 +176,7 @@ var udCmd = &cobra.Command{ } if len(facts) == 0 { - err = client.StopNetworkFollower(10 * time.Second) + err = client.StopNetworkFollower() if err != nil { jww.WARN.Print(err) } @@ -196,7 +196,7 @@ var udCmd = &cobra.Command{ jww.FATAL.Panicf("%+v", err) } time.Sleep(91 * time.Second) - err = client.StopNetworkFollower(90 * time.Second) + err = client.StopNetworkFollower() if err != nil { jww.WARN.Print(err) } diff --git a/go.mod b/go.mod index 402ae1b22ca9b09dfbf5deed303fdc28172901f3..a534cac91c9efd0f75f3fcfd2d9c5f40587c06d3 100644 --- a/go.mod +++ b/go.mod @@ -17,13 +17,13 @@ require ( github.com/spf13/jwalterweatherman v1.1.0 github.com/spf13/viper v1.7.1 gitlab.com/elixxir/bloomfilter v0.0.0-20200930191214-10e9ac31b228 - gitlab.com/elixxir/comms v0.0.4-0.20210603164716-b39b297bbb57 - gitlab.com/elixxir/crypto v0.0.7-0.20210603164430-a52be8b1a3e8 + gitlab.com/elixxir/comms v0.0.4-0.20210607222512-0f2e89b475b4 + gitlab.com/elixxir/crypto v0.0.7-0.20210607221512-0a9bff216f7c gitlab.com/elixxir/ekv v0.1.5 - gitlab.com/elixxir/primitives v0.0.3-0.20210603164310-4bd6e45e65e1 + gitlab.com/elixxir/primitives v0.0.3-0.20210607210820-afd1b028b558 gitlab.com/xx_network/comms v0.0.4-0.20210603164237-d0c36076d7f0 gitlab.com/xx_network/crypto v0.0.5-0.20210603164136-743cb9b0a967 - gitlab.com/xx_network/primitives v0.0.4-0.20210603164056-0abf3f914f25 + gitlab.com/xx_network/primitives v0.0.4-0.20210607221158-361a2cbc5529 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 golang.org/x/net v0.0.0-20210525063256-abc453219eb5 google.golang.org/genproto v0.0.0-20210105202744-fe13368bc0e1 // indirect diff --git a/go.sum b/go.sum index 94d7d1b1af2f749ed09387597fb5b6f2da3697e6..66a72049798930218404a8ba6a848ee93c29d687 100644 --- a/go.sum +++ b/go.sum @@ -247,20 +247,20 @@ github.com/zeebo/pcg v1.0.0 h1:dt+dx+HvX8g7Un32rY9XWoYnd0NmKmrIzpHF7qiTDj0= github.com/zeebo/pcg v1.0.0/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4= gitlab.com/elixxir/bloomfilter v0.0.0-20200930191214-10e9ac31b228 h1:Gi6rj4mAlK0BJIk1HIzBVMjWNjIUfstrsXC2VqLYPcA= gitlab.com/elixxir/bloomfilter v0.0.0-20200930191214-10e9ac31b228/go.mod h1:H6jztdm0k+wEV2QGK/KYA+MY9nj9Zzatux/qIvDDv3k= -gitlab.com/elixxir/comms v0.0.4-0.20210603164716-b39b297bbb57 h1:7brO8ImMildCKPL1EXYgWPGz4GEIKBRSZmTGeGD33n8= -gitlab.com/elixxir/comms v0.0.4-0.20210603164716-b39b297bbb57/go.mod h1:s+aHLLk2bJOajrTb237G3J3gRUFkFWMddnaGOgt5Deg= +gitlab.com/elixxir/comms v0.0.4-0.20210607222512-0f2e89b475b4 h1:9h/Dkvu8t4/Q2LY81qoxjbO4GJtVanjWV3B9UMuZgDA= +gitlab.com/elixxir/comms v0.0.4-0.20210607222512-0f2e89b475b4/go.mod h1:Ox1NgdvFRy4/AfWAIrKZR9W6O+PAYeNOZviGbhqH+eo= gitlab.com/elixxir/crypto v0.0.0-20200804182833-984246dea2c4/go.mod h1:ucm9SFKJo+K0N2GwRRpaNr+tKXMIOVWzmyUD0SbOu2c= gitlab.com/elixxir/crypto v0.0.3/go.mod h1:ZNgBOblhYToR4m8tj4cMvJ9UsJAUKq+p0gCp07WQmhA= -gitlab.com/elixxir/crypto v0.0.7-0.20210603164430-a52be8b1a3e8 h1:3d4I+PxwyhvLwj6jInDO43kCGcSiybhrREiYH9jsFmc= -gitlab.com/elixxir/crypto v0.0.7-0.20210603164430-a52be8b1a3e8/go.mod h1:WacVVYlNPYbeDR1gBRJHtAx6T5P6ZZZ0tTHW5GTdi5Q= +gitlab.com/elixxir/crypto v0.0.7-0.20210607221512-0a9bff216f7c h1:G4IE4xEnoapQoZsMK1XRAB5QdOr2z+IyxZ6JOIXXuQ8= +gitlab.com/elixxir/crypto v0.0.7-0.20210607221512-0a9bff216f7c/go.mod h1:HAW5sLSUHKgylen7C4YoVNr0kp5g21/eKJw6sCuBRFs= gitlab.com/elixxir/ekv v0.1.5 h1:R8M1PA5zRU1HVnTyrtwybdABh7gUJSCvt1JZwUSeTzk= gitlab.com/elixxir/ekv v0.1.5/go.mod h1:e6WPUt97taFZe5PFLPb1Dupk7tqmDCTQu1kkstqJvw4= gitlab.com/elixxir/primitives v0.0.0-20200731184040-494269b53b4d/go.mod h1:OQgUZq7SjnE0b+8+iIAT2eqQF+2IFHn73tOo+aV11mg= gitlab.com/elixxir/primitives v0.0.0-20200804170709-a1896d262cd9/go.mod h1:p0VelQda72OzoUckr1O+vPW0AiFe0nyKQ6gYcmFSuF8= gitlab.com/elixxir/primitives v0.0.0-20200804182913-788f47bded40/go.mod h1:tzdFFvb1ESmuTCOl1z6+yf6oAICDxH2NPUemVgoNLxc= gitlab.com/elixxir/primitives v0.0.1/go.mod h1:kNp47yPqja2lHSiS4DddTvFpB/4D9dB2YKnw5c+LJCE= -gitlab.com/elixxir/primitives v0.0.3-0.20210603164310-4bd6e45e65e1 h1:RZFD4TbJCBO36WYGVH7gakUKbumFsAUsuEqR8ryYF/A= -gitlab.com/elixxir/primitives v0.0.3-0.20210603164310-4bd6e45e65e1/go.mod h1:1EFCSsERjE5RagCjys/70sqkAZC5PimEkWHcljh4bwQ= +gitlab.com/elixxir/primitives v0.0.3-0.20210607210820-afd1b028b558 h1:J8FllvIDv5RkON+Bg61NxNr78cYLOLifRrg/ugm5mW8= +gitlab.com/elixxir/primitives v0.0.3-0.20210607210820-afd1b028b558/go.mod h1:1EFCSsERjE5RagCjys/70sqkAZC5PimEkWHcljh4bwQ= gitlab.com/xx_network/comms v0.0.0-20200805174823-841427dd5023/go.mod h1:owEcxTRl7gsoM8c3RQ5KAm5GstxrJp5tn+6JfQ4z5Hw= gitlab.com/xx_network/comms v0.0.4-0.20210603164237-d0c36076d7f0 h1:+/U+6Ra5pqDhIHCqMniESovsakooCFTv/omlpfvffU8= gitlab.com/xx_network/comms v0.0.4-0.20210603164237-d0c36076d7f0/go.mod h1:cpogFfWweZFzldGnRmgI9ilW2IFyeINDXpa1sAP950U= @@ -271,8 +271,9 @@ gitlab.com/xx_network/crypto v0.0.5-0.20210603164136-743cb9b0a967/go.mod h1:qOOP gitlab.com/xx_network/primitives v0.0.0-20200803231956-9b192c57ea7c/go.mod h1:wtdCMr7DPePz9qwctNoAUzZtbOSHSedcK++3Df3psjA= gitlab.com/xx_network/primitives v0.0.0-20200804183002-f99f7a7284da/go.mod h1:OK9xevzWCaPO7b1wiluVJGk7R5ZsuC7pHY5hteZFQug= gitlab.com/xx_network/primitives v0.0.2/go.mod h1:cs0QlFpdMDI6lAo61lDRH2JZz+3aVkHy+QogOB6F/qc= -gitlab.com/xx_network/primitives v0.0.4-0.20210603164056-0abf3f914f25 h1:SH3pO8rAj59B6LLFDYWNrgj5yUncT7/TBXrSyORbQRQ= gitlab.com/xx_network/primitives v0.0.4-0.20210603164056-0abf3f914f25/go.mod h1:9imZHvYwNFobxueSvVtHneZLk9wTK7HQTzxPm+zhFhE= +gitlab.com/xx_network/primitives v0.0.4-0.20210607221158-361a2cbc5529 h1:zC1z2Pcxy+fQc3ZzRxz2lOorj1LqBms3xohIVmThPb0= +gitlab.com/xx_network/primitives v0.0.4-0.20210607221158-361a2cbc5529/go.mod h1:9imZHvYwNFobxueSvVtHneZLk9wTK7HQTzxPm+zhFhE= gitlab.com/xx_network/ring v0.0.3-0.20210527191221-ce3f170aabd5 h1:FY+4Rh1Q2rgLyv10aKJjhWApuKRCR/054XhreudfAvw= gitlab.com/xx_network/ring v0.0.3-0.20210527191221-ce3f170aabd5/go.mod h1:aLzpP2TiZTQut/PVHR40EJAomzugDdHXetbieRClXIM= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= diff --git a/groupChat/gcMessages.pb.go b/groupChat/gcMessages.pb.go new file mode 100644 index 0000000000000000000000000000000000000000..37b23167e412866c8012fd9f8c95bd99adab88b3 --- /dev/null +++ b/groupChat/gcMessages.pb.go @@ -0,0 +1,117 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: groupChat/gcMessages.proto + +package groupChat + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +// Request to join the group sent from leader to all members. +type Request struct { + Name []byte `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + IdPreimage []byte `protobuf:"bytes,2,opt,name=idPreimage,proto3" json:"idPreimage,omitempty"` + KeyPreimage []byte `protobuf:"bytes,3,opt,name=keyPreimage,proto3" json:"keyPreimage,omitempty"` + Members []byte `protobuf:"bytes,4,opt,name=members,proto3" json:"members,omitempty"` + Message []byte `protobuf:"bytes,5,opt,name=message,proto3" json:"message,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Request) Reset() { *m = Request{} } +func (m *Request) String() string { return proto.CompactTextString(m) } +func (*Request) ProtoMessage() {} +func (*Request) Descriptor() ([]byte, []int) { + return fileDescriptor_49d0b7a6ffb7e279, []int{0} +} + +func (m *Request) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Request.Unmarshal(m, b) +} +func (m *Request) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Request.Marshal(b, m, deterministic) +} +func (m *Request) XXX_Merge(src proto.Message) { + xxx_messageInfo_Request.Merge(m, src) +} +func (m *Request) XXX_Size() int { + return xxx_messageInfo_Request.Size(m) +} +func (m *Request) XXX_DiscardUnknown() { + xxx_messageInfo_Request.DiscardUnknown(m) +} + +var xxx_messageInfo_Request proto.InternalMessageInfo + +func (m *Request) GetName() []byte { + if m != nil { + return m.Name + } + return nil +} + +func (m *Request) GetIdPreimage() []byte { + if m != nil { + return m.IdPreimage + } + return nil +} + +func (m *Request) GetKeyPreimage() []byte { + if m != nil { + return m.KeyPreimage + } + return nil +} + +func (m *Request) GetMembers() []byte { + if m != nil { + return m.Members + } + return nil +} + +func (m *Request) GetMessage() []byte { + if m != nil { + return m.Message + } + return nil +} + +func init() { + proto.RegisterType((*Request)(nil), "gcRequestMessages.Request") +} + +func init() { + proto.RegisterFile("groupChat/gcMessages.proto", fileDescriptor_49d0b7a6ffb7e279) +} + +var fileDescriptor_49d0b7a6ffb7e279 = []byte{ + // 186 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x4a, 0x2f, 0xca, 0x2f, + 0x2d, 0x70, 0xce, 0x48, 0x2c, 0xd1, 0x4f, 0x4f, 0xf6, 0x4d, 0x2d, 0x2e, 0x4e, 0x4c, 0x4f, 0x2d, + 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x12, 0x4c, 0x4f, 0x0e, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, + 0x2e, 0x81, 0x49, 0x28, 0x4d, 0x66, 0xe4, 0x62, 0x87, 0x8a, 0x09, 0x09, 0x71, 0xb1, 0xe4, 0x25, + 0xe6, 0xa6, 0x4a, 0x30, 0x2a, 0x30, 0x6a, 0xf0, 0x04, 0x81, 0xd9, 0x42, 0x72, 0x5c, 0x5c, 0x99, + 0x29, 0x01, 0x45, 0xa9, 0x99, 0xb9, 0x89, 0xe9, 0xa9, 0x12, 0x4c, 0x60, 0x19, 0x24, 0x11, 0x21, + 0x05, 0x2e, 0xee, 0xec, 0xd4, 0x4a, 0xb8, 0x02, 0x66, 0xb0, 0x02, 0x64, 0x21, 0x21, 0x09, 0x2e, + 0xf6, 0xdc, 0xd4, 0xdc, 0xa4, 0xd4, 0xa2, 0x62, 0x09, 0x16, 0xb0, 0x2c, 0x8c, 0x0b, 0x91, 0x01, + 0xbb, 0x43, 0x82, 0x15, 0x26, 0x03, 0xe6, 0x3a, 0xa9, 0x46, 0x29, 0xa7, 0x67, 0x96, 0xe4, 0x24, + 0x26, 0xe9, 0x25, 0xe7, 0xe7, 0xea, 0xa7, 0xe6, 0x64, 0x56, 0x54, 0x64, 0x16, 0xe9, 0x27, 0xe7, + 0x64, 0xa6, 0xe6, 0x95, 0xe8, 0xc3, 0x3d, 0x98, 0xc4, 0x06, 0xf6, 0x96, 0x31, 0x20, 0x00, 0x00, + 0xff, 0xff, 0x6e, 0x63, 0x77, 0xd1, 0xf4, 0x00, 0x00, 0x00, +} diff --git a/groupChat/gcMessages.proto b/groupChat/gcMessages.proto new file mode 100644 index 0000000000000000000000000000000000000000..7ce1f3b40c5f34f61367dcec2e477dd8040faae5 --- /dev/null +++ b/groupChat/gcMessages.proto @@ -0,0 +1,20 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +syntax = "proto3"; +package gcRequestMessages; +option go_package = "gitlab.com/elixxir/client/groupChat"; + + +// Request to join the group sent from leader to all members. +message Request { + bytes name = 1; + bytes idPreimage = 2; + bytes keyPreimage = 3; + bytes members = 4; + bytes message = 5; +} \ No newline at end of file diff --git a/groupChat/generateProto.sh b/groupChat/generateProto.sh new file mode 100644 index 0000000000000000000000000000000000000000..43968a4aa112270ffb38ea9a2c5da91e309871f1 --- /dev/null +++ b/groupChat/generateProto.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +protoc --go_out=paths=source_relative:. groupChat/gcMessages.proto diff --git a/groupChat/group.go b/groupChat/group.go new file mode 100644 index 0000000000000000000000000000000000000000..5d67d19f1e52f80cf1628f2deece083b4d8f4568 --- /dev/null +++ b/groupChat/group.go @@ -0,0 +1,70 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +// Group chat is used to communicate the same content with multiple clients over +// cMix. A group chat is controlled by a group leader who creates the group, +// defines all group keys, and is responsible for key rotation. To create a +// group, the group leader must have an authenticated channel with all members +// of the group. +// +// Once a group is created, neither the leader nor other members can add or +// remove users to the group. Only members can leave a group themselves. +// +// When a message is sent to the group, the sender will send an individual +// message to every member of the group. + +package groupChat + +import ( + gs "gitlab.com/elixxir/client/groupChat/groupStore" + "gitlab.com/xx_network/primitives/id" +) + +// GroupChat is used to send and receive cMix messages to/from multiple users. +type GroupChat interface { + // MakeGroup sends GroupChat requests to all members over an authenticated + // channel. The leader of a GroupChat must have an authenticated channel + // with each member of the GroupChat to add them to the GroupChat. It blocks + // until all the GroupChat requests are sent. Returns the new group and the + // round IDs the requests were sent on. Returns an error if at least one + // request to a member fails to send. Also returns the status of the sent + // requests. + MakeGroup(membership []*id.ID, name, message []byte) (gs.Group, []id.Round, + RequestStatus, error) + + // ResendRequest allows a GroupChat request to be sent again. It returns + // the rounds that the requests were sent on and the status of the send. + ResendRequest(groupID *id.ID) ([]id.Round, RequestStatus, error) + + // JoinGroup allows a user to accept a GroupChat request and stores the + // GroupChat as active to allow receiving and sending of messages from/to + // the GroupChat. A user can only join a GroupChat once. + JoinGroup(g gs.Group) error + + // LeaveGroup removes a group from a list of groups the user is a part of. + LeaveGroup(groupID *id.ID) error + + // Send sends a message to all GroupChat members using Client.SendManyCMIX. + // The send fails if the message is too long. + Send(groupID *id.ID, message []byte) (id.Round, error) + + // GetGroups returns a list of all registered GroupChat IDs. + GetGroups() []*id.ID + + // GetGroup returns the group with the matching ID or returns false if none + // exist. + GetGroup(groupID *id.ID) (gs.Group, bool) + + // NumGroups returns the number of groups the user is a part of. + NumGroups() int +} + +// RequestCallback is called when a GroupChat request is received. +type RequestCallback func(g gs.Group) + +// ReceiveCallback is called when a GroupChat message is received. +type ReceiveCallback func(msg MessageReceive) diff --git a/groupChat/groupStore/dhKeyList.go b/groupChat/groupStore/dhKeyList.go new file mode 100644 index 0000000000000000000000000000000000000000..50c405c426d7a27f2f99767208d2b2915a4f93b5 --- /dev/null +++ b/groupChat/groupStore/dhKeyList.go @@ -0,0 +1,139 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupStore + +import ( + "bytes" + "encoding/binary" + "github.com/pkg/errors" + "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/elixxir/crypto/diffieHellman" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/xx_network/primitives/id" + "sort" + "strings" +) + +// Error messages. +const ( + idUnmarshalErr = "failed to unmarshal member ID: %+v" + dhKeyDecodeErr = "failed to decode member DH key: %+v" +) + +type DhKeyList map[id.ID]*cyclic.Int + +// GenerateDhKeyList generates the symmetric/DH key between the user and all +// group members. +func GenerateDhKeyList(userID *id.ID, privKey *cyclic.Int, + members group.Membership, grp *cyclic.Group) DhKeyList { + dkl := make(DhKeyList, len(members)-1) + + for _, m := range members { + if !userID.Cmp(m.ID) { + dkl.Add(privKey, m, grp) + } + } + + return dkl +} + +// Add generates DH key between the user and the group member. The +func (dkl DhKeyList) Add(privKey *cyclic.Int, m group.Member, grp *cyclic.Group) { + dkl[*m.ID] = diffieHellman.GenerateSessionKey(privKey, m.DhKey, grp) +} + +// DeepCopy returns a copy of the DhKeyList. +func (dkl DhKeyList) DeepCopy() DhKeyList { + newDkl := make(DhKeyList, len(dkl)) + for uid, key := range dkl { + newDkl[uid] = key.DeepCopy() + } + return newDkl +} + +// Serialize serializes the DhKeyList and returns the byte slice. +func (dkl DhKeyList) Serialize() []byte { + buff := bytes.NewBuffer(nil) + + for uid, key := range dkl { + // Write ID + buff.Write(uid.Marshal()) + + // Write DH key length + b := make([]byte, 8) + keyBytes := key.BinaryEncode() + binary.LittleEndian.PutUint64(b, uint64(len(keyBytes))) + buff.Write(b) + + // Write DH key + buff.Write(keyBytes) + } + + return buff.Bytes() +} + +// DeserializeDhKeyList deserializes the bytes into a DhKeyList. +func DeserializeDhKeyList(data []byte) (DhKeyList, error) { + if len(data) == 0 { + return nil, nil + } + + buff := bytes.NewBuffer(data) + dkl := make(DhKeyList) + + for n := buff.Next(id.ArrIDLen); len(n) == id.ArrIDLen; n = buff.Next(id.ArrIDLen) { + // Read and unmarshal ID + uid, err := id.Unmarshal(n) + if err != nil { + return nil, errors.Errorf(idUnmarshalErr, err) + } + + // Get length of DH key + keyLen := int(binary.LittleEndian.Uint64(buff.Next(8))) + + // Read and decode DH key + key := &cyclic.Int{} + err = key.BinaryDecode(buff.Next(keyLen)) + if err != nil { + return nil, errors.Errorf(dhKeyDecodeErr, err) + } + + dkl[*uid] = key + } + + return dkl, nil +} + +// GoString returns all the elements in the DhKeyList as text in sorted order. +// This functions satisfies the fmt.GoStringer interface. +func (dkl DhKeyList) GoString() string { + str := make([]string, 0, len(dkl)) + + unsorted := make([]struct { + uid *id.ID + key *cyclic.Int + }, 0, len(dkl)) + + for uid, key := range dkl { + unsorted = append(unsorted, struct { + uid *id.ID + key *cyclic.Int + }{uid: uid.DeepCopy(), key: key.DeepCopy()}) + } + + sort.Slice(unsorted, func(i, j int) bool { + return bytes.Compare(unsorted[i].uid.Bytes(), + unsorted[j].uid.Bytes()) == -1 + }) + + for _, val := range unsorted { + str = append(str, val.uid.String()+": "+val.key.Text(10)) + } + + return "{" + strings.Join(str, ", ") + "}" +} diff --git a/groupChat/groupStore/dhKeyList_test.go b/groupChat/groupStore/dhKeyList_test.go new file mode 100644 index 0000000000000000000000000000000000000000..eca03f45a187c63a13276b7f11edfda2010e7b4d --- /dev/null +++ b/groupChat/groupStore/dhKeyList_test.go @@ -0,0 +1,93 @@ +package groupStore + +import ( + "math/rand" + "reflect" + "strings" + "testing" +) + +// // Unit test of GenerateDhKeyList. +// func TestGenerateDhKeyList(t *testing.T) { +// prng := rand.New(rand.NewSource(42)) +// grp := getGroup() +// userID := id.NewIdFromString("userID", id.User, t) +// privKey := grp.NewInt(42) +// pubKey := grp.ExpG(privKey, grp.NewInt(1)) +// members := createMembership(prng, 10, t) +// members[2].ID = userID +// members[2].DhKey = pubKey +// +// dkl := GenerateDhKeyList(userID, privKey, members, grp) +// +// t.Log(dkl) +// } + +// Unit test of DhKeyList.DeepCopy. +func TestDhKeyList_DeepCopy(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + dkl := createDhKeyList(prng, 10, t) + newDkl := dkl.DeepCopy() + + if !reflect.DeepEqual(dkl, newDkl) { + t.Errorf("DeepCopy() failed to return a copy of the original."+ + "\nexpected: %#v\nrecevied: %#v", dkl, newDkl) + } + + if &dkl == &newDkl { + t.Errorf("DeepCopy returned a copy of the pointer."+ + "\nexpected: %p\nreceived: %p", &dkl, &newDkl) + } +} + +// Tests that a DhKeyList that is serialized and deserialized matches the +// original. +func TestDhKeyList_Serialize_DeserializeDhKeyList(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + dkl := createDhKeyList(prng, 10, t) + + data := dkl.Serialize() + newDkl, err := DeserializeDhKeyList(data) + if err != nil { + t.Errorf("DeserializeDhKeyList returned an error: %+v", err) + } + + if !reflect.DeepEqual(dkl, newDkl) { + t.Errorf("Failed to serialize and deserialize DhKeyList."+ + "\nexpected: %#v\nreceived: %#v", dkl, newDkl) + } +} + +// Error path: an error is returned when DeserializeDhKeyList encounters invalid +// cyclic int. +func TestDeserializeDhKeyList_DhKeyBinaryDecodeError(t *testing.T) { + expectedErr := strings.SplitN(dhKeyDecodeErr, "%", 2)[0] + + _, err := DeserializeDhKeyList(make([]byte, 41)) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("DeserializeDhKeyList failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Unit test of DhKeyList.GoString. +func TestDhKeyList_GoString(t *testing.T) { + grp := createTestGroup(rand.New(rand.NewSource(42)), t) + expected := "{Grcjbkt1IWKQzyvrQsPKJzKFYPGqwGfOpui/RtSrK0YD: 5170411903... in GRP: 6SsQ/HAHUn..., QCxg8d6XgoPUoJo2+WwglBdG4+1NpkaprotPp7T8OiAD: 1754900790... in GRP: 6SsQ/HAHUn..., invD4ElbVxL+/b4MECiH4QDazS2IX2kstgfaAKEcHHAD: 2926033432... in GRP: 6SsQ/HAHUn..., wRYCP6iJdLrAyv2a0FaSsTYZ5ziWTf3Hno1TQ3NmHP0D: 2297312580... in GRP: 6SsQ/HAHUn..., 15ufnw07pVsMwNYUTIiFNYQay+BwmwdYCD9h03W8ArQD: 6199513233... in GRP: 6SsQ/HAHUn..., 3RqsBM4ux44bC6+uiBuCp1EQikLtPJA8qkNGWnhiBhYD: 4604475835... in GRP: 6SsQ/HAHUn..., 55ai4SlwXic/BckjJoKOKwVuOBdljhBhSYlH/fNEQQ4D: 9940605492... in GRP: 6SsQ/HAHUn..., 9PkZKU50joHnnku9b+NM3LqEPujWPoxP/hzr6lRtj6wD: 2451667393... in GRP: 6SsQ/HAHUn..., +hp17fHP0rO1EhnqeVM6v0SNLEedMmB1M5BZFMjMHPAD: 6029441980... in GRP: 6SsQ/HAHUn...}" + + if grp.DhKeys.GoString() != expected { + t.Errorf("GoString failed to return the expected string."+ + "\nexpected: %s\nreceived: %s", expected, grp.DhKeys.GoString()) + } +} + +// Tests that DhKeyList.GoString. returns the expected string for a nil map. +func TestDhKeyList_GoString_NilMap(t *testing.T) { + dkl := DhKeyList{} + expected := "{}" + + if dkl.GoString() != expected { + t.Errorf("GoString failed to return the expected string."+ + "\nexpected: %s\nreceived: %s", expected, dkl.GoString()) + } +} diff --git a/groupChat/groupStore/group.go b/groupChat/groupStore/group.go new file mode 100644 index 0000000000000000000000000000000000000000..98d0b489de808cf1ed5ca31818f4bc7f203fe878 --- /dev/null +++ b/groupChat/groupStore/group.go @@ -0,0 +1,235 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupStore + +import ( + "bytes" + "encoding/binary" + "fmt" + "github.com/pkg/errors" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "strings" +) + +// Storage values. +const ( + // Key that is prepended to group ID to create a unique key to identify a + // Group in storage. + groupStorageKey = "GroupChat/" + groupStoreVersion = 0 +) + +// Error messages. +const ( + kvGetGroupErr = "failed to get group %s from storage: %+v" + membershipErr = "failed to deserialize member list: %+v" + dhKeyListErr = "failed to deserialize DH key list: %+v" +) + +// Group contains the membership list, the cryptographic information, and the +// identifying information of a group chat. +type Group struct { + Name []byte // Name of the group set by the user + ID *id.ID // Group ID + Key group.Key // Group key + IdPreimage group.IdPreimage // 256-bit value from CRNG + KeyPreimage group.KeyPreimage // 256-bit value from CRNG + InitMessage []byte // The original invite message + Members group.Membership // Sorted list of members in group + DhKeys DhKeyList // List of shared DH keys +} + +// NewGroup creates a new Group from copies of the given data. +func NewGroup(name []byte, groupID *id.ID, groupKey group.Key, + idPreimage group.IdPreimage, keyPreimage group.KeyPreimage, + initMessage []byte, members group.Membership, dhKeys DhKeyList) Group { + g := Group{ + Name: make([]byte, len(name)), + ID: groupID.DeepCopy(), + Key: groupKey, + IdPreimage: idPreimage, + KeyPreimage: keyPreimage, + InitMessage: make([]byte, len(initMessage)), + Members: members.DeepCopy(), + DhKeys: dhKeys, + } + + copy(g.Name, name) + copy(g.InitMessage, initMessage) + + return g +} + +// DeepCopy returns a copy of the Group. +func (g Group) DeepCopy() Group { + newGrp := Group{ + Name: make([]byte, len(g.Name)), + ID: g.ID.DeepCopy(), + Key: g.Key, + IdPreimage: g.IdPreimage, + KeyPreimage: g.KeyPreimage, + InitMessage: make([]byte, len(g.InitMessage)), + Members: g.Members.DeepCopy(), + DhKeys: make(map[id.ID]*cyclic.Int, len(g.Members)-1), + } + + copy(newGrp.Name, g.Name) + copy(newGrp.InitMessage, g.InitMessage) + + for uid, key := range g.DhKeys { + newGrp.DhKeys[uid] = key.DeepCopy() + } + + return newGrp +} + +// store saves an individual Group to storage keying on the group ID. +func (g Group) store(kv *versioned.KV) error { + obj := &versioned.Object{ + Version: groupStoreVersion, + Timestamp: netTime.Now(), + Data: g.Serialize(), + } + + return kv.Set(groupStoreKey(g.ID), groupStoreVersion, obj) +} + +// loadGroup returns the group with the corresponding ID from storage. +func loadGroup(groupID *id.ID, kv *versioned.KV) (Group, error) { + obj, err := kv.Get(groupStoreKey(groupID), groupStoreVersion) + if err != nil { + return Group{}, errors.Errorf(kvGetGroupErr, groupID, err) + } + + return DeserializeGroup(obj.Data) +} + +// removeGroup deletes the given group from storage. +func removeGroup(groupID *id.ID, kv *versioned.KV) error { + return kv.Delete(groupStoreKey(groupID), groupStoreVersion) +} + +// Serialize serializes the Group and returns the byte slice. +func (g Group) Serialize() []byte { + buff := bytes.NewBuffer(nil) + + // Write length of name and name + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(len(g.Name))) + buff.Write(b) + buff.Write(g.Name) + + // Write group ID + if g.ID != nil { + buff.Write(g.ID.Marshal()) + } else { + buff.Write(make([]byte, id.ArrIDLen)) + } + + // Write group key and preimages + buff.Write(g.Key[:]) + buff.Write(g.IdPreimage[:]) + buff.Write(g.KeyPreimage[:]) + + // Write length of InitMessage and InitMessage + b = make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(len(g.InitMessage))) + buff.Write(b) + buff.Write(g.InitMessage) + + // Write length of group membership and group membership + b = make([]byte, 8) + memberBytes := g.Members.Serialize() + binary.LittleEndian.PutUint64(b, uint64(len(memberBytes))) + buff.Write(b) + buff.Write(memberBytes) + + // Write DH key list + buff.Write(g.DhKeys.Serialize()) + + return buff.Bytes() +} + +// DeserializeGroup deserializes the bytes into a Group. +func DeserializeGroup(data []byte) (Group, error) { + buff := bytes.NewBuffer(data) + var g Group + var err error + + // Get name + nameLen := binary.LittleEndian.Uint64(buff.Next(8)) + if nameLen > 0 { + g.Name = buff.Next(int(nameLen)) + } + + // Get group ID + var groupID id.ID + copy(groupID[:], buff.Next(id.ArrIDLen)) + if groupID == [id.ArrIDLen]byte{} { + g.ID = nil + } else { + g.ID = &groupID + } + + // Get group key and preimages + copy(g.Key[:], buff.Next(group.KeyLen)) + copy(g.IdPreimage[:], buff.Next(group.IdPreimageLen)) + copy(g.KeyPreimage[:], buff.Next(group.KeyPreimageLen)) + + // Get InitMessage + initMessageLength := binary.LittleEndian.Uint64(buff.Next(8)) + if initMessageLength > 0 { + g.InitMessage = buff.Next(int(initMessageLength)) + } + + // Get member list + membersLength := binary.LittleEndian.Uint64(buff.Next(8)) + g.Members, err = group.DeserializeMembership(buff.Next(int(membersLength))) + if err != nil { + return Group{}, errors.Errorf(membershipErr, err) + } + + // Get DH key list + g.DhKeys, err = DeserializeDhKeyList(buff.Bytes()) + if err != nil { + return Group{}, errors.Errorf(dhKeyListErr, err) + } + + return g, err +} + +// groupStoreKey generates a unique key to save and load a Group to/from storage. +func groupStoreKey(groupID *id.ID) string { + return groupStorageKey + groupID.String() +} + +// GoString returns all the Group's fields as text. This functions satisfies the +// fmt.GoStringer interface. +func (g Group) GoString() string { + idString := "<nil>" + if g.ID != nil { + idString = g.ID.String() + } + + str := make([]string, 8) + + str[0] = "Name:" + fmt.Sprintf("%q", g.Name) + str[1] = "ID:" + idString + str[2] = "Key:" + g.Key.String() + str[3] = "IdPreimage:" + g.IdPreimage.String() + str[4] = "KeyPreimage:" + g.KeyPreimage.String() + str[5] = "InitMessage:" + fmt.Sprintf("%q", g.InitMessage) + str[6] = "Members:" + g.Members.String() + str[7] = "DhKeys:" + g.DhKeys.GoString() + + return "{" + strings.Join(str, ", ") + "}" +} diff --git a/groupChat/groupStore/group_test.go b/groupChat/groupStore/group_test.go new file mode 100644 index 0000000000000000000000000000000000000000..b0f84761b39c17059635c592e2f6f271471dbab1 --- /dev/null +++ b/groupChat/groupStore/group_test.go @@ -0,0 +1,289 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupStore + +import ( + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/primitives/id" + "math/rand" + "reflect" + "strings" + "testing" +) + +// Unit test of NewGroup. +func TestNewGroup(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + membership := createMembership(prng, 10, t) + dkl := GenerateDhKeyList(membership[0].ID, randCycInt(prng), membership, getGroup()) + + expectedGroup := Group{ + Name: []byte(groupName), + ID: id.NewIdFromUInt(uint64(42), id.Group, t), + Key: newKey(groupKey), + IdPreimage: newIdPreimage(groupIdPreimage), + KeyPreimage: newKeyPreimage(groupKeyPreimage), + InitMessage: []byte(initMessage), + Members: membership, + DhKeys: dkl, + } + + receivedGroup := NewGroup( + []byte(groupName), + id.NewIdFromUInt(uint64(42), id.Group, t), + newKey(groupKey), + newIdPreimage(groupIdPreimage), + newKeyPreimage(groupKeyPreimage), + []byte(initMessage), + membership, + dkl, + ) + + if !reflect.DeepEqual(receivedGroup, expectedGroup) { + t.Errorf("NewGroup did not return the expected Group."+ + "\nexpected: %#v\nreceived: %#v", expectedGroup, receivedGroup) + } +} + +// Unit test of Group.DeepCopy. +func TestGroup_DeepCopy(t *testing.T) { + grp := createTestGroup(rand.New(rand.NewSource(42)), t) + + newGrp := grp.DeepCopy() + + if !reflect.DeepEqual(grp, newGrp) { + t.Errorf("DeepCopy did not return a copy of the original Group."+ + "\nexpected: %#v\nreceived: %#v", grp, newGrp) + } + + if &grp.Name[0] == &newGrp.Name[0] { + t.Errorf("DeepCopy returned a copy of the pointer of Name."+ + "\nexpected: %p\nreceived: %p", &grp.Name[0], &newGrp.Name[0]) + } + + if &grp.ID[0] == &newGrp.ID[0] { + t.Errorf("DeepCopy returned a copy of the pointer of ID."+ + "\nexpected: %p\nreceived: %p", &grp.ID[0], &newGrp.ID[0]) + } + + if &grp.Key[0] == &newGrp.Key[0] { + t.Errorf("DeepCopy returned a copy of the pointer of Key."+ + "\nexpected: %p\nreceived: %p", &grp.Key[0], &newGrp.Key[0]) + } + + if &grp.IdPreimage[0] == &newGrp.IdPreimage[0] { + t.Errorf("DeepCopy returned a copy of the pointer of IdPreimage."+ + "\nexpected: %p\nreceived: %p", &grp.IdPreimage[0], &newGrp.IdPreimage[0]) + } + + if &grp.KeyPreimage[0] == &newGrp.KeyPreimage[0] { + t.Errorf("DeepCopy returned a copy of the pointer of KeyPreimage."+ + "\nexpected: %p\nreceived: %p", &grp.KeyPreimage[0], &newGrp.KeyPreimage[0]) + } + + if &grp.InitMessage[0] == &newGrp.InitMessage[0] { + t.Errorf("DeepCopy returned a copy of the pointer of InitMessage."+ + "\nexpected: %p\nreceived: %p", &grp.InitMessage[0], &newGrp.InitMessage[0]) + } + + if &grp.Members[0] == &newGrp.Members[0] { + t.Errorf("DeepCopy returned a copy of the pointer of Members."+ + "\nexpected: %p\nreceived: %p", &grp.Members[0], &newGrp.Members[0]) + } +} + +// Unit test of Group.store. +func TestGroup_store(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + g := createTestGroup(rand.New(rand.NewSource(42)), t) + + err := g.store(kv) + if err != nil { + t.Errorf("store returned an error: %+v", err) + } + + obj, err := kv.Get(groupStoreKey(g.ID), groupStoreVersion) + if err != nil { + t.Errorf("Failed to get group from storage: %+v", err) + } + + newGrp, err := DeserializeGroup(obj.Data) + if err != nil { + t.Errorf("Failed to deserialize group: %+v", err) + } + + if !reflect.DeepEqual(g, newGrp) { + t.Errorf("Failed to read correct group from storage."+ + "\nexpected: %#v\nreceived: %#v", g, newGrp) + } +} + +// Unit test of Group.loadGroup. +func Test_loadGroup(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + g := createTestGroup(rand.New(rand.NewSource(42)), t) + + err := g.store(kv) + if err != nil { + t.Errorf("store returned an error: %+v", err) + } + + newGrp, err := loadGroup(g.ID, kv) + if err != nil { + t.Errorf("loadGroup returned an error: %+v", err) + } + + if !reflect.DeepEqual(g, newGrp) { + t.Errorf("loadGroup failed to return the expected group."+ + "\nexpected: %#v\nreceived: %#v", g, newGrp) + } +} + +// Error path: an error is returned when no group with the ID exists in storage. +func Test_loadGroup_InvalidGroupIdError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + g := createTestGroup(rand.New(rand.NewSource(42)), t) + expectedErr := strings.SplitN(kvGetGroupErr, "%", 2)[0] + + _, err := loadGroup(g.ID, kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadGroup failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Unit test of Group.removeGroup. +func Test_removeGroup(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + g := createTestGroup(rand.New(rand.NewSource(42)), t) + + err := g.store(kv) + if err != nil { + t.Errorf("store returned an error: %+v", err) + } + + err = removeGroup(g.ID, kv) + if err != nil { + t.Errorf("removeGroup returned an error: %+v", err) + } + + foundGrp, err := loadGroup(g.ID, kv) + if err == nil { + t.Errorf("loadGroup found group that should have been removed: %#v", + foundGrp) + } +} + +// Tests that a group that is serialized and deserialized matches the original. +func TestGroup_Serialize_DeserializeGroup(t *testing.T) { + grp := createTestGroup(rand.New(rand.NewSource(42)), t) + + grpBytes := grp.Serialize() + + newGrp, err := DeserializeGroup(grpBytes) + if err != nil { + t.Errorf("DeserializeGroup returned an error: %+v", err) + } + + if !reflect.DeepEqual(grp, newGrp) { + t.Errorf("Deserialized group does not match original."+ + "\nexpected: %#v\nreceived: %#v", grp, newGrp) + } +} + +// Tests that a group with nil fields that is serialized and deserialized +// matches the original. +func TestGroup_Serialize_DeserializeGroup_NilGroup(t *testing.T) { + grp := Group{Members: make(group.Membership, 3)} + + grpBytes := grp.Serialize() + + newGrp, err := DeserializeGroup(grpBytes) + if err != nil { + t.Errorf("DeserializeGroup returned an error: %+v", err) + } + + if !reflect.DeepEqual(grp, newGrp) { + t.Errorf("Deserialized group does not match original."+ + "\nexpected: %#v\nreceived: %#v", grp, newGrp) + } +} + +// Error path: error returned when the group membership is too small. +func TestDeserializeGroup_DeserializeMembershipError(t *testing.T) { + grp := Group{} + grpBytes := grp.Serialize() + expectedErr := strings.SplitN(membershipErr, "%", 2)[0] + + _, err := DeserializeGroup(grpBytes) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("DeserializeGroup failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +func Test_groupStoreKey(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + expectedKeys := []string{ + "GroupChat/U4x/lrFkvxuXu59LtHLon1sUhPJSCcnZND6SugndnVID", + "GroupChat/15tNdkKbYXoMn58NO6VbDMDWFEyIhTWEGsvgcJsHWAgD", + "GroupChat/YdN1vAK0HfT5GSnhj9qeb4LlTnSOgeeeS71v40zcuoQD", + "GroupChat/6NY+jE/+HOvqVG2PrBPdGqwEzi6ih3xVec+ix44bC68D", + "GroupChat/iBuCp1EQikLtPJA8qkNGWnhiBhaXiu0M48bE8657w+AD", + "GroupChat/W1cS/v2+DBAoh+EA2s0tiF9pLLYH2gChHBxwceeWotwD", + "GroupChat/wlpbdLLhKXBeJz8FySMmgo4rBW44F2WOEGFJiUf980QD", + "GroupChat/DtTBFgI/qONXa2/tJ/+JdLrAyv2a0FaSsTYZ5ziWTf0D", + "GroupChat/no1TQ3NmHP1m10/sHhuJSRq3I25LdSFikM8r60LDyicD", + "GroupChat/hWDxqsBnzqbov0bUqytGgEAsX7KCDohdMmDx3peCg9QD", + } + for i, expected := range expectedKeys { + newID, _ := id.NewRandomID(prng, id.User) + + key := groupStoreKey(newID) + + if key != expected { + t.Errorf("groupStoreKey did not return the expected key (%d)."+ + "\nexpected: %s\nreceived: %s", i, expected, key) + } + + // fmt.Printf("\"%s\",\n", key) + } +} + +// Unit test of Group.GoString. +func TestGroup_GoString(t *testing.T) { + grp := createTestGroup(rand.New(rand.NewSource(42)), t) + expected := "{Name:\"groupName\", ID:ISTkX+tNhfEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAE, Key:a2V5AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=, IdPreimage:aWRQcmVpbWFnZQAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=, KeyPreimage:a2V5UHJlaW1hZ2UAAAAAAAAAAAAAAAAAAAAAAAAAAAA=, InitMessage:\"initMessage\", Members:{Leader: {U4x/lrFkvxuXu59LtHLon1sUhPJSCcnZND6SugndnVID, 3534334367... in GRP: 6SsQ/HAHUn...}, Participants: 0: {Grcjbkt1IWKQzyvrQsPKJzKFYPGqwGfOpui/RtSrK0YD, 5274380952... in GRP: 6SsQ/HAHUn...}, 1: {QCxg8d6XgoPUoJo2+WwglBdG4+1NpkaprotPp7T8OiAD, 1628829379... in GRP: 6SsQ/HAHUn...}, 2: {invD4ElbVxL+/b4MECiH4QDazS2IX2kstgfaAKEcHHAD, 4157513341... in GRP: 6SsQ/HAHUn...}, 3: {wRYCP6iJdLrAyv2a0FaSsTYZ5ziWTf3Hno1TQ3NmHP0D, 5785305945... in GRP: 6SsQ/HAHUn...}, 4: {15ufnw07pVsMwNYUTIiFNYQay+BwmwdYCD9h03W8ArQD, 2010156224... in GRP: 6SsQ/HAHUn...}, 5: {3RqsBM4ux44bC6+uiBuCp1EQikLtPJA8qkNGWnhiBhYD, 2643318057... in GRP: 6SsQ/HAHUn...}, 6: {55ai4SlwXic/BckjJoKOKwVuOBdljhBhSYlH/fNEQQ4D, 6482807720... in GRP: 6SsQ/HAHUn...}, 7: {9PkZKU50joHnnku9b+NM3LqEPujWPoxP/hzr6lRtj6wD, 6603068123... in GRP: 6SsQ/HAHUn...}, 8: {+hp17fHP0rO1EhnqeVM6v0SNLEedMmB1M5BZFMjMHPAD, 2628757933... in GRP: 6SsQ/HAHUn...}}, DhKeys:{Grcjbkt1IWKQzyvrQsPKJzKFYPGqwGfOpui/RtSrK0YD: 5170411903... in GRP: 6SsQ/HAHUn..., QCxg8d6XgoPUoJo2+WwglBdG4+1NpkaprotPp7T8OiAD: 1754900790... in GRP: 6SsQ/HAHUn..., invD4ElbVxL+/b4MECiH4QDazS2IX2kstgfaAKEcHHAD: 2926033432... in GRP: 6SsQ/HAHUn..., wRYCP6iJdLrAyv2a0FaSsTYZ5ziWTf3Hno1TQ3NmHP0D: 2297312580... in GRP: 6SsQ/HAHUn..., 15ufnw07pVsMwNYUTIiFNYQay+BwmwdYCD9h03W8ArQD: 6199513233... in GRP: 6SsQ/HAHUn..., 3RqsBM4ux44bC6+uiBuCp1EQikLtPJA8qkNGWnhiBhYD: 4604475835... in GRP: 6SsQ/HAHUn..., 55ai4SlwXic/BckjJoKOKwVuOBdljhBhSYlH/fNEQQ4D: 9940605492... in GRP: 6SsQ/HAHUn..., 9PkZKU50joHnnku9b+NM3LqEPujWPoxP/hzr6lRtj6wD: 2451667393... in GRP: 6SsQ/HAHUn..., +hp17fHP0rO1EhnqeVM6v0SNLEedMmB1M5BZFMjMHPAD: 6029441980... in GRP: 6SsQ/HAHUn...}}" + + if grp.GoString() != expected { + t.Errorf("GoString failed to return the expected string."+ + "\nexpected: %s\nreceived: %s", expected, grp.GoString()) + } +} + +// Test that Group.GoString returns the expected string for a nil group. +func TestGroup_GoString_NilGroup(t *testing.T) { + grp := Group{} + expected := "{" + + "Name:\"\", " + + "ID:<nil>, " + + "Key:AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=, " + + "IdPreimage:AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=, " + + "KeyPreimage:AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=, " + + "InitMessage:\"\", " + + "Members:{<nil>}, " + + "DhKeys:{}" + + "}" + + if grp.GoString() != expected { + t.Errorf("GoString failed to return the expected string."+ + "\nexpected: %s\nreceived: %s", expected, grp.GoString()) + } +} diff --git a/groupChat/groupStore/store.go b/groupChat/groupStore/store.go new file mode 100644 index 0000000000000000000000000000000000000000..88e980603e323114b2a0c39a138f0e2977627053 --- /dev/null +++ b/groupChat/groupStore/store.go @@ -0,0 +1,302 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupStore + +import ( + "bytes" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "sync" + "testing" +) + +// Storage values. +const ( + // Key used to identify the list of Groups in storage. + groupStoragePrefix = "GroupChatListStore" + groupListStorageKey = "GroupChatList" + groupListVersion = 0 +) + +// Error messages. +const ( + kvGetGroupListErr = "failed to get list of group IDs from storage: %+v" + groupLoadErr = "failed to load group %d/%d: %+v" + groupSaveErr = "failed to save group %s to storage: %+v" + maxGroupsErr = "failed to add new group, max number of groups (%d) reached" + groupExistsErr = "group with ID %s already exists" + groupRemoveErr = "failed to remove group with ID %s, group not found in memory" + saveListRemoveErr = "failed to save new group ID list after removing group %s" + setUserPanic = "Store.SetUser is for testing only. Got %T" +) + +// The maximum number of group chats that a user can be a part of at once. +const MaxGroupChats = 64 + +// Store stores the list of Groups that a user is a part of. +type Store struct { + list map[id.ID]Group + user group.Member + kv *versioned.KV + mux sync.RWMutex +} + +// NewStore constructs a new Store object for the user and saves it to storage. +func NewStore(kv *versioned.KV, user group.Member) (*Store, error) { + s := &Store{ + list: make(map[id.ID]Group), + user: user.DeepCopy(), + kv: kv.Prefix(groupStoragePrefix), + } + + return s, s.save() +} + +// NewOrLoadStore loads the group store from storage or makes a new one if it +// does not exist. +func NewOrLoadStore(kv *versioned.KV, user group.Member) (*Store, error) { + prefixKv := kv.Prefix(groupStoragePrefix) + + // Load the list of group IDs from file if they exist + vo, err := prefixKv.Get(groupListStorageKey, groupListVersion) + if err == nil { + return loadStore(vo.Data, prefixKv, user) + } + + // If there is no group list saved, then make a new one + return NewStore(kv, user) +} + +// LoadStore loads all the Groups from storage into memory and return them in +// a Store object. +func LoadStore(kv *versioned.KV, user group.Member) (*Store, error) { + kv = kv.Prefix(groupStoragePrefix) + + // Load the list of group IDs from file + vo, err := kv.Get(groupListStorageKey, groupListVersion) + if err != nil { + return nil, errors.Errorf(kvGetGroupListErr, err) + } + + return loadStore(vo.Data, kv, user) +} + +// loadStore builds the list of group IDs and loads the groups from storage. +func loadStore(data []byte, kv *versioned.KV, user group.Member) (*Store, error) { + // Deserialize list of group IDs + groupIDs := deserializeGroupIdList(data) + + // Initialize the Store + s := &Store{ + list: make(map[id.ID]Group, len(groupIDs)), + user: user.DeepCopy(), + kv: kv, + } + + // Load each Group from storage into the map + for i, grpID := range groupIDs { + grp, err := loadGroup(grpID, kv) + if err != nil { + return nil, errors.Errorf(groupLoadErr, i, len(grpID), err) + } + s.list[*grpID] = grp + } + + return s, nil +} + +// saveGroupList saves a list of group IDs to storage. +func (s *Store) saveGroupList() error { + // Create the versioned object + obj := &versioned.Object{ + Version: groupListVersion, + Timestamp: netTime.Now(), + Data: serializeGroupIdList(s.list), + } + + // Save to storage + return s.kv.Set(groupListStorageKey, groupListVersion, obj) +} + +// serializeGroupIdList serializes the list of group IDs. +func serializeGroupIdList(list map[id.ID]Group) []byte { + buff := bytes.NewBuffer(nil) + buff.Grow(id.ArrIDLen * len(list)) + + // Create list of IDs from map + for grpId := range list { + buff.Write(grpId.Marshal()) + } + + return buff.Bytes() +} + +// deserializeGroupIdList deserializes data into a list of group IDs. +func deserializeGroupIdList(data []byte) []*id.ID { + idLen := id.ArrIDLen + groupIDs := make([]*id.ID, 0, len(data)/idLen) + buff := bytes.NewBuffer(data) + + // Copy each set of data into a new ID and append to list + for n := buff.Next(idLen); len(n) == idLen; n = buff.Next(idLen) { + var newID id.ID + copy(newID[:], n) + groupIDs = append(groupIDs, &newID) + } + + return groupIDs +} + +// save saves the group ID list and each group individually to storage. +func (s *Store) save() error { + // Store group ID list + err := s.saveGroupList() + if err != nil { + return err + } + + // Store individual groups + for grpID, grp := range s.list { + if err := grp.store(s.kv); err != nil { + return errors.Errorf(groupSaveErr, grpID, err) + } + } + + return nil +} + +// Len returns the number of groups stored. +func (s *Store) Len() int { + s.mux.RLock() + defer s.mux.RUnlock() + + return len(s.list) +} + +// Add adds a new group to the group list and saves it to storage. An error is +// returned if the user has the max number of groups (MaxGroupChats). +func (s *Store) Add(g Group) error { + s.mux.Lock() + defer s.mux.Unlock() + + // Check if the group list is full. + if len(s.list) >= MaxGroupChats { + return errors.Errorf(maxGroupsErr, MaxGroupChats) + } + + // Return an error if the group already exists in the map + if _, exists := s.list[*g.ID]; exists { + return errors.Errorf(groupExistsErr, g.ID) + } + + // Add the group to the map + s.list[*g.ID] = g.DeepCopy() + + // Update the group list in storage + err := s.saveGroupList() + if err != nil { + return err + } + + // Store the group to storage + return g.store(s.kv) +} + +// Remove removes the group with the corresponding ID from memory and storage. +// An error is returned if the group cannot be found in memory or storage. +func (s *Store) Remove(groupID *id.ID) error { + s.mux.Lock() + defer s.mux.Unlock() + + // Exit if the Group does not exist in memory + if _, exists := s.list[*groupID]; !exists { + return errors.Errorf(groupRemoveErr, groupID) + } + + // Delete Group from memory + delete(s.list, *groupID) + + // Remove group ID from list in memory + err := s.saveGroupList() + if err != nil { + return errors.Errorf(saveListRemoveErr, groupID) + } + + // Delete Group from storage + return removeGroup(groupID, s.kv) +} + +// GroupIDs returns a list of all group IDs. +func (s *Store) GroupIDs() []*id.ID { + s.mux.RLock() + defer s.mux.RUnlock() + + idList := make([]*id.ID, 0, len(s.list)) + for gid := range s.list { + idList = append(idList, gid.DeepCopy()) + } + + return idList +} + +// Get returns the Group for the given group ID. Returns false if no Group is +// found. +func (s *Store) Get(groupID *id.ID) (Group, bool) { + s.mux.RLock() + defer s.mux.RUnlock() + + grp, exists := s.list[*groupID] + if !exists { + return Group{}, false + } + + return grp.DeepCopy(), exists +} + +// GetByKeyFp returns the group with the matching key fingerprint and salt. +// Returns false if no group is found. +func (s *Store) GetByKeyFp(keyFp format.Fingerprint, salt [group.SaltLen]byte) (Group, bool) { + s.mux.RLock() + defer s.mux.RUnlock() + + // Iterate through each group to check if the key fingerprint matches + for _, grp := range s.list { + if group.CheckKeyFingerprint(keyFp, grp.Key, salt, s.user.ID) { + return grp.DeepCopy(), true + } + } + + return Group{}, false +} + +// GetUser returns the group member for the current user. +func (s *Store) GetUser() group.Member { + s.mux.RLock() + defer s.mux.RUnlock() + return s.user.DeepCopy() +} + +// SetUser allows a user to be set. This function is for testing purposes only. +// It panics if the interface is not of a testing type. +func (s *Store) SetUser(user group.Member, x interface{}) { + switch x.(type) { + case *testing.T, *testing.M, *testing.B, *testing.PB: + break + default: + jww.FATAL.Panicf(setUserPanic, x) + } + + s.mux.Lock() + defer s.mux.Unlock() + s.user = user.DeepCopy() +} diff --git a/groupChat/groupStore/store_test.go b/groupChat/groupStore/store_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0e0917ff38ce33b83cbecbfafb4f960ea8c38447 --- /dev/null +++ b/groupChat/groupStore/store_test.go @@ -0,0 +1,576 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupStore + +import ( + "bytes" + "fmt" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/primitives/id" + "math/rand" + "reflect" + "sort" + "strings" + "testing" +) + +// Unit test of NewStore. +func TestNewStore(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + user := randMember(prng) + + expectedStore := &Store{ + list: make(map[id.ID]Group), + user: user, + kv: kv.Prefix(groupStoragePrefix), + } + + store, err := NewStore(kv, user) + if err != nil { + t.Fatalf("NewStore returned an error: %+v", err) + } + + // Compare manually created object with NewUnknownRoundsStore + if !reflect.DeepEqual(expectedStore, store) { + t.Errorf("NewStore returned incorrect Store."+ + "\nexpected: %+v\nreceived: %+v", expectedStore, store) + } + + // Add information in store + testGroup := createTestGroup(prng, t) + + store.list[*testGroup.ID] = testGroup + + if err := store.save(); err != nil { + t.Fatalf("save() could not write to disk: %+v", err) + } + + groupIds := make([]id.ID, 0, len(store.list)) + for grpId := range store.list { + groupIds = append(groupIds, grpId) + } + + // Check that stored group Id list is expected value + expectedData := serializeGroupIdList(store.list) + + obj, err := store.kv.Get(groupListStorageKey, groupListVersion) + if err != nil { + t.Errorf("Could not get group list: %+v", err) + } + + // Check that the stored data is the data outputted by marshal + if !bytes.Equal(expectedData, obj.Data) { + t.Errorf("NewStore() returned incorrect Store."+ + "\nexpected: %+v\nreceived: %+v", expectedData, obj.Data) + } + + obj, err = store.kv.Get(groupStoreKey(testGroup.ID), groupListVersion) + if err != nil { + t.Errorf("Could not get group: %+v", err) + } + + newGrp, err := DeserializeGroup(obj.Data) + if err != nil { + t.Errorf("Failed to deserialize group: %+v", err) + } + + if !reflect.DeepEqual(testGroup, newGrp) { + t.Errorf("NewStore() returned incorrect Store."+ + "\nexpected: %#v\nreceived: %#v", testGroup, newGrp) + } +} + +func TestNewOrLoadStore(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + user := randMember(prng) + + store, err := NewOrLoadStore(kv, user) + if err != nil { + t.Fatalf("Failed to create new store: %+v", err) + } + + // Add group to store + testGroup := createTestGroup(prng, t) + if err = store.Add(testGroup); err != nil { + t.Fatalf("Failed to add test group: %+v", err) + } + + // Load the store from kv + receivedStore, err := NewOrLoadStore(kv, user) + if err != nil { + t.Fatalf("LoadStore returned an error: %+v", err) + } + + // Check that state in loaded store matches store that was saved + if len(receivedStore.list) != len(store.list) { + t.Errorf("LoadStore returned Store with incorrect number of groups."+ + "\nexpected len: %d\nreceived len: %d", + len(store.list), len(receivedStore.list)) + } + + if _, exists := receivedStore.list[*testGroup.ID]; !exists { + t.Fatalf("Failed to get group from loaded group map."+ + "\nexpected: %#v\nreceived: %#v", testGroup, receivedStore.list) + } +} + +// Unit test of LoadStore. +func TestLoadStore(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + user := randMember(prng) + + store, err := NewStore(kv, user) + if err != nil { + t.Fatalf("Failed to create new store: %+v", err) + } + + // Add group to store + testGroup := createTestGroup(prng, t) + if err = store.Add(testGroup); err != nil { + t.Fatalf("Failed to add test group: %+v", err) + } + + // Load the store from kv + receivedStore, err := LoadStore(kv, user) + if err != nil { + t.Fatalf("LoadStore returned an error: %+v", err) + } + + // Check that state in loaded store matches store that was saved + if len(receivedStore.list) != len(store.list) { + t.Errorf("LoadStore returned Store with incorrect number of groups."+ + "\nexpected len: %d\nreceived len: %d", + len(store.list), len(receivedStore.list)) + } + + if _, exists := receivedStore.list[*testGroup.ID]; !exists { + t.Fatalf("Failed to get group from loaded group map."+ + "\nexpected: %#v\nreceived: %#v", testGroup, receivedStore.list) + } +} + +// Error path: show that LoadStore returns an error when no group store can be +// found in storage. +func TestLoadStore_GetError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + user := randMember(rand.New(rand.NewSource(42))) + expectedErr := strings.SplitN(kvGetGroupListErr, "%", 2)[0] + + // Load the store from kv + _, err := LoadStore(kv, user) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("LoadStore did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: show that loadStore returns an error when no group can be found +// in storage. +func Test_loadStore_GetGroupError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + user := randMember(rand.New(rand.NewSource(42))) + var idList []byte + for i := 0; i < 10; i++ { + idList = append(idList, id.NewIdFromUInt(uint64(i), id.Group, t).Marshal()...) + } + expectedErr := strings.SplitN(groupLoadErr, "%", 2)[0] + + // Load the groups from kv + _, err := loadStore(idList, kv, user) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadStore did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } + +} + +// Tests that a map of groups can be serialized and deserialized into a list +// that has the same group IDs. +func Test_serializeGroupIdList_deserializeGroupIdList(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + n := 10 + testMap := make(map[id.ID]Group, n) + expected := make([]*id.ID, n) + for i := 0; i < n; i++ { + grp := createTestGroup(prng, t) + expected[i] = grp.ID + testMap[*grp.ID] = grp + } + + // Serialize and deserialize map + data := serializeGroupIdList(testMap) + newList := deserializeGroupIdList(data) + + // Sort expected and received lists so they are in the same order + sort.Slice(expected, func(i, j int) bool { + return bytes.Compare(expected[i].Bytes(), expected[j].Bytes()) == -1 + }) + sort.Slice(newList, func(i, j int) bool { + return bytes.Compare(newList[i].Bytes(), newList[j].Bytes()) == -1 + }) + + // Check if they match + if !reflect.DeepEqual(expected, newList) { + t.Errorf("Failed to serialize and deserilize group map into list."+ + "\nexpected: %+v\nreceived: %+v", expected, newList) + } +} + +// Unit test of Store.Len. +func TestStore_Len(t *testing.T) { + s := Store{list: make(map[id.ID]Group)} + + if s.Len() != 0 { + t.Errorf("Len returned the wrong length.\nexpected: %d\nreceived: %d", + 0, s.Len()) + } + + n := 10 + for i := 0; i < n; i++ { + s.list[*id.NewIdFromUInt(uint64(i), id.Group, t)] = Group{} + } + + if s.Len() != n { + t.Errorf("Len returned the wrong length.\nexpected: %d\nreceived: %d", + n, s.Len()) + } +} + +// Unit test of Store.Add. +func TestStore_Add(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + user := randMember(prng) + + store, err := NewStore(kv, user) + if err != nil { + t.Fatalf("Failed to create store: %+v", err) + } + + // Add maximum number of groups allowed + for i := 0; i < MaxGroupChats; i++ { + // Add group to store + grp := createTestGroup(prng, t) + err = store.Add(grp) + if err != nil { + t.Errorf("Add returned an error (%d): %v", i, err) + } + + if _, exists := store.list[*grp.ID]; !exists { + t.Errorf("Group %s was not added to the map (%d)", grp.ID, i) + } + } + + if len(store.list) != MaxGroupChats { + t.Errorf("Length of group map does not match number of groups added."+ + "\nexpected: %d\nreceived: %d", MaxGroupChats, len(store.list)) + } +} + +// Error path: shows that an error is returned when trying to add too many +// groups. +func TestStore_Add_MapFullError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + user := randMember(prng) + expectedErr := strings.SplitN(maxGroupsErr, "%", 2)[0] + + store, err := NewStore(kv, user) + if err != nil { + t.Fatalf("Failed to create store: %+v", err) + } + + // Add maximum number of groups allowed + for i := 0; i < MaxGroupChats; i++ { + err = store.Add(createTestGroup(prng, t)) + if err != nil { + t.Errorf("Add returned an error (%d): %v", i, err) + } + } + + err = store.Add(createTestGroup(prng, t)) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("Add did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: show Store.Add returns an error when attempting to add a group +// that is already in the map. +func TestStore_Add_GroupExistsError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + user := randMember(prng) + expectedErr := strings.SplitN(groupExistsErr, "%", 2)[0] + + store, err := NewStore(kv, user) + if err != nil { + t.Fatalf("Failed to create store: %+v", err) + } + + grp := createTestGroup(prng, t) + err = store.Add(grp) + if err != nil { + t.Errorf("Add returned an error: %+v", err) + } + + err = store.Add(grp) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("Add did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Unit test of Store.Remove. +func TestStore_Remove(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + user := randMember(prng) + + store, err := NewStore(kv, user) + if err != nil { + t.Fatalf("Failed to create store: %+v", err) + } + + // Add maximum number of groups allowed + groups := make([]Group, MaxGroupChats) + for i := 0; i < MaxGroupChats; i++ { + groups[i] = createTestGroup(prng, t) + if err = store.Add(groups[i]); err != nil { + t.Errorf("Failed to add group (%d): %v", i, err) + } + } + + // Remove all groups + for i, grp := range groups { + err = store.Remove(grp.ID) + if err != nil { + t.Errorf("Remove returned an error (%d): %+v", i, err) + } + + if _, exists := store.list[*grp.ID]; exists { + t.Fatalf("Group %s still exists in map (%d).", grp.ID, i) + } + } + + // Check that the list is empty now + if len(store.list) != 0 { + t.Fatalf("Remove failed to remove all groups.."+ + "\nexpected: %d\nreceived: %d", 0, len(store.list)) + } +} + +// Error path: shows that Store.Remove returns an error when no group with the +// given ID is found in the map. +func TestStore_Remove_RemoveGroupNotInMemoryError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + user := randMember(prng) + expectedErr := strings.SplitN(groupRemoveErr, "%", 2)[0] + + store, err := NewStore(kv, user) + if err != nil { + t.Fatalf("Failed to create store: %+v", err) + } + + grp := createTestGroup(prng, t) + err = store.Remove(grp.ID) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("Remove did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Unit test of Store.GroupIDs. +func TestStore_GroupIDs(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + n := 10 + store := Store{list: make(map[id.ID]Group, n)} + expected := make([]*id.ID, n) + for i := 0; i < n; i++ { + grp := createTestGroup(prng, t) + expected[i] = grp.ID + store.list[*grp.ID] = grp + } + + newList := store.GroupIDs() + + // Sort expected and received lists so they are in the same order + sort.Slice(expected, func(i, j int) bool { + return bytes.Compare(expected[i].Bytes(), expected[j].Bytes()) == -1 + }) + sort.Slice(newList, func(i, j int) bool { + return bytes.Compare(newList[i].Bytes(), newList[j].Bytes()) == -1 + }) + + // Check if they match + if !reflect.DeepEqual(expected, newList) { + t.Errorf("GroupIDs did not return the expected list."+ + "\nexpected: %+v\nreceived: %+v", expected, newList) + } +} + +// Unit test of Store.Get. +func TestStore_Get(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + user := randMember(prng) + + store, err := NewStore(kv, user) + if err != nil { + t.Fatalf("Failed to make new Store: %+v", err) + } + + // Add group to store + grp := createTestGroup(prng, t) + if err = store.Add(grp); err != nil { + t.Errorf("Failed to add group to store: %+v", err) + } + + // Attempt to get group + retrieved, exists := store.Get(grp.ID) + if !exists { + t.Errorf("Get failed to return the expected group: %#v", grp) + } + + if !reflect.DeepEqual(grp, retrieved) { + t.Errorf("Get did not return the expected group."+ + "\nexpected: %#v\nreceived: %#v", grp, retrieved) + } +} + +// Error path: shows that Store.Get return false if no group is found. +func TestStore_Get_NoGroupError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + user := randMember(rand.New(rand.NewSource(42))) + + store, err := NewStore(kv, user) + if err != nil { + t.Fatalf("Failed to make new Store: %+v", err) + } + + // Attempt to get group + retrieved, exists := store.Get(id.NewIdFromString("testID", id.Group, t)) + if exists { + t.Errorf("Get returned a group that should not exist: %#v", retrieved) + } +} + +// Unit test of Store.GetByKeyFp. +func TestStore_GetByKeyFp(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + user := randMember(prng) + + store, err := NewStore(kv, user) + if err != nil { + t.Fatalf("Failed to make new Store: %+v", err) + } + + // Add group to store + grp := createTestGroup(prng, t) + if err = store.Add(grp); err != nil { + t.Fatalf("Failed to add group: %+v", err) + } + + // Get group by fingerprint + salt := newSalt(groupSalt) + generatedFP := group.NewKeyFingerprint(grp.Key, salt, store.user.ID) + retrieved, exists := store.GetByKeyFp(generatedFP, salt) + if !exists { + t.Errorf("GetByKeyFp failed to find a group with the matching key "+ + "fingerprint: %#v", grp) + } + + // check that retrieved value match + if !reflect.DeepEqual(grp, retrieved) { + t.Errorf("GetByKeyFp failed to return the expected group."+ + "\nexpected: %#v\nreceived: %#v", grp, retrieved) + } +} + +// Error path: shows that Store.GetByKeyFp return false if no group is found. +func TestStore_GetByKeyFp_NoGroupError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + user := randMember(prng) + + store, err := NewStore(kv, user) + if err != nil { + t.Fatalf("Failed to make new Store: %+v", err) + } + + // Get group by fingerprint + grp := createTestGroup(prng, t) + salt := newSalt(groupSalt) + generatedFP := group.NewKeyFingerprint(grp.Key, salt, store.user.ID) + retrieved, exists := store.GetByKeyFp(generatedFP, salt) + if exists { + t.Errorf("GetByKeyFp found a group when none should exist: %#v", + retrieved) + } +} + +// Unit test of Store.GetUser. +func TestStore_GetUser(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + user := randMember(rand.New(rand.NewSource(42))) + + store, err := NewStore(kv, user) + if err != nil { + t.Fatalf("Failed to make new Store: %+v", err) + } + + if !user.Equal(store.GetUser()) { + t.Errorf("GetUser() failed to return the expected member."+ + "\nexpected: %#v\nreceived: %#v", user, store.GetUser()) + } +} + +// Unit test of Store.SetUser. +func TestStore_SetUser(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + prng := rand.New(rand.NewSource(42)) + oldUser := randMember(prng) + newUser := randMember(prng) + + store, err := NewStore(kv, oldUser) + if err != nil { + t.Fatalf("Failed to make new Store: %+v", err) + } + + store.SetUser(newUser, t) + + if !newUser.Equal(store.user) { + t.Errorf("SetUser() failed to set the correct user."+ + "\nexpected: %#v\nreceived: %#v", newUser, store.user) + } +} + +// Panic path: show that Store.SetUser panics when the interface is not of a +// testing type. +func TestStore_SetUser_NonTestingInterfacePanic(t *testing.T) { + user := randMember(rand.New(rand.NewSource(42))) + store := &Store{} + nonTestingInterface := struct{}{} + expectedErr := fmt.Sprintf(setUserPanic, nonTestingInterface) + + defer func() { + if r := recover(); r == nil || r.(string) != expectedErr { + t.Errorf("SetUser failed to panic with the expected message."+ + "\nexpected: %s\nreceived: %+v", expectedErr, r) + } + }() + + store.SetUser(user, nonTestingInterface) +} diff --git a/groupChat/groupStore/utils_test.go b/groupChat/groupStore/utils_test.go new file mode 100644 index 0000000000000000000000000000000000000000..9a6460800498b11fdbcb8a4b9b5aa7d4e391b2f3 --- /dev/null +++ b/groupChat/groupStore/utils_test.go @@ -0,0 +1,139 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupStore + +import ( + "gitlab.com/elixxir/crypto/contact" + "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/xx_network/crypto/large" + "gitlab.com/xx_network/primitives/id" + "math/rand" + "testing" +) + +const ( + groupName = "groupName" + groupSalt = "salt" + groupKey = "key" + groupIdPreimage = "idPreimage" + groupKeyPreimage = "keyPreimage" + initMessage = "initMessage" +) + +// createTestGroup generates a new group for testing. +func createTestGroup(rng *rand.Rand, t *testing.T) Group { + members := createMembership(rng, 10, t) + dkl := GenerateDhKeyList(members[0].ID, randCycInt(rng), members, getGroup()) + return NewGroup( + []byte(groupName), + id.NewIdFromUInt(rng.Uint64(), id.Group, t), + newKey(groupKey), + newIdPreimage(groupIdPreimage), + newKeyPreimage(groupKeyPreimage), + []byte(initMessage), + members, + dkl, + ) +} + +// createMembership creates a new membership with the specified number of +// randomly generated members. +func createMembership(rng *rand.Rand, size int, t *testing.T) group.Membership { + contacts := make([]contact.Contact, size) + for i := range contacts { + contacts[i] = randContact(rng) + } + + membership, err := group.NewMembership(contacts[0], contacts[1:]...) + if err != nil { + t.Errorf("Failed to create new membership: %+v", err) + } + + return membership +} + +// createDhKeyList creates a new DhKeyList with the specified number of randomly +// generated members. +func createDhKeyList(rng *rand.Rand, size int, _ *testing.T) DhKeyList { + dkl := make(DhKeyList, size) + for i := 0; i < size; i++ { + dkl[*randID(rng, id.User)] = randCycInt(rng) + } + + return dkl +} + +// randMember returns a Member with a random ID and DH public key. +func randMember(rng *rand.Rand) group.Member { + return group.Member{ + ID: randID(rng, id.User), + DhKey: randCycInt(rng), + } +} + +// randContact returns a contact with a random ID and DH public key. +func randContact(rng *rand.Rand) contact.Contact { + return contact.Contact{ + ID: randID(rng, id.User), + DhPubKey: randCycInt(rng), + } +} + +// randID returns a new random ID of the specified type. +func randID(rng *rand.Rand, t id.Type) *id.ID { + newID, _ := id.NewRandomID(rng, t) + return newID +} + +// randCycInt returns a random cyclic int. +func randCycInt(rng *rand.Rand) *cyclic.Int { + return getGroup().NewInt(rng.Int63()) +} + +func getGroup() *cyclic.Group { + return cyclic.NewGroup( + large.NewIntFromString("E2EE983D031DC1DB6F1A7A67DF0E9A8E5561DB8E8D4941"+ + "3394C049B7A8ACCEDC298708F121951D9CF920EC5D146727AA4AE535B0922C688"+ + "B55B3DD2AEDF6C01C94764DAB937935AA83BE36E67760713AB44A6337C20E7861"+ + "575E745D31F8B9E9AD8412118C62A3E2E29DF46B0864D0C951C394A5CBBDC6ADC"+ + "718DD2A3E041023DBB5AB23EBB4742DE9C1687B5B34FA48C3521632C4A530E8FF"+ + "B1BC51DADDF453B0B2717C2BC6669ED76B4BDD5C9FF558E88F26E5785302BEDBC"+ + "A23EAC5ACE92096EE8A60642FB61E8F3D24990B8CB12EE448EEF78E184C7242DD"+ + "161C7738F32BF29A841698978825B4111B4BC3E1E198455095958333D776D8B2B"+ + "EEED3A1A1A221A6E37E664A64B83981C46FFDDC1A45E3D5211AAF8BFBC072768C"+ + "4F50D7D7803D2D4F278DE8014A47323631D7E064DE81C0C6BFA43EF0E6998860F"+ + "1390B5D3FEACAF1696015CB79C3F9C2D93D961120CD0E5F12CBB687EAB045241F"+ + "96789C38E89D796138E6319BE62E35D87B1048CA28BE389B575E994DCA7554715"+ + "84A09EC723742DC35873847AEF49F66E43873", 16), + large.NewIntFromString("2", 16)) +} + +func newSalt(s string) [group.SaltLen]byte { + var salt [group.SaltLen]byte + copy(salt[:], s) + return salt +} + +func newKey(s string) group.Key { + var key group.Key + copy(key[:], s) + return key +} + +func newIdPreimage(s string) group.IdPreimage { + var preimage group.IdPreimage + copy(preimage[:], s) + return preimage +} + +func newKeyPreimage(s string) group.KeyPreimage { + var preimage group.KeyPreimage + copy(preimage[:], s) + return preimage +} diff --git a/groupChat/internalFormat.go b/groupChat/internalFormat.go new file mode 100644 index 0000000000000000000000000000000000000000..2502a9c8c29c9f940a93bebb1101c5d5a5ae8ef2 --- /dev/null +++ b/groupChat/internalFormat.go @@ -0,0 +1,156 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + "encoding/binary" + "fmt" + "github.com/pkg/errors" + "gitlab.com/xx_network/primitives/id" + "strconv" + "time" +) + +// Sizes of marshaled data, in bytes. +const ( + timestampLen = 8 + idLen = id.ArrIDLen + internalPayloadSizeLen = 2 + internalMinLen = timestampLen + idLen + internalPayloadSizeLen +) + +// Error messages +const ( + newInternalSizeErr = "max message size %d < %d minimum required" + unmarshalInternalSizeErr = "size of data %d < %d minimum required" +) + +// internalMsg is the internal, unencrypted data in a group message. +// +// +-------------------------------------------+ +// | data | +// +-----------+----------+---------+----------+ +// | timestamp | senderID | size | payload | +// | 8 bytes | 32 bytes | 2 bytes | variable | +// +-----------+----------+---------+----------+ +type internalMsg struct { + data []byte // Serial of all the parts of the message + timestamp []byte // 64-bit Unix time timestamp stored in nanoseconds + senderID []byte // 264-bit sender ID + size []byte // Size of the payload + payload []byte // Message contents +} + +// newInternalMsg creates a new internalMsg of size maxDataSize. An error is +// returned if the maxDataSize is smaller than the minimum internalMsg size. +func newInternalMsg(maxDataSize int) (internalMsg, error) { + if maxDataSize < internalMinLen { + return internalMsg{}, + errors.Errorf(newInternalSizeErr, maxDataSize, internalMinLen) + } + + return mapInternalMsg(make([]byte, maxDataSize)), nil +} + +// mapInternalMsg maps all the parts of the internalMsg to the passed in data. +func mapInternalMsg(data []byte) internalMsg { + return internalMsg{ + data: data, + timestamp: data[:timestampLen], + senderID: data[timestampLen : timestampLen+idLen], + size: data[timestampLen+idLen : timestampLen+idLen+internalPayloadSizeLen], + payload: data[timestampLen+idLen+internalPayloadSizeLen:], + } +} + +// unmarshalInternalMsg unmarshal the data into an internalMsg. An error is +// returned if the data length is smaller than the minimum allowed size. +func unmarshalInternalMsg(data []byte) (internalMsg, error) { + if len(data) < internalMinLen { + return internalMsg{}, + errors.Errorf(unmarshalInternalSizeErr, len(data), internalMinLen) + } + + return mapInternalMsg(data), nil +} + +// Marshal returns the serial of the internalMsg. +func (im internalMsg) Marshal() []byte { + return im.data +} + +// GetTimestamp returns the timestamp as a time.Time. +func (im internalMsg) GetTimestamp() time.Time { + return time.Unix(0, int64(binary.LittleEndian.Uint64(im.timestamp))) +} + +// SetTimestamp converts the time.Time to Unix nano and save as bytes. +func (im internalMsg) SetTimestamp(t time.Time) { + binary.LittleEndian.PutUint64(im.timestamp, uint64(t.UnixNano())) +} + +// GetSenderID returns the sender ID bytes as a id.ID. +func (im internalMsg) GetSenderID() (*id.ID, error) { + return id.Unmarshal(im.senderID) +} + +// SetSenderID sets the sender ID. +func (im internalMsg) SetSenderID(sid *id.ID) { + copy(im.senderID, sid.Marshal()) +} + +// GetPayload returns the payload truncated to the correct size. +func (im internalMsg) GetPayload() []byte { + return im.payload[:im.GetPayloadSize()] +} + +// SetPayload sets the payload and saves it size. +func (im internalMsg) SetPayload(payload []byte) { + // Save size of payload + binary.LittleEndian.PutUint16(im.size, uint16(len(payload))) + + // Save payload + copy(im.payload, payload) +} + +// GetPayloadSize returns the length of the content in the payload. +func (im internalMsg) GetPayloadSize() int { + return int(binary.LittleEndian.Uint16(im.size)) +} + +// GetPayloadMaxSize returns the maximum size of the payload. +func (im internalMsg) GetPayloadMaxSize() int { + return len(im.payload) +} + +// String prints a string representation of internalMsg. This functions +// satisfies the fmt.Stringer interface. +func (im internalMsg) String() string { + timestamp := "<nil>" + if len(im.timestamp) > 0 { + timestamp = im.GetTimestamp().String() + } + + senderID := "<nil>" + if sid, _ := im.GetSenderID(); sid != nil { + senderID = sid.String() + } + + size := "<nil>" + if len(im.size) > 0 { + size = strconv.Itoa(im.GetPayloadSize()) + } + + payload := "<nil>" + if len(im.size) > 0 { + payload = fmt.Sprintf("%q", im.GetPayload()) + } + + return "{timestamp:" + timestamp + ", senderID:" + senderID + + ", size:" + size + ", payload:" + payload + "}" +} diff --git a/groupChat/internalFormat_test.go b/groupChat/internalFormat_test.go new file mode 100644 index 0000000000000000000000000000000000000000..984d11b8f35ec44935eaea48b5bbd76eec80bdb1 --- /dev/null +++ b/groupChat/internalFormat_test.go @@ -0,0 +1,211 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + "bytes" + "encoding/binary" + "fmt" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "reflect" + "testing" + "time" +) + +// Unit test of newInternalMsg. +func Test_newInternalMsg(t *testing.T) { + maxDataSize := 2 * internalMinLen + im, err := newInternalMsg(maxDataSize) + if err != nil { + t.Errorf("newInternalMsg() returned an error: %+v", err) + } + + if len(im.data) != maxDataSize { + t.Errorf("newInternalMsg() set data to the wrong length."+ + "\nexpected: %d\nreceived: %d", maxDataSize, len(im.data)) + } +} + +// Error path: the maxDataSize is smaller than the minimum size. +func Test_newInternalMsg_PayloadSizeError(t *testing.T) { + maxDataSize := internalMinLen - 1 + expectedErr := fmt.Sprintf(newInternalSizeErr, maxDataSize, internalMinLen) + + _, err := newInternalMsg(maxDataSize) + if err == nil || err.Error() != expectedErr { + t.Errorf("newInternalMsg() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Unit test of mapInternalMsg. +func Test_mapInternalMsg(t *testing.T) { + // Create all the expected data + timestamp := make([]byte, timestampLen) + binary.LittleEndian.PutUint64(timestamp, uint64(netTime.Now().UnixNano())) + senderID := id.NewIdFromString("test sender ID", id.User, t).Marshal() + payload := []byte("Sample payload contents.") + size := make([]byte, internalPayloadSizeLen) + binary.LittleEndian.PutUint16(size, uint16(len(payload))) + + // Construct data into single slice + data := bytes.NewBuffer(nil) + data.Write(timestamp) + data.Write(senderID) + data.Write(size) + data.Write(payload) + + // Map data + im := mapInternalMsg(data.Bytes()) + + // Check that the mapped values match the expected values + if !bytes.Equal(timestamp, im.timestamp) { + t.Errorf("mapInternalMsg() did not correctly map timestamp."+ + "\nexpected: %+v\nreceived: %+v", timestamp, im.timestamp) + } + + if !bytes.Equal(senderID, im.senderID) { + t.Errorf("mapInternalMsg() did not correctly map senderID."+ + "\nexpected: %+v\nreceived: %+v", senderID, im.senderID) + } + + if !bytes.Equal(size, im.size) { + t.Errorf("mapInternalMsg() did not correctly map size."+ + "\nexpected: %+v\nreceived: %+v", size, im.size) + } + + if !bytes.Equal(payload, im.payload) { + t.Errorf("mapInternalMsg() did not correctly map payload."+ + "\nexpected: %+v\nreceived: %+v", payload, im.payload) + } +} + +// Tests that a marshaled and unmarshalled internalMsg matches the original. +func TestInternalMsg_Marshal_unmarshalInternalMsg(t *testing.T) { + im, _ := newInternalMsg(internalMinLen * 2) + im.SetTimestamp(netTime.Now()) + im.SetSenderID(id.NewIdFromString("test sender ID", id.User, t)) + im.SetPayload([]byte("Sample payload message.")) + + data := im.Marshal() + + newIm, err := unmarshalInternalMsg(data) + if err != nil { + t.Errorf("unmarshalInternalMsg() returned an error: %+v", err) + } + + if !reflect.DeepEqual(im, newIm) { + t.Errorf("unmarshalInternalMsg() did not return the expected internalMsg."+ + "\nexpected: %s\nreceived: %s", im, newIm) + } +} + +// Error path: error is returned when the data is too short. +func Test_unmarshalInternalMsg_DataLengthError(t *testing.T) { + expectedErr := fmt.Sprintf(unmarshalInternalSizeErr, 0, internalMinLen) + + _, err := unmarshalInternalMsg(nil) + if err == nil || err.Error() != expectedErr { + t.Errorf("unmarshalInternalMsg() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Happy path. +func TestInternalMsg_SetTimestamp_GetTimestamp(t *testing.T) { + im, _ := newInternalMsg(internalMinLen * 2) + timestamp := netTime.Now() + im.SetTimestamp(timestamp) + testTimestamp := im.GetTimestamp() + + if !timestamp.Equal(testTimestamp) { + t.Errorf("Failed to get original timestamp."+ + "\nexpected: %s\nreceived: %s", timestamp, testTimestamp) + } +} + +// Happy path. +func TestInternalMsg_SetSenderID_GetSenderID(t *testing.T) { + im, _ := newInternalMsg(internalMinLen * 2) + sid := id.NewIdFromString("testSenderID", id.User, t) + im.SetSenderID(sid) + testID, err := im.GetSenderID() + if err != nil { + t.Errorf("GetSenderID() returned an error: %+v", err) + } + + if !sid.Cmp(testID) { + t.Errorf("Failed to get original sender ID."+ + "\nexpected: %s\nreceived: %s", sid, testID) + } +} + +// Tests that the original payload matches the saved one. +func TestInternalMsg_SetPayload_GetPayload(t *testing.T) { + im, _ := newInternalMsg(internalMinLen * 2) + payload := []byte("Test payload message.") + im.SetPayload(payload) + testPayload := im.GetPayload() + + if !bytes.Equal(payload, testPayload) { + t.Errorf("Failed to get original sender payload."+ + "\nexpected: %s\nreceived: %s", payload, testPayload) + } +} + +// Happy path. +func TestInternalMsg_GetPayloadSize(t *testing.T) { + im, _ := newInternalMsg(internalMinLen * 2) + payload := []byte("Test payload message.") + im.SetPayload(payload) + + if len(payload) != im.GetPayloadSize() { + t.Errorf("GetPayloadSize() failed to return the correct size."+ + "\nexpected: %d\nreceived: %d", len(payload), im.GetPayloadSize()) + } +} + +// Happy path. +func TestInternalMsg_GetPayloadMaxSize(t *testing.T) { + im, _ := newInternalMsg(internalMinLen * 2) + + if internalMinLen != im.GetPayloadMaxSize() { + t.Errorf("GetPayloadSize() failed to return the correct size."+ + "\nexpected: %d\nreceived: %d", internalMinLen, im.GetPayloadMaxSize()) + } +} + +// Happy path. +func TestInternalMsg_String(t *testing.T) { + im, _ := newInternalMsg(internalMinLen * 2) + im.SetTimestamp(time.Date(1955, 11, 5, 12, 0, 0, 0, time.UTC)) + im.SetSenderID(id.NewIdFromString("test sender ID", id.User, t)) + payload := []byte("Sample payload message.") + payload = append(payload, 0, 1, 2) + im.SetPayload(payload) + + expected := `{timestamp:` + im.GetTimestamp().String() + `, senderID:dGVzdCBzZW5kZXIgSUQAAAAAAAAAAAAAAAAAAAAAAAAD, size:26, payload:"Sample payload message.\x00\x01\x02"}` + + if im.String() != expected { + t.Errorf("String() failed to return the expected value."+ + "\nexpected: %s\nreceived: %s", expected, im.String()) + } +} + +// Happy path: tests that String returns the expected string for a nil internalMsg. +func TestInternalMsg_String_NilInternalMessage(t *testing.T) { + im := internalMsg{} + + expected := "{timestamp:<nil>, senderID:<nil>, size:<nil>, payload:<nil>}" + + if im.String() != expected { + t.Errorf("String() failed to return the expected value."+ + "\nexpected: %s\nreceived: %s", expected, im.String()) + } +} diff --git a/groupChat/makeGroup.go b/groupChat/makeGroup.go new file mode 100644 index 0000000000000000000000000000000000000000..fa230fadcc99f61a4b4a44d0f62aa549ec9ec306 --- /dev/null +++ b/groupChat/makeGroup.go @@ -0,0 +1,190 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + "github.com/pkg/errors" + gs "gitlab.com/elixxir/client/groupChat/groupStore" + "gitlab.com/elixxir/crypto/contact" + "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/xx_network/primitives/id" + "strconv" +) + +// Error messages. +const ( + maxInitMsgSizeErr = "new group request message length %d > %d maximum size" + getPrivKeyErr = "failed to get private key from partner: %+v" + minMembersErr = "length of membership list %d < %d minimum allowed" + maxMembersErr = "length of membership list %d > %d maximum allowed" + getPartnerErr = "failed to get partner %s: %+v" + makeMembershipErr = "failed to assemble group chat membership: %+v" + newIdPreimageErr = "failed to create group ID preimage: %+v" + newKeyPreimageErr = "failed to create group key preimage: %+v" + addGroupErr = "failed to save new group: %+v" +) + +// MaxInitMessageSize is the maximum allowable length of the initial message +// sent in a group request. +const MaxInitMessageSize = 256 + +// RequestStatus signals the status of the group requests on group creation. +type RequestStatus int + +const ( + NotSent RequestStatus = iota // Error occurred before sending requests + AllFail // Sending of all requests failed + PartialSent // Sending of some request failed + AllSent // Sending of all request succeeded +) + +// MakeGroup sends groupChat requests to all members over an authenticated +// channel. The leader of a groupChat must have an authenticated channel with +// each member of the groupChat to add them to the groupChat. It blocks until +// all the groupChat requests are sent. Returns an error if at least one request +// to a member fails to send. +func (m Manager) MakeGroup(membership []*id.ID, name, msg []byte) (gs.Group, + []id.Round, RequestStatus, error) { + // Return an error if the message is too long + if len(msg) > MaxInitMessageSize { + return gs.Group{}, nil, NotSent, + errors.Errorf(maxInitMsgSizeErr, len(msg), MaxInitMessageSize) + } + + // Build membership and DH key list from list of IDs + mem, dkl, err := m.buildMembership(membership) + if err != nil { + return gs.Group{}, nil, NotSent, err + } + + // Generate ID and key preimages + idPreimage, keyPreimage, err := getPreimages(m.rng) + if err != nil { + return gs.Group{}, nil, NotSent, err + } + + // Create new group ID and key + groupID := group.NewID(idPreimage, mem) + groupKey := group.NewKey(keyPreimage, mem) + + // Create new group and add to manager + g := gs.NewGroup(name, groupID, groupKey, idPreimage, keyPreimage, msg, mem, dkl) + if err := m.gs.Add(g); err != nil { + return gs.Group{}, nil, NotSent, errors.Errorf(addGroupErr, err) + } + + // Send all group requests + roundIDs, status, err := m.sendRequests(g) + + return g, roundIDs, status, err +} + +// buildMembership retrieves the contact object for each member ID and creates a +// new membership from them. The caller is set as the leader. For a member to be +// added, the group leader must have an authenticated channel with the member. +func (m Manager) buildMembership(members []*id.ID) (group.Membership, gs.DhKeyList, error) { + // Return an error if the membership list has too few or too many members + if len(members) < group.MinParticipants { + return nil, nil, + errors.Errorf(minMembersErr, len(members), group.MinParticipants) + } else if len(members) > group.MaxParticipants { + return nil, nil, + errors.Errorf(maxMembersErr, len(members), group.MaxParticipants) + } + + grp := m.store.E2e().GetGroup() + dkl := make(gs.DhKeyList, len(members)) + + // Lookup partner contact objects from their ID + contacts := make([]contact.Contact, len(members)) + var err error + for i, uid := range members { + partner, err := m.store.E2e().GetPartner(uid) + if err != nil { + return nil, nil, errors.Errorf(getPartnerErr, uid, err) + } + + contacts[i] = contact.Contact{ + ID: partner.GetPartnerID(), + DhPubKey: partner.GetPartnerOriginPublicKey(), + } + + dkl.Add(partner.GetMyOriginPrivateKey(), group.Member{ + ID: partner.GetPartnerID(), + DhKey: partner.GetPartnerOriginPublicKey(), + }, grp) + } + + // Create new Membership from contact list and client's own contact. + user := m.gs.GetUser() + leader := contact.Contact{ID: user.ID, DhPubKey: user.DhKey} + mem, err := group.NewMembership(leader, contacts...) + if err != nil { + return nil, nil, errors.Errorf(makeMembershipErr, err) + } + + return mem, dkl, nil +} + +// getPreimages generates and returns the group ID preimage and the group key +// preimage. This function allows the stream to +func getPreimages(streamGen *fastRNG.StreamGenerator) (group.IdPreimage, + group.KeyPreimage, error) { + + // Get new stream and defer its close + rng := streamGen.GetStream() + defer rng.Close() + + idPreimage, err := group.NewIdPreimage(rng) + if err != nil { + return group.IdPreimage{}, group.KeyPreimage{}, + errors.Errorf(newIdPreimageErr, err) + } + + keyPreimage, err := group.NewKeyPreimage(rng) + if err != nil { + return group.IdPreimage{}, group.KeyPreimage{}, + errors.Errorf(newKeyPreimageErr, err) + } + + return idPreimage, keyPreimage, nil +} + +// String prints the description of the status code. This functions satisfies +// the fmt.Stringer interface. +func (rs RequestStatus) String() string { + switch rs { + case NotSent: + return "NotSent" + case AllFail: + return "AllFail" + case PartialSent: + return "PartialSent" + case AllSent: + return "AllSent" + default: + return "INVALID STATUS" + } +} + +// Message prints a full description of the status code. +func (rs RequestStatus) Message() string { + switch rs { + case NotSent: + return "an error occurred before sending any group requests" + case AllFail: + return "all group requests failed to send" + case PartialSent: + return "some group requests failed to send" + case AllSent: + return "all groups requests successfully sent" + default: + return "INVALID STATUS " + strconv.Itoa(int(rs)) + } +} diff --git a/groupChat/makeGroup_test.go b/groupChat/makeGroup_test.go new file mode 100644 index 0000000000000000000000000000000000000000..004bf452664984c8ec8651442ab10a7a64d48fee --- /dev/null +++ b/groupChat/makeGroup_test.go @@ -0,0 +1,302 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + "bytes" + "fmt" + gs "gitlab.com/elixxir/client/groupChat/groupStore" + "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/primitives/id" + "math/rand" + "reflect" + "strconv" + "strings" + "testing" +) + +// Tests that Manager.MakeGroup adds a group and returns the expected status. +func TestManager_MakeGroup(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, _ := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + memberIDs, members, dkl := addPartners(m, t) + name := []byte("groupName") + message := []byte("Invite message.") + + g, _, status, err := m.MakeGroup(memberIDs, name, message) + if err != nil { + t.Errorf("MakeGroup() returned an error: %+v", err) + } + + if status != AllSent { + t.Errorf("MakeGroup() did not return the expected status."+ + "\nexpected: %s\nreceived: %s", AllSent, status) + } + + _, exists := m.gs.Get(g.ID) + if !exists { + t.Errorf("Failed to get group %#v.", g) + } + + if !reflect.DeepEqual(members, g.Members) { + t.Errorf("New group does not have expected membership."+ + "\nexpected: %s\nreceived: %s", members, g.Members) + } + + if !reflect.DeepEqual(dkl, g.DhKeys) { + t.Errorf("New group does not have expected DH key list."+ + "\nexpected: %#v\nreceived: %#v", dkl, g.DhKeys) + } + + if !g.ID.Cmp(g.ID) { + t.Errorf("New group does not have expected ID."+ + "\nexpected: %s\nreceived: %s", g.ID, g.ID) + } + + if !bytes.Equal(name, g.Name) { + t.Errorf("New group does not have expected name."+ + "\nexpected: %q\nreceived: %q", name, g.Name) + } + + if !bytes.Equal(message, g.InitMessage) { + t.Errorf("New group does not have expected message."+ + "\nexpected: %q\nreceived: %q", message, g.InitMessage) + } +} + +// Error path: make sure an error and the correct status is returned when the +// message is too large. +func TestManager_MakeGroup_MaxMessageSizeError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, _ := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + expectedErr := fmt.Sprintf(maxInitMsgSizeErr, MaxInitMessageSize+1, MaxInitMessageSize) + + _, _, status, err := m.MakeGroup(nil, nil, make([]byte, MaxInitMessageSize+1)) + if err == nil || err.Error() != expectedErr { + t.Errorf("MakeGroup() did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } + + if status != NotSent { + t.Errorf("MakeGroup() did not return the expected status."+ + "\nexpected: %s\nreceived: %s", NotSent, status) + } +} + +// Error path: make sure an error and the correct status is returned when the +// membership list is too small. +func TestManager_MakeGroup_MembershipSizeError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, _ := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + expectedErr := fmt.Sprintf(maxMembersErr, group.MaxParticipants+1, group.MaxParticipants) + + _, _, status, err := m.MakeGroup(make([]*id.ID, group.MaxParticipants+1), + nil, []byte{}) + if err == nil || err.Error() != expectedErr { + t.Errorf("MakeGroup() did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } + + if status != NotSent { + t.Errorf("MakeGroup() did not return the expected status."+ + "\nexpected: %s\nreceived: %s", NotSent, status) + } +} + +// Error path: make sure an error and the correct status is returned when adding +// a group failed because the user is a part of too many groups already. +func TestManager_MakeGroup_AddGroupError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, _ := newTestManagerWithStore(prng, gs.MaxGroupChats, 0, nil, nil, t) + memberIDs, _, _ := addPartners(m, t) + expectedErr := strings.SplitN(addGroupErr, "%", 2)[0] + + _, _, _, err := m.MakeGroup(memberIDs, []byte{}, []byte{}) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("MakeGroup() did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Unit test of Manager.buildMembership. +func TestManager_buildMembership(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, _ := newTestManager(prng, t) + memberIDs, expected, expectedDKL := addPartners(m, t) + + membership, dkl, err := m.buildMembership(memberIDs) + if err != nil { + t.Errorf("buildMembership() returned an error: %+v", err) + } + + if !reflect.DeepEqual(expected, membership) { + t.Errorf("buildMembership() failed to return the expected membership."+ + "\nexpected: %s\nrecieved: %s", expected, membership) + } + + if !reflect.DeepEqual(expectedDKL, dkl) { + t.Errorf("buildMembership() failed to return the expected DH key list."+ + "\nexpected: %#v\nrecieved: %#v", expectedDKL, dkl) + } +} + +// Error path: an error is returned when the number of members in the membership +// list is too few. +func TestManager_buildMembership_MinParticipantsError(t *testing.T) { + m, _ := newTestManager(rand.New(rand.NewSource(42)), t) + memberIDs := make([]*id.ID, group.MinParticipants-1) + expectedErr := fmt.Sprintf(minMembersErr, len(memberIDs), group.MinParticipants) + + _, _, err := m.buildMembership(memberIDs) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("buildMembership() did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: an error is returned when the number of members in the membership +// list is too many. +func TestManager_buildMembership_MaxParticipantsError(t *testing.T) { + m, _ := newTestManager(rand.New(rand.NewSource(42)), t) + memberIDs := make([]*id.ID, group.MaxParticipants+1) + expectedErr := fmt.Sprintf(maxMembersErr, len(memberIDs), group.MaxParticipants) + + _, _, err := m.buildMembership(memberIDs) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("buildMembership() did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: error returned when a partner cannot be found +func TestManager_buildMembership_GetPartnerContactError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, _ := newTestManager(prng, t) + memberIDs, _, _ := addPartners(m, t) + expectedErr := strings.SplitN(getPartnerErr, "%", 2)[0] + + // Replace a partner ID + memberIDs[len(memberIDs)/2] = id.NewIdFromString("nonPartnerID", id.User, t) + + _, _, err := m.buildMembership(memberIDs) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("buildMembership() did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: error returned when a member ID appears twice on the list. +func TestManager_buildMembership_DuplicateContactError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, _ := newTestManager(prng, t) + memberIDs, _, _ := addPartners(m, t) + expectedErr := strings.SplitN(makeMembershipErr, "%", 2)[0] + + // Replace a partner ID + memberIDs[5] = memberIDs[4] + + _, _, err := m.buildMembership(memberIDs) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("buildMembership() did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Test that getPreimages produces unique preimages. +func Test_getPreimages_Unique(t *testing.T) { + streamGen := fastRNG.NewStreamGenerator(1000, 10, csprng.NewSystemRNG) + n := 100 + idPreimages := make(map[group.IdPreimage]bool, n) + keyPreimages := make(map[group.KeyPreimage]bool, n) + + for i := 0; i < n; i++ { + idPreimage, keyPreimage, err := getPreimages(streamGen) + if err != nil { + t.Errorf("getPreimages() returned an error: %+v", err) + } + + if idPreimages[idPreimage] { + t.Errorf("getPreimages() produced a duplicate idPreimage: %s", idPreimage) + } else { + idPreimages[idPreimage] = true + } + + if keyPreimages[keyPreimage] { + t.Errorf("getPreimages() produced a duplicate keyPreimage: %s", keyPreimage) + } else { + keyPreimages[keyPreimage] = true + } + } +} + +// Unit test of RequestStatus.String. +func TestRequestStatus_String(t *testing.T) { + statusCodes := map[RequestStatus]string{ + NotSent: "NotSent", + AllFail: "AllFail", + PartialSent: "PartialSent", + AllSent: "AllSent", + AllSent + 1: "INVALID STATUS", + } + + for status, expected := range statusCodes { + if status.String() != expected { + t.Errorf("String() failed to return the expected name."+ + "\nexpected: %s\nreceived: %s", expected, status.String()) + } + } +} + +// Unit test of RequestStatus.Message. +func TestRequestStatus_Message(t *testing.T) { + statusCodes := map[RequestStatus]string{ + NotSent: "an error occurred before sending any group requests", + AllFail: "all group requests failed to send", + PartialSent: "some group requests failed to send", + AllSent: "all groups requests successfully sent", + AllSent + 1: "INVALID STATUS " + strconv.Itoa(int(AllSent)+1), + } + + for status, expected := range statusCodes { + if status.Message() != expected { + t.Errorf("Message() failed to return the expected message."+ + "\nexpected: %s\nreceived: %s", expected, status.Message()) + } + } +} + +// addPartners returns a list of user IDs and their matching membership and adds +// them as partners. +func addPartners(m *Manager, t *testing.T) ([]*id.ID, group.Membership, gs.DhKeyList) { + memberIDs := make([]*id.ID, 10) + members := group.Membership{m.gs.GetUser()} + dkl := gs.DhKeyList{} + + for i := range memberIDs { + // Build member data + uid := id.NewIdFromUInt(uint64(i), id.User, t) + dhKey := m.store.E2e().GetGroup().NewInt(int64(i + 42)) + + // Add to lists + memberIDs[i] = uid + members = append(members, group.Member{ID: uid, DhKey: dhKey}) + dkl.Add(dhKey, group.Member{ID: uid, DhKey: dhKey}, m.store.E2e().GetGroup()) + + // Add partner + err := m.store.E2e().AddPartner(uid, dhKey, dhKey, + params.GetDefaultE2ESessionParams(), params.GetDefaultE2ESessionParams()) + if err != nil { + t.Errorf("Failed to add partner %d: %+v", i, err) + } + } + + return memberIDs, members, dkl +} diff --git a/groupChat/manager.go b/groupChat/manager.go new file mode 100644 index 0000000000000000000000000000000000000000..f6044ed2269ed83e2ee96899f2768b27ce76e6f6 --- /dev/null +++ b/groupChat/manager.go @@ -0,0 +1,156 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + "github.com/pkg/errors" + "gitlab.com/elixxir/client/api" + gs "gitlab.com/elixxir/client/groupChat/groupStore" + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" + "gitlab.com/elixxir/client/storage" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/xx_network/primitives/id" +) + +const ( + rawMessageBuffSize = 100 + receiveStoppableName = "GroupChatReceive" + receiveListenerName = "GroupChatReceiveListener" + requestStoppableName = "GroupChatRequest" + requestListenerName = "GroupChatRequestListener" + groupStoppableName = "GroupChat" +) + +// Error messages. +const ( + newGroupStoreErr = "failed to create new group store: %+v" + joinGroupErr = "failed to join new group %s: %+v" + leaveGroupErr = "failed to leave group %s: %+v" +) + +// Manager handles the list of groups a user is a part of. +type Manager struct { + client *api.Client + store *storage.Session + swb interfaces.Switchboard + net interfaces.NetworkManager + rng *fastRNG.StreamGenerator + gs *gs.Store + + requestFunc RequestCallback + receiveFunc ReceiveCallback +} + +// NewManager generates a new group chat manager. This functions satisfies the +// GroupChat interface. +func NewManager(client *api.Client, requestFunc RequestCallback, + receiveFunc ReceiveCallback) (*Manager, error) { + return newManager( + client, + client.GetUser().ReceptionID.DeepCopy(), + client.GetStorage().E2e().GetDHPublicKey(), + client.GetStorage(), + client.GetSwitchboard(), + client.GetNetworkInterface(), + client.GetRng(), + client.GetStorage().GetKV(), + requestFunc, + receiveFunc, + ) +} + +// newManager creates a new group chat manager from api.Client parts for easier +// testing. +func newManager(client *api.Client, userID *id.ID, userDhKey *cyclic.Int, + store *storage.Session, swb interfaces.Switchboard, + net interfaces.NetworkManager, rng *fastRNG.StreamGenerator, + kv *versioned.KV, requestFunc RequestCallback, + receiveFunc ReceiveCallback) (*Manager, error) { + + // Load the group chat storage or create one if one does not exist + gStore, err := gs.NewOrLoadStore(kv, group.Member{ID: userID, DhKey: userDhKey}) + if err != nil { + return nil, errors.Errorf(newGroupStoreErr, err) + } + + return &Manager{ + client: client, + store: store, + swb: swb, + net: net, + rng: rng, + gs: gStore, + requestFunc: requestFunc, + receiveFunc: receiveFunc, + }, nil +} + +// StartProcesses starts the reception worker. +func (m *Manager) StartProcesses() stoppable.Stoppable { + // Start group reception worker + receiveStop := stoppable.NewSingle(receiveStoppableName) + receiveChan := make(chan message.Receive, rawMessageBuffSize) + m.swb.RegisterChannel(receiveListenerName, &id.ID{}, + message.Raw, receiveChan) + go m.receive(receiveChan, receiveStop) + + // Start group request worker + requestStop := stoppable.NewSingle(requestStoppableName) + requestChan := make(chan message.Receive, rawMessageBuffSize) + m.swb.RegisterChannel(requestListenerName, &id.ID{}, + message.GroupCreationRequest, requestChan) + go m.receiveRequest(requestChan, requestStop) + + // Create a multi stoppable + multiStoppable := stoppable.NewMulti(groupStoppableName) + multiStoppable.Add(receiveStop) + multiStoppable.Add(requestStop) + + return multiStoppable +} + +// JoinGroup adds the group to the list of group chats the user is a part of. +// An error is returned if the user is already part of the group or if the +// maximum number of groups have already been joined. +func (m Manager) JoinGroup(g gs.Group) error { + if err := m.gs.Add(g); err != nil { + return errors.Errorf(joinGroupErr, g.ID, err) + } + + return nil +} + +// LeaveGroup removes a group from a list of groups the user is a part of. +func (m Manager) LeaveGroup(groupID *id.ID) error { + if err := m.gs.Remove(groupID); err != nil { + return errors.Errorf(leaveGroupErr, groupID, err) + } + + return nil +} + +// GetGroups returns a list of all registered groupChat IDs. +func (m Manager) GetGroups() []*id.ID { + return m.gs.GroupIDs() +} + +// GetGroup returns the group with the matching ID or returns false if none +// exist. +func (m Manager) GetGroup(groupID *id.ID) (gs.Group, bool) { + return m.gs.Get(groupID) +} + +// NumGroups returns the number of groups the user is a part of. +func (m Manager) NumGroups() int { + return m.gs.Len() +} diff --git a/groupChat/manager_test.go b/groupChat/manager_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0ea0f5341018fac4a24e72b999603d2a93086ed2 --- /dev/null +++ b/groupChat/manager_test.go @@ -0,0 +1,386 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + gs "gitlab.com/elixxir/client/groupChat/groupStore" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/primitives/id" + "math/rand" + "reflect" + "strings" + "testing" + "time" +) + +// Unit test of Manager.newManager. +func Test_newManager(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + user := group.Member{ + ID: id.NewIdFromString("userID", id.User, t), + DhKey: randCycInt(rand.New(rand.NewSource(42))), + } + requestChan := make(chan gs.Group) + requestFunc := func(g gs.Group) { requestChan <- g } + receiveChan := make(chan MessageReceive) + receiveFunc := func(msg MessageReceive) { receiveChan <- msg } + m, err := newManager(nil, user.ID, user.DhKey, nil, nil, nil, nil, kv, requestFunc, receiveFunc) + if err != nil { + t.Errorf("newManager() returned an error: %+v", err) + } + + if !m.gs.GetUser().Equal(user) { + t.Errorf("newManager() failed to create a store with the correct user."+ + "\nexpected: %s\nreceived: %s", user, m.gs.GetUser()) + } + + if m.gs.Len() != 0 { + t.Errorf("newManager() failed to create an empty store."+ + "\nexpected: %d\nreceived: %d", 0, m.gs.Len()) + } + + // Check if requestFunc works + go m.requestFunc(gs.Group{}) + select { + case <-requestChan: + case <-time.NewTimer(5 * time.Millisecond).C: + t.Errorf("Timed out waiting for requestFunc to be called.") + } + + // Check if receiveFunc works + go m.receiveFunc(MessageReceive{}) + select { + case <-receiveChan: + case <-time.NewTimer(5 * time.Millisecond).C: + t.Errorf("Timed out waiting for receiveFunc to be called.") + } +} + +// Tests that Manager.newManager loads a group storage when it exists. +func Test_newManager_LoadStorage(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + user := group.Member{ + ID: id.NewIdFromString("userID", id.User, t), + DhKey: randCycInt(rand.New(rand.NewSource(42))), + } + + gStore, err := gs.NewStore(kv, user) + if err != nil { + t.Errorf("Failed to create new group storage: %+v", err) + } + + for i := 0; i < 10; i++ { + err := gStore.Add(newTestGroup(getGroup(), getGroup().NewInt(42), prng, t)) + if err != nil { + t.Errorf("Failed to add group %d: %+v", i, err) + } + } + + m, err := newManager(nil, user.ID, user.DhKey, nil, nil, nil, nil, kv, nil, nil) + if err != nil { + t.Errorf("newManager() returned an error: %+v", err) + } + + if !reflect.DeepEqual(gStore, m.gs) { + t.Errorf("newManager() failed to load the expected storage."+ + "\nexpected: %+v\nreceived: %+v", gStore, m.gs) + } +} + +// Error path: an error is returned when a group cannot be loaded from storage. +func Test_newManager_LoadError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + kv := versioned.NewKV(make(ekv.Memstore)) + user := group.Member{ + ID: id.NewIdFromString("userID", id.User, t), + DhKey: randCycInt(rand.New(rand.NewSource(42))), + } + + gStore, err := gs.NewStore(kv, user) + if err != nil { + t.Errorf("Failed to create new group storage: %+v", err) + } + + g := newTestGroup(getGroup(), getGroup().NewInt(42), prng, t) + err = gStore.Add(g) + if err != nil { + t.Errorf("Failed to add group: %+v", err) + } + _ = kv.Prefix("GroupChatListStore").Delete("GroupChat/"+g.ID.String(), 0) + + expectedErr := strings.SplitN(newGroupStoreErr, "%", 2)[0] + + _, err = newManager(nil, user.ID, user.DhKey, nil, nil, nil, nil, kv, nil, nil) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("newManager() did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// +// func TestManager_StartProcesses(t *testing.T) { +// jww.SetLogThreshold(jww.LevelTrace) +// jww.SetStdoutThreshold(jww.LevelTrace) +// prng := rand.New(rand.NewSource(42)) +// requestChan1 := make(chan gs.Group) +// requestFunc1 := func(g gs.Group) { requestChan1 <- g } +// receiveChan1 := make(chan MessageReceive) +// receiveFunc1 := func(msg MessageReceive) { receiveChan1 <- msg } +// requestChan2 := make(chan gs.Group) +// requestFunc2 := func(g gs.Group) { requestChan2 <- g } +// receiveChan2 := make(chan MessageReceive) +// receiveFunc2 := func(msg MessageReceive) { receiveChan2 <- msg } +// requestChan3 := make(chan gs.Group) +// requestFunc3 := func(g gs.Group) { requestChan3 <- g } +// receiveChan3 := make(chan MessageReceive) +// receiveFunc3 := func(msg MessageReceive) { receiveChan3 <- msg } +// +// m1, _ := newTestManagerWithStore(prng, 10, 0, requestFunc1, receiveFunc1, t) +// m2, _ := newTestManagerWithStore(prng, 10, 0, requestFunc2, receiveFunc2, t) +// m3, _ := newTestManagerWithStore(prng, 10, 0, requestFunc3, receiveFunc3, t) +// +// membership, err := group.NewMembership(m1.store.GetUser().GetContact(), +// m2.store.GetUser().GetContact(), m3.store.GetUser().GetContact()) +// if err != nil { +// t.Errorf("Failed to generate new membership: %+v", err) +// } +// +// dhKeys := gs.GenerateDhKeyList(m1.gs.GetUser().ID, +// m1.store.GetUser().E2eDhPrivateKey, membership, m1.store.E2e().GetGroup()) +// +// grp1 := newTestGroup(m1.store.E2e().GetGroup(), m1.store.GetUser().E2eDhPrivateKey, prng, t) +// grp1.Members = membership +// grp1.DhKeys = dhKeys +// grp1.ID = group.NewID(grp1.IdPreimage, grp1.Members) +// grp1.Key = group.NewKey(grp1.KeyPreimage, grp1.Members) +// grp2 := grp1.DeepCopy() +// grp2.DhKeys = gs.GenerateDhKeyList(m2.gs.GetUser().ID, +// m2.store.GetUser().E2eDhPrivateKey, membership, m2.store.E2e().GetGroup()) +// grp3 := grp1.DeepCopy() +// grp3.DhKeys = gs.GenerateDhKeyList(m3.gs.GetUser().ID, +// m3.store.GetUser().E2eDhPrivateKey, membership, m3.store.E2e().GetGroup()) +// +// err = m1.gs.Add(grp1) +// if err != nil { +// t.Errorf("Failed to add group to member 1: %+v", err) +// } +// err = m2.gs.Add(grp2) +// if err != nil { +// t.Errorf("Failed to add group to member 2: %+v", err) +// } +// err = m3.gs.Add(grp3) +// if err != nil { +// t.Errorf("Failed to add group to member 3: %+v", err) +// } +// +// _ = m1.StartProcesses() +// _ = m2.StartProcesses() +// _ = m3.StartProcesses() +// +// // Build request message +// requestMarshaled, err := proto.Marshal(&Request{ +// Name: grp1.Name, +// IdPreimage: grp1.IdPreimage.Bytes(), +// KeyPreimage: grp1.KeyPreimage.Bytes(), +// Members: grp1.Members.Serialize(), +// Message: grp1.InitMessage, +// }) +// if err != nil { +// t.Errorf("Failed to proto marshal message: %+v", err) +// } +// msg := message.Receive{ +// Payload: requestMarshaled, +// MessageType: message.GroupCreationRequest, +// Sender: m1.gs.GetUser().ID, +// } +// +// m2.swb.(*switchboard.Switchboard).Speak(msg) +// m3.swb.(*switchboard.Switchboard).Speak(msg) +// +// select { +// case received := <-requestChan2: +// if !reflect.DeepEqual(grp2, received) { +// t.Errorf("Failed to receive expected group on requestChan."+ +// "\nexpected: %#v\nreceived: %#v", grp2, received) +// } +// case <-time.NewTimer(5 * time.Millisecond).C: +// t.Error("Timed out waiting for request callback.") +// } +// +// select { +// case received := <-requestChan3: +// if !reflect.DeepEqual(grp3, received) { +// t.Errorf("Failed to receive expected group on requestChan."+ +// "\nexpected: %#v\nreceived: %#v", grp3, received) +// } +// case <-time.NewTimer(5 * time.Millisecond).C: +// t.Error("Timed out waiting for request callback.") +// } +// +// contents := []byte("Test group message.") +// timestamp := netTime.Now() +// +// // Create cMix message and get public message +// cMixMsg, err := m1.newCmixMsg(grp1, contents, timestamp, m2.gs.GetUser(), prng) +// if err != nil { +// t.Errorf("Failed to create new cMix message: %+v", err) +// } +// +// internalMsg, _ := newInternalMsg(cMixMsg.ContentsSize() - publicMinLen) +// internalMsg.SetTimestamp(timestamp) +// internalMsg.SetSenderID(m1.gs.GetUser().ID) +// internalMsg.SetPayload(contents) +// expectedMsgID := group.NewMessageID(grp1.ID, internalMsg.Marshal()) +// +// expectedMsg := MessageReceive{ +// GroupID: grp1.ID, +// ID: expectedMsgID, +// Payload: contents, +// SenderID: m1.gs.GetUser().ID, +// RoundTimestamp: timestamp.Local(), +// } +// +// msg = message.Receive{ +// Payload: cMixMsg.Marshal(), +// MessageType: message.Raw, +// Sender: m1.gs.GetUser().ID, +// RoundTimestamp: timestamp.Local(), +// } +// m2.swb.(*switchboard.Switchboard).Speak(msg) +// +// select { +// case received := <-receiveChan2: +// if !reflect.DeepEqual(expectedMsg, received) { +// t.Errorf("Failed to receive expected group on receiveChan."+ +// "\nexpected: %+v\nreceived: %+v", expectedMsg, received) +// } +// case <-time.NewTimer(5 * time.Millisecond).C: +// t.Error("Timed out waiting for receive callback.") +// } +// } + +// Unit test of Manager.JoinGroup. +func TestManager_JoinGroup(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, _ := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + g := newTestGroup(m.store.E2e().GetGroup(), m.store.GetUser().E2eDhPrivateKey, prng, t) + + err := m.JoinGroup(g) + if err != nil { + t.Errorf("JoinGroup() returned an error: %+v", err) + } + + if _, exists := m.gs.Get(g.ID); !exists { + t.Errorf("JoinGroup() failed to add the group %s.", g.ID) + } +} + +// Error path: an error is returned when a group is joined twice. +func TestManager_JoinGroup_AddErr(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, g := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + expectedErr := strings.SplitN(joinGroupErr, "%", 2)[0] + + err := m.JoinGroup(g) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("JoinGroup() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Unit test of Manager.LeaveGroup. +func TestManager_LeaveGroup(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, g := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + + err := m.LeaveGroup(g.ID) + if err != nil { + t.Errorf("LeaveGroup() returned an error: %+v", err) + } + + if _, exists := m.GetGroup(g.ID); exists { + t.Error("LeaveGroup() failed to delete the group.") + } +} + +// Error path: an error is returned when no group with the ID exists +func TestManager_LeaveGroup_NoGroupError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, _ := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + expectedErr := strings.SplitN(leaveGroupErr, "%", 2)[0] + + err := m.LeaveGroup(id.NewIdFromString("invalidID", id.Group, t)) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("LeaveGroup() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Unit test of Manager.GetGroups. +func TestManager_GetGroups(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, _ := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + + list := m.GetGroups() + for i, gid := range list { + if err := m.gs.Remove(gid); err != nil { + t.Errorf("Group %s does not exist (%d): %+v", gid, i, err) + } + } + + if m.gs.Len() != 0 { + t.Errorf("GetGroups() returned %d IDs, which is %d less than is in "+ + "memory.", len(list), m.gs.Len()) + } +} + +// Unit test of Manager.GetGroup. +func TestManager_GetGroup(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, g := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + + testGrp, exists := m.GetGroup(g.ID) + if !exists { + t.Error("GetGroup() failed to find a group that should exist.") + } + + if !reflect.DeepEqual(g, testGrp) { + t.Errorf("GetGroup() failed to return the expected group."+ + "\nexpected: %#v\nreceived: %#v", g, testGrp) + } + + testGrp, exists = m.GetGroup(id.NewIdFromString("invalidID", id.Group, t)) + if exists { + t.Errorf("GetGroup() returned a group that should not exist: %#v", testGrp) + } +} + +// Unit test of Manager.NumGroups. First a manager is created with 10 groups +// and the initial number is checked. Then the number of groups is checked after +// leaving each until the number left is 0. +func TestManager_NumGroups(t *testing.T) { + expectedNum := 10 + m, _ := newTestManagerWithStore(rand.New(rand.NewSource(42)), expectedNum, + 0, nil, nil, t) + + groups := append([]*id.ID{{}}, m.GetGroups()...) + + for i, gid := range groups { + _ = m.LeaveGroup(gid) + + if m.NumGroups() != expectedNum-i { + t.Errorf("NumGroups() failed to return the expected number of "+ + "groups (%d).\nexpected: %d\nreceived: %d", + i, expectedNum-i, m.NumGroups()) + } + } + +} diff --git a/groupChat/messageReceive.go b/groupChat/messageReceive.go new file mode 100644 index 0000000000000000000000000000000000000000..e607e7f01fcd1aa2e6bb4cee11356e3ea71814f5 --- /dev/null +++ b/groupChat/messageReceive.go @@ -0,0 +1,69 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + "fmt" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/id/ephemeral" + "strconv" + "strings" + "time" +) + +// MessageReceive contains the GroupChat message and associated data that a user +// receives when getting a group message. +type MessageReceive struct { + GroupID *id.ID + ID group.MessageID + Payload []byte + SenderID *id.ID + RecipientID *id.ID + EphemeralID ephemeral.Id + Timestamp time.Time + RoundID id.Round + RoundTimestamp time.Time +} + +// String returns the MessageReceive as readable text. This functions satisfies +// the fmt.Stringer interface. +func (mr MessageReceive) String() string { + groupID := "<nil>" + if mr.GroupID != nil { + groupID = mr.GroupID.String() + } + + payload := "<nil>" + if mr.Payload != nil { + payload = fmt.Sprintf("%q", mr.Payload) + } + + senderID := "<nil>" + if mr.SenderID != nil { + senderID = mr.SenderID.String() + } + + recipientID := "<nil>" + if mr.RecipientID != nil { + recipientID = mr.RecipientID.String() + } + + str := make([]string, 0, 9) + str = append(str, "GroupID:"+groupID) + str = append(str, "ID:"+mr.ID.String()) + str = append(str, "Payload:"+payload) + str = append(str, "SenderID:"+senderID) + str = append(str, "RecipientID:"+recipientID) + str = append(str, "EphemeralID:"+strconv.FormatInt(mr.EphemeralID.Int64(), 10)) + str = append(str, "Timestamp:"+mr.Timestamp.String()) + str = append(str, "RoundID:"+strconv.FormatUint(uint64(mr.RoundID), 10)) + str = append(str, "RoundTimestamp:"+mr.RoundTimestamp.String()) + + return "{" + strings.Join(str, " ") + "}" +} diff --git a/groupChat/messageReceive_test.go b/groupChat/messageReceive_test.go new file mode 100644 index 0000000000000000000000000000000000000000..343f53774a08e59caed53a37d31a1a35f7047cd7 --- /dev/null +++ b/groupChat/messageReceive_test.go @@ -0,0 +1,70 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// +package groupChat + +import ( + "gitlab.com/elixxir/crypto/group" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/id/ephemeral" + "testing" + "time" +) + +// Unit test of MessageReceive.String. +func TestMessageReceive_String(t *testing.T) { + msg := MessageReceive{ + GroupID: id.NewIdFromString("GroupID", id.Group, t), + ID: group.MessageID{0, 1, 2, 3}, + Payload: []byte("Group message."), + SenderID: id.NewIdFromString("SenderID", id.User, t), + RecipientID: id.NewIdFromString("RecipientID", id.User, t), + EphemeralID: ephemeral.Id{0, 1, 2, 3}, + Timestamp: time.Date(1955, 11, 5, 12, 0, 0, 0, time.UTC), + RoundID: 42, + RoundTimestamp: time.Date(1955, 11, 5, 12, 1, 0, 0, time.UTC), + } + + expected := "{" + + "GroupID:R3JvdXBJRAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAE " + + "ID:AAECAwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= " + + "Payload:\"Group message.\" " + + "SenderID:U2VuZGVySUQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD " + + "RecipientID:UmVjaXBpZW50SUQAAAAAAAAAAAAAAAAAAAAAAAAAAAAD " + + "EphemeralID:141843442434048 " + + "Timestamp:" + msg.Timestamp.String() + " " + + "RoundID:42 " + + "RoundTimestamp:" + msg.RoundTimestamp.String() + + "}" + + if msg.String() != expected { + t.Errorf("String() returned the incorrect string."+ + "\nexpected: %s\nreceived: %s", expected, msg.String()) + } +} + +// Tests that MessageReceive.String returns the expected value for a message +// with nil values. +func TestMessageReceive_String_NilMessageReceive(t *testing.T) { + msg := MessageReceive{} + + expected := "{" + + "GroupID:<nil> " + + "ID:AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= " + + "Payload:<nil> " + + "SenderID:<nil> " + + "RecipientID:<nil> " + + "EphemeralID:0 " + + "Timestamp:0001-01-01 00:00:00 +0000 UTC " + + "RoundID:0 " + + "RoundTimestamp:0001-01-01 00:00:00 +0000 UTC" + + "}" + + if msg.String() != expected { + t.Errorf("String() returned the incorrect string."+ + "\nexpected: %s\nreceived: %s", expected, msg.String()) + } +} diff --git a/groupChat/publicFormat.go b/groupChat/publicFormat.go new file mode 100644 index 0000000000000000000000000000000000000000..ab88d9e09f9c7e5110404fca5fc473070b45c088 --- /dev/null +++ b/groupChat/publicFormat.go @@ -0,0 +1,120 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + "encoding/base64" + "fmt" + "github.com/pkg/errors" + "gitlab.com/elixxir/crypto/group" +) + +// Sizes of marshaled data, in bytes. +const ( + saltLen = group.SaltLen + publicMinLen = saltLen +) + +// Error messages +const ( + newPublicSizeErr = "max message size %d < %d minimum required" + unmarshalPublicSizeErr = "size of data %d < %d minimum required" +) + +// publicMsg is contains the salt and encrypted data in a group message. +// +// +---------------------+ +// | data | +// +----------+----------+ +// | salt | payload | +// | 32 bytes | variable | +// +----------+----------+ +type publicMsg struct { + data []byte // Serial of all the parts of the message + salt []byte // 256-bit sender salt + payload []byte // Encrypted internalMsg +} + +// newPublicMsg creates a new publicMsg of size maxDataSize. An error is +// returned if the maxDataSize is smaller than the minimum newPublicMsg size. +func newPublicMsg(maxDataSize int) (publicMsg, error) { + if maxDataSize < publicMinLen { + return publicMsg{}, + errors.Errorf(newPublicSizeErr, maxDataSize, publicMinLen) + } + + return mapPublicMsg(make([]byte, maxDataSize)), nil +} + +// mapPublicMsg maps all the parts of the publicMsg to the passed in data. +func mapPublicMsg(data []byte) publicMsg { + return publicMsg{ + data: data, + salt: data[:saltLen], + payload: data[saltLen:], + } +} + +// unmarshalPublicMsg unmarshal the data into an publicMsg. An error is +// returned if the data length is smaller than the minimum allowed size. +func unmarshalPublicMsg(data []byte) (publicMsg, error) { + if len(data) < publicMinLen { + return publicMsg{}, + errors.Errorf(unmarshalPublicSizeErr, len(data), publicMinLen) + } + + return mapPublicMsg(data), nil +} + +// Marshal returns the serial of the publicMsg. +func (pm publicMsg) Marshal() []byte { + return pm.data +} + +// GetSalt returns the 256-bit salt. +func (pm publicMsg) GetSalt() [group.SaltLen]byte { + var salt [group.SaltLen]byte + copy(salt[:], pm.salt) + return salt +} + +// SetSalt sets the 256-bit salt. +func (pm publicMsg) SetSalt(salt [group.SaltLen]byte) { + copy(pm.salt, salt[:]) +} + +// GetPayload returns the payload truncated to the correct size. +func (pm publicMsg) GetPayload() []byte { + return pm.payload +} + +// SetPayload sets the payload and saves it size. +func (pm publicMsg) SetPayload(payload []byte) { + copy(pm.payload, payload) +} + +// GetPayloadSize returns the maximum size of the payload. +func (pm publicMsg) GetPayloadSize() int { + return len(pm.payload) +} + +// String prints a string representation of publicMsg. This functions satisfies +// the fmt.Stringer interface. +func (pm publicMsg) String() string { + salt := "<nil>" + if len(pm.salt) > 0 { + salt = base64.StdEncoding.EncodeToString(pm.salt) + } + + payload := "<nil>" + if len(pm.payload) > 0 { + payload = fmt.Sprintf("%q", pm.GetPayload()) + } + + return "{salt:" + salt + ", payload:" + payload + "}" +} diff --git a/groupChat/publicFormat_test.go b/groupChat/publicFormat_test.go new file mode 100644 index 0000000000000000000000000000000000000000..69884ff76856562e0d6f9ee3af03ad63e6eecb74 --- /dev/null +++ b/groupChat/publicFormat_test.go @@ -0,0 +1,162 @@ +package groupChat + +import ( + "bytes" + "fmt" + "math/rand" + "reflect" + "testing" +) + +// Unit test of newPublicMsg. +func Test_newPublicMsg(t *testing.T) { + maxDataSize := 2 * publicMinLen + im, err := newPublicMsg(maxDataSize) + if err != nil { + t.Errorf("newPublicMsg() returned an error: %+v", err) + } + + if len(im.data) != maxDataSize { + t.Errorf("newPublicMsg() set data to the wrong length."+ + "\nexpected: %d\nreceived: %d", maxDataSize, len(im.data)) + } +} + +// Error path: the maxDataSize is smaller than the minimum size. +func Test_newPublicMsg_PayloadSizeError(t *testing.T) { + maxDataSize := publicMinLen - 1 + expectedErr := fmt.Sprintf(newPublicSizeErr, maxDataSize, publicMinLen) + + _, err := newPublicMsg(maxDataSize) + if err == nil || err.Error() != expectedErr { + t.Errorf("newPublicMsg() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Unit test of mapPublicMsg. +func Test_mapPublicMsg(t *testing.T) { + // Create all the expected data + var salt [saltLen]byte + rand.New(rand.NewSource(42)).Read(salt[:]) + payload := []byte("Sample payload contents.") + + // Construct data into single slice + data := bytes.NewBuffer(nil) + data.Write(salt[:]) + data.Write(payload) + + // Map data + im := mapPublicMsg(data.Bytes()) + + // Check that the mapped values match the expected values + if !bytes.Equal(salt[:], im.salt) { + t.Errorf("mapPublicMsg() did not correctly map salt."+ + "\nexpected: %+v\nreceived: %+v", salt, im.salt) + } + + if !bytes.Equal(payload, im.payload) { + t.Errorf("mapPublicMsg() did not correctly map payload."+ + "\nexpected: %+v\nreceived: %+v", payload, im.payload) + } +} + +// Tests that a marshaled and unmarshalled publicMsg matches the original. +func Test_publicMsg_Marshal_unmarshalPublicMsg(t *testing.T) { + pm, _ := newPublicMsg(publicMinLen * 2) + var salt [saltLen]byte + rand.New(rand.NewSource(42)).Read(salt[:]) + pm.SetSalt(salt) + pm.SetPayload([]byte("Sample payload message.")) + + data := pm.Marshal() + + newPm, err := unmarshalPublicMsg(data) + if err != nil { + t.Errorf("unmarshalPublicMsg() returned an error: %+v", err) + } + + if !reflect.DeepEqual(pm, newPm) { + t.Errorf("unmarshalPublicMsg() did not return the expected publicMsg."+ + "\nexpected: %s\nreceived: %s", pm, newPm) + } +} + +// Error path: error is returned when the data is too short. +func Test_unmarshalPublicMsg(t *testing.T) { + expectedErr := fmt.Sprintf(unmarshalPublicSizeErr, 0, publicMinLen) + + _, err := unmarshalPublicMsg(nil) + if err == nil || err.Error() != expectedErr { + t.Errorf("unmarshalPublicMsg() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Happy path. +func Test_publicMsg_SetSalt_GetSalt(t *testing.T) { + pm, _ := newPublicMsg(publicMinLen * 2) + var salt [saltLen]byte + rand.New(rand.NewSource(42)).Read(salt[:]) + pm.SetSalt(salt) + + testSalt := pm.GetSalt() + if salt != testSalt { + t.Errorf("Failed to get original salt."+ + "\nexpected: %+v\nreceived: %+v", salt, testSalt) + } +} + +// Tests that the original payload matches the saved one. +func Test_publicMsg_SetPayload_GetPayload(t *testing.T) { + pm, _ := newPublicMsg(publicMinLen * 2) + payload := make([]byte, pm.GetPayloadSize()) + copy(payload, "Test payload message.") + pm.SetPayload(payload) + testPayload := pm.GetPayload() + + if !bytes.Equal(payload, testPayload) { + t.Errorf("Failed to get original sender payload."+ + "\nexpected: %q\nreceived: %q", payload, testPayload) + } +} + +// Happy path. +func Test_publicMsg_GetPayloadSize(t *testing.T) { + pm, _ := newPublicMsg(publicMinLen * 2) + + if publicMinLen != pm.GetPayloadSize() { + t.Errorf("GetPayloadSize() failed to return the correct size."+ + "\nexpected: %d\nreceived: %d", publicMinLen, pm.GetPayloadSize()) + } +} + +// Happy path. +func Test_publicMsg_String(t *testing.T) { + pm, _ := newPublicMsg(publicMinLen * 2) + var salt [saltLen]byte + rand.New(rand.NewSource(42)).Read(salt[:]) + pm.SetSalt(salt) + payload := []byte("Sample payload message.") + payload = append(payload, 0, 1, 2) + pm.SetPayload(payload) + + expected := `{salt:U4x/lrFkvxuXu59LtHLon1sUhPJSCcnZND6SugndnVI=, payload:"Sample payload message.\x00\x01\x02\x00\x00\x00\x00\x00\x00"}` + + if pm.String() != expected { + t.Errorf("String() failed to return the expected value."+ + "\nexpected: %s\nreceived: %s", expected, pm.String()) + } +} + +// Happy path: tests that String returns the expected string for a nil publicMsg. +func Test_publicMsg_String_NilInternalMessage(t *testing.T) { + pm := publicMsg{} + + expected := "{salt:<nil>, payload:<nil>}" + + if pm.String() != expected { + t.Errorf("String() failed to return the expected value."+ + "\nexpected: %s\nreceived: %s", expected, pm.String()) + } +} diff --git a/groupChat/receive.go b/groupChat/receive.go new file mode 100644 index 0000000000000000000000000000000000000000..64cf10b789b3d07bf9d1dbd82d6d76b15c91fe53 --- /dev/null +++ b/groupChat/receive.go @@ -0,0 +1,168 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + gs "gitlab.com/elixxir/client/groupChat/groupStore" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/primitives/id" + "time" +) + +// Error messages. +const ( + newDecryptKeyErr = "failed to generate key for decrypting group payload: %+v" + unmarshalInternalMsgErr = "failed to unmarshal group internal message: %+v" + unmarshalSenderIdErr = "failed to unmarshal sender ID: %+v" + unmarshalPublicMsgErr = "failed to unmarshal group cMix message contents: %+v" + findGroupKeyFpErr = "failed to find group with key fingerprint matching %s" + genCryptKeyMacErr = "failed to generate encryption key for group " + + "cMix message because MAC verification failed (epoch %d could be off)" +) + +// receive starts the group message reception worker that waits for new group +// messages to arrive. +func (m Manager) receive(rawMsgs chan message.Receive, stop *stoppable.Single) { + jww.DEBUG.Print("Starting group message reception worker.") + + for { + select { + case <-stop.Quit(): + jww.DEBUG.Print("Stopping group message reception worker.") + stop.ToStopped() + return + case receiveMsg := <-rawMsgs: + jww.DEBUG.Print("Group message reception received cMix message.") + + // Attempt to read the message + g, msgID, timestamp, senderID, msg, err := m.readMessage(receiveMsg) + if err != nil { + jww.WARN.Printf("Group message reception failed to read cMix "+ + "message: %+v", err) + continue + } + + // If the message was read correctly, send it to the callback + go m.receiveFunc(MessageReceive{ + GroupID: g.ID, + ID: msgID, + Payload: msg, + SenderID: senderID, + RecipientID: receiveMsg.RecipientID, + EphemeralID: receiveMsg.EphemeralID, + Timestamp: receiveMsg.Timestamp, + RoundID: receiveMsg.RoundId, + RoundTimestamp: timestamp, + }) + } + } +} + +// readMessage returns the group, message ID, timestamp, sender ID, and message +// of a group message. The encrypted group message data is unmarshaled from a +// cMix message in the message.Receive and then decrypted and the MAC is +// verified. The group is found by finding the group with a matching key +// fingerprint. +func (m *Manager) readMessage(msg message.Receive) (gs.Group, group.MessageID, + time.Time, *id.ID, []byte, error) { + // Unmarshal payload into cMix message + cMixMsg := format.Unmarshal(msg.Payload) + + // Unmarshal cMix message contents to get public message format + publicMsg, err := unmarshalPublicMsg(cMixMsg.GetContents()) + if err != nil { + return gs.Group{}, group.MessageID{}, time.Time{}, nil, nil, + errors.Errorf(unmarshalPublicMsgErr, err) + } + + // Get the group from storage via key fingerprint lookup + g, exists := m.gs.GetByKeyFp(cMixMsg.GetKeyFP(), publicMsg.GetSalt()) + if !exists { + return gs.Group{}, group.MessageID{}, time.Time{}, nil, nil, + errors.Errorf(findGroupKeyFpErr, cMixMsg.GetKeyFP()) + } + + // Decrypt the payload and return the messages timestamp, sender ID, and + // message contents + messageID, timestamp, senderID, contents, err := m.decryptMessage( + g, cMixMsg, publicMsg, msg.RoundTimestamp) + return g, messageID, timestamp, senderID, contents, err +} + +// decryptMessage decrypts the group message payload and returns its message ID, +// timestamp, sender ID, and message contents. +func (m *Manager) decryptMessage(g gs.Group, cMixMsg format.Message, + publicMsg publicMsg, roundTimestamp time.Time) (group.MessageID, time.Time, + *id.ID, []byte, error) { + + key, err := getCryptKey(g.Key, publicMsg.GetSalt(), cMixMsg.GetMac(), + publicMsg.GetPayload(), g.DhKeys, roundTimestamp) + if err != nil { + return group.MessageID{}, time.Time{}, nil, nil, err + } + + // Decrypt internal message + decryptedPayload := group.Decrypt(key, cMixMsg.GetKeyFP(), + publicMsg.GetPayload()) + + // Unmarshal internal message + internalMsg, err := unmarshalInternalMsg(decryptedPayload) + if err != nil { + return group.MessageID{}, time.Time{}, nil, nil, + errors.Errorf(unmarshalInternalMsgErr, err) + } + + // Unmarshal sender ID + senderID, err := internalMsg.GetSenderID() + if err != nil { + return group.MessageID{}, time.Time{}, nil, nil, + errors.Errorf(unmarshalSenderIdErr, err) + } + + messageID := group.NewMessageID(g.ID, internalMsg.Marshal()) + + return messageID, internalMsg.GetTimestamp(), senderID, + internalMsg.GetPayload(), nil +} + +// getCryptKey generates the decryption key for a group internal message. The +// key is generated using the group key, an epoch, and a salt. The epoch is +// based off the round timestamp. So, to avoid missing the correct epoch, the +// current, past, and next epochs are checked until one of them produces a key +// that matches the message's MAC. The DH key is also unknown, so each member's +// DH key is tried until there is a match. +func getCryptKey(key group.Key, salt [group.SaltLen]byte, mac, payload []byte, + dhKeys gs.DhKeyList, roundTimestamp time.Time) (group.CryptKey, error) { + // Compute the current epoch + epoch := group.ComputeEpoch(roundTimestamp) + + for _, dhKey := range dhKeys { + + // Create a key with the correct epoch + for _, epoch := range []uint32{epoch, epoch - 1, epoch + 1} { + // Generate key + cryptKey, err := group.NewKdfKey(key, epoch, salt) + if err != nil { + return group.CryptKey{}, errors.Errorf(newDecryptKeyErr, err) + } + + // Return the key if the MAC matches + if group.CheckMAC(mac, cryptKey, payload, dhKey) { + return cryptKey, nil + } + } + } + + // Return an error if none of the epochs worked + return group.CryptKey{}, errors.Errorf(genCryptKeyMacErr, epoch) +} diff --git a/groupChat/receiveRequest.go b/groupChat/receiveRequest.go new file mode 100644 index 0000000000000000000000000000000000000000..e5c7576f174f793d29ef337e092c96e4d782c371 --- /dev/null +++ b/groupChat/receiveRequest.go @@ -0,0 +1,111 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + "github.com/golang/protobuf/proto" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + gs "gitlab.com/elixxir/client/groupChat/groupStore" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" + "gitlab.com/elixxir/crypto/group" +) + +// Error message. +const ( + sendMessageTypeErr = "message not of type GroupCreationRequest" + protoUnmarshalErr = "failed to unmarshal request: %+v" + deserializeMembershipErr = "failed to deserialize membership: %+v" +) + +// receiveRequest starts the group request reception worker that waits for new +// group requests to arrive. +func (m Manager) receiveRequest(rawMsgs chan message.Receive, stop *stoppable.Single) { + jww.DEBUG.Print("Starting group message request reception worker.") + + for { + select { + case <-stop.Quit(): + jww.DEBUG.Print("Stopping group message request reception worker.") + stop.ToStopped() + return + case sendMsg := <-rawMsgs: + jww.DEBUG.Print("Group message request received send message.") + + // Generate the group from the request message + g, err := m.readRequest(sendMsg) + if err != nil { + jww.WARN.Printf("Failed to read message as group request: %+v", + err) + continue + } + + // Call request callback with the new group if it does not already + // exist + if _, exists := m.GetGroup(g.ID); !exists { + go m.requestFunc(g) + } + } + } +} + +// readRequest returns the group describes in the group request message. An +// error is returned if the request is of the wrong type or cannot be read. +func (m *Manager) readRequest(msg message.Receive) (gs.Group, error) { + // Return an error if the message is not of the right type + if msg.MessageType != message.GroupCreationRequest { + return gs.Group{}, errors.New(sendMessageTypeErr) + } + + // Unmarshal the request message + request := &Request{} + err := proto.Unmarshal(msg.Payload, request) + if err != nil { + return gs.Group{}, errors.Errorf(protoUnmarshalErr, err) + } + + // Deserialize membership list + membership, err := group.DeserializeMembership(request.Members) + if err != nil { + return gs.Group{}, errors.Errorf(deserializeMembershipErr, err) + } + + // Get the relationship with the group leader + partner, err := m.store.E2e().GetPartner(membership[0].ID) + if err != nil { + return gs.Group{}, errors.Errorf(getPrivKeyErr, err) + } + + // Replace leader's public key with the one from the partnership + leaderPubKey := membership[0].DhKey.DeepCopy() + membership[0].DhKey = partner.GetPartnerOriginPublicKey() + + // Generate the DH keys with each group member + privKey := partner.GetMyOriginPrivateKey() + grp := m.store.E2e().GetGroup() + dkl := gs.GenerateDhKeyList(m.gs.GetUser().ID, privKey, membership, grp) + + // Restore the original public key for the leader so that the membership + // digest generated later is correct + membership[0].DhKey = leaderPubKey + + // Copy preimages + var idPreimage group.IdPreimage + copy(idPreimage[:], request.IdPreimage) + var keyPreimage group.KeyPreimage + copy(keyPreimage[:], request.KeyPreimage) + + // Create group ID and key + groupID := group.NewID(idPreimage, membership) + groupKey := group.NewKey(keyPreimage, membership) + + // Return the new group + return gs.NewGroup(request.Name, groupID, groupKey, idPreimage, keyPreimage, + request.Message, membership, dkl), nil +} diff --git a/groupChat/receiveRequest_test.go b/groupChat/receiveRequest_test.go new file mode 100644 index 0000000000000000000000000000000000000000..6925853a325948b005c7c38e11f2a28621ab2b11 --- /dev/null +++ b/groupChat/receiveRequest_test.go @@ -0,0 +1,241 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + "github.com/golang/protobuf/proto" + gs "gitlab.com/elixxir/client/groupChat/groupStore" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" + "math/rand" + "strings" + "testing" + "time" +) + +// // Tests that the correct group is received from the request. +// func TestManager_receiveRequest(t *testing.T) { +// prng := rand.New(rand.NewSource(42)) +// requestChan := make(chan gs.Group) +// requestFunc := func(g gs.Group) { requestChan <- g } +// m, _ := newTestManagerWithStore(prng, 10, 0, requestFunc, nil, t) +// g := newTestGroupWithUser(m.store.E2e().GetGroup(), +// m.store.GetUser().ReceptionID, m.store.GetUser().E2eDhPublicKey, +// m.store.GetUser().E2eDhPrivateKey, prng, t) +// +// requestMarshaled, err := proto.Marshal(&Request{ +// Name: g.Name, +// IdPreimage: g.IdPreimage.Bytes(), +// KeyPreimage: g.KeyPreimage.Bytes(), +// Members: g.Members.Serialize(), +// Message: g.InitMessage, +// }) +// if err != nil { +// t.Errorf("Failed to marshal proto message: %+v", err) +// } +// +// msg := message.Receive{ +// Payload: requestMarshaled, +// MessageType: message.GroupCreationRequest, +// } +// +// rawMessages := make(chan message.Receive) +// quit := make(chan struct{}) +// go m.receiveRequest(rawMessages, quit) +// rawMessages <- msg +// +// select { +// case receivedGrp := <-requestChan: +// if !reflect.DeepEqual(g, receivedGrp) { +// t.Errorf("receiveRequest() failed to return the expected group."+ +// "\nexpected: %#v\nreceived: %#v", g, receivedGrp) +// } +// case <-time.NewTimer(5 * time.Millisecond).C: +// t.Error("Timed out while waiting for callback.") +// } +// } + +// Tests that the callback is not called when the group already exists in the +// manager. +func TestManager_receiveRequest_GroupExists(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + requestChan := make(chan gs.Group) + requestFunc := func(g gs.Group) { requestChan <- g } + m, g := newTestManagerWithStore(prng, 10, 0, requestFunc, nil, t) + + requestMarshaled, err := proto.Marshal(&Request{ + Name: g.Name, + IdPreimage: g.IdPreimage.Bytes(), + KeyPreimage: g.KeyPreimage.Bytes(), + Members: g.Members.Serialize(), + Message: g.InitMessage, + }) + if err != nil { + t.Errorf("Failed to marshal proto message: %+v", err) + } + + msg := message.Receive{ + Payload: requestMarshaled, + MessageType: message.GroupCreationRequest, + } + + rawMessages := make(chan message.Receive) + stop := stoppable.NewSingle("testStoppable") + go m.receiveRequest(rawMessages, stop) + rawMessages <- msg + + select { + case <-requestChan: + t.Error("receiveRequest() called the callback when the group already " + + "exists in the list.") + case <-time.NewTimer(5 * time.Millisecond).C: + } +} + +// Tests that the quit channel quits the worker. +func TestManager_receiveRequest_QuitChan(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + requestChan := make(chan gs.Group) + requestFunc := func(g gs.Group) { requestChan <- g } + m, _ := newTestManagerWithStore(prng, 10, 0, requestFunc, nil, t) + + rawMessages := make(chan message.Receive) + stop := stoppable.NewSingle("testStoppable") + done := make(chan struct{}) + go func() { + m.receiveRequest(rawMessages, stop) + done <- struct{}{} + }() + if err := stop.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } + + select { + case <-done: + case <-time.NewTimer(5 * time.Millisecond).C: + t.Error("receiveRequest() failed to close when the quit.") + } +} + +// Tests that the callback is not called when the send message is not of the +// correct type. +func TestManager_receiveRequest_SendMessageTypeError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + requestChan := make(chan gs.Group) + requestFunc := func(g gs.Group) { requestChan <- g } + m, _ := newTestManagerWithStore(prng, 10, 0, requestFunc, nil, t) + + msg := message.Receive{ + MessageType: message.NoType, + } + + rawMessages := make(chan message.Receive) + stop := stoppable.NewSingle("singleStoppable") + go m.receiveRequest(rawMessages, stop) + rawMessages <- msg + + select { + case receivedGrp := <-requestChan: + t.Errorf("Callback called when the message should have been skipped: %#v", + receivedGrp) + case <-time.NewTimer(5 * time.Millisecond).C: + } +} + +// // Unit test of readRequest. +// func TestManager_readRequest(t *testing.T) { +// m, g := newTestManager(rand.New(rand.NewSource(42)), t) +// _ = m.store.E2e().AddPartner( +// g.Members[0].ID, +// g.Members[0].DhKey, +// m.store.E2e().GetGroup().NewInt(43), +// params.GetDefaultE2ESessionParams(), +// params.GetDefaultE2ESessionParams(), +// ) +// +// requestMarshaled, err := proto.Marshal(&Request{ +// Name: g.Name, +// IdPreimage: g.IdPreimage.Bytes(), +// KeyPreimage: g.KeyPreimage.Bytes(), +// Members: g.Members.Serialize(), +// Message: g.InitMessage, +// }) +// if err != nil { +// t.Errorf("Failed to marshal proto message: %+v", err) +// } +// +// msg := message.Receive{ +// Payload: requestMarshaled, +// MessageType: message.GroupCreationRequest, +// } +// +// newGrp, err := m.readRequest(msg) +// if err != nil { +// t.Errorf("readRequest() returned an error: %+v", err) +// } +// +// if !reflect.DeepEqual(g, newGrp) { +// t.Errorf("readRequest() returned the wrong group."+ +// "\nexpected: %#v\nreceived: %#v", g, newGrp) +// } +// } + +// Error path: an error is returned if the message type is incorrect. +func TestManager_readRequest_MessageTypeError(t *testing.T) { + m, _ := newTestManager(rand.New(rand.NewSource(42)), t) + expectedErr := sendMessageTypeErr + msg := message.Receive{ + MessageType: message.NoType, + } + + _, err := m.readRequest(msg) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("readRequest() did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: an error is returned if the proto message cannot be unmarshalled. +func TestManager_readRequest_ProtoUnmarshalError(t *testing.T) { + expectedErr := strings.SplitN(deserializeMembershipErr, "%", 2)[0] + m, _ := newTestManager(rand.New(rand.NewSource(42)), t) + + requestMarshaled, err := proto.Marshal(&Request{ + Members: []byte("Invalid membership serial."), + }) + if err != nil { + t.Errorf("Failed to marshal proto message: %+v", err) + } + + msg := message.Receive{ + Payload: requestMarshaled, + MessageType: message.GroupCreationRequest, + } + + _, err = m.readRequest(msg) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("readRequest() did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: an error is returned if the membership cannot be deserialized. +func TestManager_readRequest_DeserializeMembershipError(t *testing.T) { + m, _ := newTestManager(rand.New(rand.NewSource(42)), t) + expectedErr := strings.SplitN(protoUnmarshalErr, "%", 2)[0] + msg := message.Receive{ + Payload: []byte("Invalid message."), + MessageType: message.GroupCreationRequest, + } + + _, err := m.readRequest(msg) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("readRequest() did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} diff --git a/groupChat/receive_test.go b/groupChat/receive_test.go new file mode 100644 index 0000000000000000000000000000000000000000..36ea8ed2dbad4c10630630f198d07d4b6a96bf54 --- /dev/null +++ b/groupChat/receive_test.go @@ -0,0 +1,409 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + "bytes" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" + "gitlab.com/elixxir/crypto/e2e" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/primitives/netTime" + "math/rand" + "reflect" + "strings" + "testing" + "time" +) + +// Tests that Manager.receive returns the correct message on the callback. +func TestManager_receive(t *testing.T) { + // Setup callback + msgChan := make(chan MessageReceive) + receiveFunc := func(msg MessageReceive) { msgChan <- msg } + + // Create new test Manager and Group + prng := rand.New(rand.NewSource(42)) + m, g := newTestManagerWithStore(prng, 10, 0, nil, receiveFunc, t) + + // Create test parameters + contents := []byte("Test group message.") + timestamp := netTime.Now() + sender := m.gs.GetUser() + + expectedMsg := MessageReceive{ + GroupID: g.ID, + ID: group.MessageID{0, 1, 2, 3}, + Payload: contents, + SenderID: sender.ID, + RoundTimestamp: timestamp.Local(), + } + + // Create cMix message and get public message + cMixMsg, err := m.newCmixMsg(g, contents, timestamp, g.Members[4], prng) + if err != nil { + t.Errorf("Failed to create new cMix message: %+v", err) + } + + internalMsg, _ := newInternalMsg(cMixMsg.ContentsSize() - publicMinLen) + internalMsg.SetTimestamp(timestamp) + internalMsg.SetSenderID(m.gs.GetUser().ID) + internalMsg.SetPayload(contents) + expectedMsg.ID = group.NewMessageID(g.ID, internalMsg.Marshal()) + + receiveChan := make(chan message.Receive, 1) + stop := stoppable.NewSingle("singleStoppable") + + m.gs.SetUser(g.Members[4], t) + go m.receive(receiveChan, stop) + + receiveChan <- message.Receive{ + Payload: cMixMsg.Marshal(), + RoundTimestamp: timestamp, + } + + select { + case msg := <-msgChan: + if !reflect.DeepEqual(expectedMsg, msg) { + t.Errorf("Failed to received expected message."+ + "\nexpected: %+v\nreceived: %+v", expectedMsg, msg) + } + case <-time.NewTimer(10 * time.Millisecond).C: + t.Errorf("Timed out waiting to receive group message.") + } +} + +// Tests that the callback is not called when the message cannot be read. +func TestManager_receive_ReadMessageError(t *testing.T) { + // Setup callback + msgChan := make(chan MessageReceive) + receiveFunc := func(msg MessageReceive) { msgChan <- msg } + + // Create new test Manager and Group + prng := rand.New(rand.NewSource(42)) + m, _ := newTestManagerWithStore(prng, 10, 0, nil, receiveFunc, t) + + receiveChan := make(chan message.Receive, 1) + stop := stoppable.NewSingle("singleStoppable") + + go m.receive(receiveChan, stop) + + receiveChan <- message.Receive{ + Payload: make([]byte, format.MinimumPrimeSize*2), + } + + select { + case <-msgChan: + t.Error("Callback called when message should have errored.") + case <-time.NewTimer(5 * time.Millisecond).C: + } +} + +// Tests that the quit channel exits the function. +func TestManager_receive_QuitChan(t *testing.T) { + // Create new test Manager and Group + prng := rand.New(rand.NewSource(42)) + m, _ := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + + receiveChan := make(chan message.Receive, 1) + stop := stoppable.NewSingle("singleStoppable") + doneChan := make(chan struct{}) + + go func() { + m.receive(receiveChan, stop) + doneChan <- struct{}{} + }() + + if err := stop.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } + + select { + case <-doneChan: + case <-time.NewTimer(10 * time.Millisecond).C: + t.Errorf("Timed out waiting for thread to quit.") + } +} + +// Tests that Manager.readMessage returns the message data for the correct group. +func TestManager_readMessage(t *testing.T) { + // Create new test Manager and Group + prng := rand.New(rand.NewSource(42)) + m, expectedGrp := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + + // Create test parameters + expectedContents := []byte("Test group message.") + expectedTimestamp := netTime.Now() + sender := m.gs.GetUser() + + // Create cMix message and get public message + cMixMsg, err := m.newCmixMsg(expectedGrp, expectedContents, + expectedTimestamp, expectedGrp.Members[4], prng) + if err != nil { + t.Errorf("Failed to create new cMix message: %+v", err) + } + + internalMsg, _ := newInternalMsg(cMixMsg.ContentsSize() - publicMinLen) + internalMsg.SetTimestamp(expectedTimestamp) + internalMsg.SetSenderID(sender.ID) + internalMsg.SetPayload(expectedContents) + expectedMsgID := group.NewMessageID(expectedGrp.ID, internalMsg.Marshal()) + + // Build message.Receive + receiveMsg := message.Receive{ + ID: e2e.MessageID{}, + Payload: cMixMsg.Marshal(), + RoundTimestamp: expectedTimestamp, + } + + m.gs.SetUser(expectedGrp.Members[4], t) + g, messageID, timestamp, senderID, contents, err := m.readMessage(receiveMsg) + if err != nil { + t.Errorf("readMessage() returned an error: %+v", err) + } + + if !reflect.DeepEqual(expectedGrp, g) { + t.Errorf("readMessage() returned incorrect group."+ + "\nexpected: %#v\nreceived: %#v", expectedGrp, g) + } + + if expectedMsgID != messageID { + t.Errorf("readMessage() returned incorrect message ID."+ + "\nexpected: %s\nreceived: %s", expectedMsgID, messageID) + } + + if !expectedTimestamp.Equal(timestamp) { + t.Errorf("readMessage() returned incorrect timestamp."+ + "\nexpected: %s\nreceived: %s", expectedTimestamp, timestamp) + } + + if !sender.ID.Cmp(senderID) { + t.Errorf("readMessage() returned incorrect sender ID."+ + "\nexpected: %s\nreceived: %s", sender.ID, senderID) + } + + if !bytes.Equal(expectedContents, contents) { + t.Errorf("readMessage() returned incorrect message."+ + "\nexpected: %s\nreceived: %s", expectedContents, contents) + } +} + +// Error path: an error is returned when a group with a matching group +// fingerprint cannot be found. +func TestManager_readMessage_FindGroupKpError(t *testing.T) { + // Create new test Manager and Group + prng := rand.New(rand.NewSource(42)) + m, g := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + + // Create test parameters + expectedContents := []byte("Test group message.") + expectedTimestamp := netTime.Now() + + // Create cMix message and get public message + cMixMsg, err := m.newCmixMsg(g, expectedContents, expectedTimestamp, g.Members[4], prng) + if err != nil { + t.Errorf("Failed to create new cMix message: %+v", err) + } + + cMixMsg.SetKeyFP(format.NewFingerprint([]byte("invalid Fingerprint"))) + + // Build message.Receive + receiveMsg := message.Receive{ + ID: e2e.MessageID{}, + Payload: cMixMsg.Marshal(), + RoundTimestamp: expectedTimestamp, + } + + expectedErr := strings.SplitN(findGroupKeyFpErr, "%", 2)[0] + + m.gs.SetUser(g.Members[4], t) + _, _, _, _, _, err = m.readMessage(receiveMsg) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("readMessage() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that a cMix message created by Manager.newCmixMsg can be read by +// Manager.readMessage. +func TestManager_decryptMessage(t *testing.T) { + // Create new test Manager and Group + prng := rand.New(rand.NewSource(42)) + m, g := newTestManager(prng, t) + + // Create test parameters + expectedContents := []byte("Test group message.") + expectedTimestamp := netTime.Now() + + // Create cMix message and get public message + msg, err := m.newCmixMsg(g, expectedContents, expectedTimestamp, g.Members[4], prng) + if err != nil { + t.Errorf("Failed to create new cMix message: %+v", err) + } + publicMsg, err := unmarshalPublicMsg(msg.GetContents()) + if err != nil { + t.Errorf("Failed to unmarshal publicMsg: %+v", err) + } + + internalMsg, _ := newInternalMsg(publicMsg.GetPayloadSize()) + internalMsg.SetTimestamp(expectedTimestamp) + internalMsg.SetSenderID(m.gs.GetUser().ID) + internalMsg.SetPayload(expectedContents) + expectedMsgID := group.NewMessageID(g.ID, internalMsg.Marshal()) + + // Read message and check if the outputs are correct + messageID, timestamp, senderID, contents, err := m.decryptMessage(g, msg, + publicMsg, expectedTimestamp) + if err != nil { + t.Errorf("decryptMessage() returned an error: %+v", err) + } + + if expectedMsgID != messageID { + t.Errorf("decryptMessage() returned incorrect message ID."+ + "\nexpected: %s\nreceived: %s", expectedMsgID, messageID) + } + + if !expectedTimestamp.Equal(timestamp) { + t.Errorf("decryptMessage() returned incorrect timestamp."+ + "\nexpected: %s\nreceived: %s", expectedTimestamp, timestamp) + } + + if !m.gs.GetUser().ID.Cmp(senderID) { + t.Errorf("decryptMessage() returned incorrect sender ID."+ + "\nexpected: %s\nreceived: %s", m.gs.GetUser().ID, senderID) + } + + if !bytes.Equal(expectedContents, contents) { + t.Errorf("decryptMessage() returned incorrect message."+ + "\nexpected: %s\nreceived: %s", expectedContents, contents) + } +} + +// Error path: an error is returned when the wrong timestamp is passed in and +// the decryption key cannot be generated because of the wrong epoch. +func TestManager_decryptMessage_GetCryptKeyError(t *testing.T) { + // Create new test Manager and Group + prng := rand.New(rand.NewSource(42)) + m, g := newTestManager(prng, t) + + // Create test parameters + contents := []byte("Test group message.") + timestamp := netTime.Now() + + // Create cMix message and get public message + msg, err := m.newCmixMsg(g, contents, timestamp, g.Members[4], prng) + if err != nil { + t.Errorf("Failed to create new cMix message: %+v", err) + } + publicMsg, err := unmarshalPublicMsg(msg.GetContents()) + if err != nil { + t.Errorf("Failed to unmarshal publicMsg: %+v", err) + } + + // Check if error is correct + expectedErr := strings.SplitN(genCryptKeyMacErr, "%", 2)[0] + _, _, _, _, err = m.decryptMessage(g, msg, publicMsg, timestamp.Add(time.Hour)) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("decryptMessage() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: an error is returned when the decrypted payload cannot be +// unmarshaled. +func TestManager_decryptMessage_UnmarshalInternalMsgError(t *testing.T) { + // Create new test Manager and Group + prng := rand.New(rand.NewSource(42)) + m, g := newTestManager(prng, t) + + // Create test parameters + contents := []byte("Test group message.") + timestamp := netTime.Now() + + // Create cMix message and get public message + msg, err := m.newCmixMsg(g, contents, timestamp, g.Members[4], prng) + if err != nil { + t.Errorf("Failed to create new cMix message: %+v", err) + } + publicMsg, err := unmarshalPublicMsg(msg.GetContents()) + if err != nil { + t.Errorf("Failed to unmarshal publicMsg: %+v", err) + } + + // Modify publicMsg to have invalid payload + publicMsg = mapPublicMsg(publicMsg.Marshal()[:33]) + key, err := group.NewKdfKey(g.Key, group.ComputeEpoch(timestamp), publicMsg.GetSalt()) + if err != nil { + t.Errorf("failed to create new key: %+v", err) + } + msg.SetMac(group.NewMAC(key, publicMsg.GetPayload(), g.DhKeys[*g.Members[4].ID])) + + // Check if error is correct + expectedErr := strings.SplitN(unmarshalInternalMsgErr, "%", 2)[0] + _, _, _, _, err = m.decryptMessage(g, msg, publicMsg, timestamp) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("decryptMessage() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Unit test of getCryptKey. +func Test_getCryptKey(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + g := newTestGroup(getGroup(), getGroup().NewInt(42), prng, t) + salt, err := newSalt(prng) + if err != nil { + t.Errorf("failed to create new salt: %+v", err) + } + payload := []byte("payload") + ts := netTime.Now() + + expectedKey, err := group.NewKdfKey(g.Key, group.ComputeEpoch(ts.Add(5*time.Minute)), salt) + if err != nil { + t.Errorf("failed to create new key: %+v", err) + } + mac := group.NewMAC(expectedKey, payload, g.DhKeys[*g.Members[4].ID]) + + key, err := getCryptKey(g.Key, salt, mac, payload, g.DhKeys, ts) + if err != nil { + t.Errorf("getCryptKey() returned an error: %+v", err) + } + + if expectedKey != key { + t.Errorf("getCryptKey() did not return the expected key."+ + "\nexpected: %v\nreceived: %v", expectedKey, key) + } +} + +// Error path: return an error when the MAC cannot be verified because the +// timestamp is incorrect and generates the wrong epoch. +func Test_getCryptKey_EpochError(t *testing.T) { + expectedErr := strings.SplitN(genCryptKeyMacErr, "%", 2)[0] + + prng := rand.New(rand.NewSource(42)) + g := newTestGroup(getGroup(), getGroup().NewInt(42), prng, t) + salt, err := newSalt(prng) + if err != nil { + t.Errorf("failed to create new salt: %+v", err) + } + payload := []byte("payload") + ts := netTime.Now() + + key, err := group.NewKdfKey(g.Key, group.ComputeEpoch(ts), salt) + if err != nil { + t.Errorf("getCryptKey() returned an error: %+v", err) + } + mac := group.NewMAC(key, payload, g.Members[4].DhKey) + + _, err = getCryptKey(g.Key, salt, mac, payload, g.DhKeys, ts.Add(time.Hour)) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("getCryptKey() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} diff --git a/groupChat/send.go b/groupChat/send.go new file mode 100644 index 0000000000000000000000000000000000000000..f2baa0953555b13f335111af4d1dd7ae0e01731e --- /dev/null +++ b/groupChat/send.go @@ -0,0 +1,222 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + "github.com/pkg/errors" + gs "gitlab.com/elixxir/client/groupChat/groupStore" + "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "io" + "time" +) + +// Error messages. +const ( + newCmixMsgErr = "failed to generate cMix messages for group chat: %+v" + sendManyCmixErr = "failed to send group chat message from member %s to group %s: %+v" + newCmixErr = "failed to generate cMix message for member %d with ID %s in group %s: %+v" + messageLenErr = "message length %d is greater than maximum message space %d" + newNoGroupErr = "failed to create message for group %s that cannot be found" + newKeyErr = "failed to generate key for encrypting group payload" + newPublicMsgErr = "failed to create new public group message for cMix message: %+v" + newInternalMsgErr = "failed to create new internal group message for cMix message: %+v" + saltReadErr = "failed to generate salt for group message: %+v" + saltReadLengthErr = "length of generated salt %d != %d required" +) + +// Send sends a message to all group members using Client.SendManyCMIX. The +// send fails if the message is too long. +func (m *Manager) Send(groupID *id.ID, message []byte) (id.Round, error) { + + // Create a cMix message for each group member + messages, err := m.createMessages(groupID, message) + if err != nil { + return 0, errors.Errorf(newCmixMsgErr, err) + } + + rid, _, err := m.net.SendManyCMIX(messages, params.GetDefaultCMIX()) + if err != nil { + return 0, errors.Errorf(sendManyCmixErr, m.gs.GetUser().ID, groupID, err) + } + + return rid, nil +} + +// createMessages generates a list of cMix messages and a list of corresponding +// recipient IDs. +func (m *Manager) createMessages(groupID *id.ID, msg []byte) (map[id.ID]format.Message, error) { + timeNow := netTime.Now() + + g, exists := m.gs.Get(groupID) + if !exists { + return map[id.ID]format.Message{}, errors.Errorf(newNoGroupErr, groupID) + } + + return m.newMessages(g, msg, timeNow) +} + +// newMessages is a private function that allows the passing in of a timestamp +// and streamGen instead of a fastRNG.StreamGenerator for easier testing. +func (m *Manager) newMessages(g gs.Group, msg []byte, + timestamp time.Time) (map[id.ID]format.Message, error) { + // Create list of cMix messages + messages := make(map[id.ID]format.Message) + + // Create channels to receive messages and errors on + type msgInfo struct { + msg format.Message + id *id.ID + } + msgChan := make(chan msgInfo, len(g.Members)-1) + errChan := make(chan error, len(g.Members)-1) + + // Create cMix messages in parallel + for i, member := range g.Members { + // Do not send to the sender + if m.gs.GetUser().ID.Cmp(member.ID) { + continue + } + + // Start thread to build cMix message + go func(member group.Member, i int) { + // Create new stream + rng := m.rng.GetStream() + defer rng.Close() + + // Add cMix message to list + cMixMsg, err := m.newCmixMsg(g, msg, timestamp, member, rng) + if err != nil { + errChan <- errors.Errorf(newCmixErr, i, member.ID, g.ID, err) + } + msgChan <- msgInfo{cMixMsg, member.ID} + + }(member, i) + } + + // Wait for messages or errors + for len(messages) < len(g.Members)-1 { + select { + case err := <-errChan: + // Return on the first error that occurs + return nil, err + case info := <-msgChan: + messages[*info.id] = info.msg + } + } + + return messages, nil +} + +// newCmixMsg generates a new cMix message to be sent to a group member. +func (m *Manager) newCmixMsg(g gs.Group, msg []byte, timestamp time.Time, + mem group.Member, rng io.Reader) (format.Message, error) { + + // Create three message layers + cmixMsg := format.NewMessage(m.store.Cmix().GetGroup().GetP().ByteLen()) + publicMsg, internalMsg, err := newMessageParts(cmixMsg.ContentsSize()) + if err != nil { + return cmixMsg, err + } + + // Return an error if the message is too large to fit in the payload + if internalMsg.GetPayloadMaxSize() < len(msg) { + return cmixMsg, errors.Errorf(messageLenErr, len(msg), + internalMsg.GetPayloadMaxSize()) + } + + // Generate 256-bit salt + salt, err := newSalt(rng) + if err != nil { + return cmixMsg, err + } + + // Generate key fingerprint + keyFp := group.NewKeyFingerprint(g.Key, salt, mem.ID) + + // Generate key + key, err := group.NewKdfKey(g.Key, group.ComputeEpoch(timestamp), salt) + if err != nil { + return cmixMsg, errors.WithMessage(err, newKeyErr) + } + + // Generate internal message + payload := setInternalPayload(internalMsg, timestamp, m.gs.GetUser().ID, msg) + + // Encrypt internal message + encryptedPayload := group.Encrypt(key, keyFp, payload) + + // Generate public message + publicPayload := setPublicPayload(publicMsg, salt, encryptedPayload) + + // Generate MAC + mac := group.NewMAC(key, encryptedPayload, g.DhKeys[*mem.ID]) + + // Construct cMix message + cmixMsg.SetContents(publicPayload) + cmixMsg.SetKeyFP(keyFp) + cmixMsg.SetMac(mac) + + return cmixMsg, nil +} + +// newMessageParts generates a public payload message and the internal payload +// message. An error is returned if the messages cannot fit in the payloadSize. +func newMessageParts(payloadSize int) (publicMsg, internalMsg, error) { + publicMsg, err := newPublicMsg(payloadSize) + if err != nil { + return publicMsg, internalMsg{}, errors.Errorf(newPublicMsgErr, err) + } + + internalMsg, err := newInternalMsg(publicMsg.GetPayloadSize()) + if err != nil { + return publicMsg, internalMsg, errors.Errorf(newInternalMsgErr, err) + } + + return publicMsg, internalMsg, nil +} + +// newSalt generates a new salt of the specified size. +func newSalt(rng io.Reader) ([group.SaltLen]byte, error) { + var salt [group.SaltLen]byte + n, err := rng.Read(salt[:]) + if err != nil { + return salt, errors.Errorf(saltReadErr, err) + } else if n != group.SaltLen { + return salt, errors.Errorf(saltReadLengthErr, group.SaltLen, n) + } + + return salt, nil +} + +// setInternalPayload sets the timestamp, sender ID, and message of the +// internalMsg and returns the marshal bytes. +func setInternalPayload(internalMsg internalMsg, timestamp time.Time, + sender *id.ID, msg []byte) []byte { + // Set timestamp, sender ID, and message to the internalMsg + internalMsg.SetTimestamp(timestamp) + internalMsg.SetSenderID(sender) + internalMsg.SetPayload(msg) + + // Return the payload marshaled + return internalMsg.Marshal() +} + +// setPublicPayload sets the salt and encrypted payload of the publicMsg and +// returns the marshal bytes. +func setPublicPayload(publicMsg publicMsg, salt [group.SaltLen]byte, + encryptedPayload []byte) []byte { + // Set salt and payload + publicMsg.SetSalt(salt) + publicMsg.SetPayload(encryptedPayload) + + return publicMsg.Marshal() +} diff --git a/groupChat/sendRequests.go b/groupChat/sendRequests.go new file mode 100644 index 0000000000000000000000000000000000000000..3e4cc39f2995dbe870b4caa771cbd3a075c77eb7 --- /dev/null +++ b/groupChat/sendRequests.go @@ -0,0 +1,129 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + "github.com/golang/protobuf/proto" + "github.com/pkg/errors" + gs "gitlab.com/elixxir/client/groupChat/groupStore" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/xx_network/primitives/id" + "strings" +) + +// Error messages. +const ( + resendGroupIdErr = "cannot resend request to nonexistent group with ID %s" + protoMarshalErr = "failed to form outgoing group chat request: %+v" + sendE2eErr = "failed to send group request via E2E to member %s: %+v" + sendRequestAllErr = "failed to send all %d group request messages: %s" + sendRequestPartialErr = "failed to send %d/%d group request messages: %s" +) + +// ResendRequest allows a groupChat request to be sent again. +func (m Manager) ResendRequest(groupID *id.ID) ([]id.Round, RequestStatus, error) { + g, exists := m.gs.Get(groupID) + if !exists { + return nil, NotSent, errors.Errorf(resendGroupIdErr, groupID) + } + + return m.sendRequests(g) +} + +// sendRequests sends group requests to each member in the group except for the +// leader/sender +func (m Manager) sendRequests(g gs.Group) ([]id.Round, RequestStatus, error) { + // Build request message + requestMarshaled, err := proto.Marshal(&Request{ + Name: g.Name, + IdPreimage: g.IdPreimage.Bytes(), + KeyPreimage: g.KeyPreimage.Bytes(), + Members: g.Members.Serialize(), + Message: g.InitMessage, + }) + if err != nil { + return nil, NotSent, errors.Errorf(protoMarshalErr, err) + } + + // Create channel to return the results of each send on + n := len(g.Members) - 1 + type sendResults struct { + rounds []id.Round + err error + } + resultsChan := make(chan sendResults, n) + + // Send request to each member in the group except the leader/sender + for _, member := range g.Members[1:] { + go func(member group.Member) { + rounds, err := m.sendRequest(member.ID, requestMarshaled) + resultsChan <- sendResults{rounds, err} + }(member) + } + + // Block until each send returns + roundIDs := make(map[id.Round]struct{}) + var errs []string + for i := 0; i < n; { + select { + case results := <-resultsChan: + for _, rid := range results.rounds { + roundIDs[rid] = struct{}{} + } + if results.err != nil { + errs = append(errs, results.err.Error()) + } + i++ + } + } + + // If all sends returned an error, then return AllFail with a list of errors + if len(errs) == n { + return nil, AllFail, + errors.Errorf(sendRequestAllErr, len(errs), strings.Join(errs, "\n")) + } + + // If some sends returned an error, then return a list of round IDs for the + // successful sends and a list of errors for the failed ones + if len(errs) > 0 { + return roundIdMap2List(roundIDs), PartialSent, + errors.Errorf(sendRequestPartialErr, len(errs), n, + strings.Join(errs, "\n")) + } + + // If all sends succeeded, return a list of roundIDs + return roundIdMap2List(roundIDs), AllSent, nil +} + +// sendRequest sends the group request to the user via E2E. +func (m Manager) sendRequest(memberID *id.ID, request []byte) ([]id.Round, error) { + sendMsg := message.Send{ + Recipient: memberID, + Payload: request, + MessageType: message.GroupCreationRequest, + } + + rounds, _, err := m.net.SendE2E(sendMsg, params.GetDefaultE2E(), nil) + if err != nil { + return nil, errors.Errorf(sendE2eErr, memberID, err) + } + + return rounds, nil +} + +// roundIdMap2List converts the map of round IDs to a list of round IDs. +func roundIdMap2List(m map[id.Round]struct{}) []id.Round { + roundIDs := make([]id.Round, 0, len(m)) + for rid := range m { + roundIDs = append(roundIDs, rid) + } + + return roundIDs +} diff --git a/groupChat/sendRequests_test.go b/groupChat/sendRequests_test.go new file mode 100644 index 0000000000000000000000000000000000000000..56ca284fbc66cb78622388bd11d0454db3280bce --- /dev/null +++ b/groupChat/sendRequests_test.go @@ -0,0 +1,274 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + "fmt" + "github.com/golang/protobuf/proto" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/xx_network/primitives/id" + "math/rand" + "reflect" + "sort" + "strings" + "testing" +) + +// Tests that Manager.ResendRequest sends all expected requests successfully. +func TestManager_ResendRequest(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, g := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + + expected := &Request{ + Name: g.Name, + IdPreimage: g.IdPreimage.Bytes(), + KeyPreimage: g.KeyPreimage.Bytes(), + Members: g.Members.Serialize(), + Message: g.InitMessage, + } + + _, status, err := m.ResendRequest(g.ID) + if err != nil { + t.Errorf("ResendRequest() returned an error: %+v", err) + } + + if status != AllSent { + t.Errorf("ResendRequest() failed to return the expected status."+ + "\nexpected: %s\nreceived: %s", AllSent, status) + } + + if len(m.net.(*testNetworkManager).e2eMessages) < len(g.Members)-1 { + t.Errorf("ResendRequest() failed to send the correct number of requests."+ + "\nexpected: %d\nreceived: %d", len(g.Members)-1, + len(m.net.(*testNetworkManager).e2eMessages)) + } + + for i := 0; i < len(m.net.(*testNetworkManager).e2eMessages); i++ { + msg := m.net.(*testNetworkManager).GetE2eMsg(i) + + // Check if the message recipient is a member in the group + matchesMember := false + for j, m := range g.Members { + if msg.Recipient.Cmp(m.ID) { + matchesMember = true + g.Members = append(g.Members[:j], g.Members[j+1:]...) + break + } + } + if !matchesMember { + t.Errorf("Message %d has recipient ID %s that is not in membership.", + i, msg.Recipient) + } + + testRequest := &Request{} + err = proto.Unmarshal(msg.Payload, testRequest) + if err != nil { + t.Errorf("Failed to unmarshal proto message (%d): %+v", i, err) + } + + if expected.String() != testRequest.String() { + t.Errorf("Message %d has unexpected payload."+ + "\nexpected: %s\nreceived: %s", i, expected, testRequest) + } + } +} + +// Error path: an error is returned when no group with the corresponding group +// ID exists. +func TestManager_ResendRequest_GetGroupError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, _ := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + expectedErr := strings.SplitN(resendGroupIdErr, "%", 2)[0] + + _, status, err := m.ResendRequest(id.NewIdFromString("invalidID", id.Group, t)) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("ResendRequest() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } + + if status != NotSent { + t.Errorf("ResendRequest() failed to return the expected status."+ + "\nexpected: %s\nreceived: %s", NotSent, status) + } +} + +// Tests that Manager.sendRequests sends all expected requests successfully. +func TestManager_sendRequests(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, g := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + + expected := &Request{ + Name: g.Name, + IdPreimage: g.IdPreimage.Bytes(), + KeyPreimage: g.KeyPreimage.Bytes(), + Members: g.Members.Serialize(), + Message: g.InitMessage, + } + + _, status, err := m.sendRequests(g) + if err != nil { + t.Errorf("sendRequests() returned an error: %+v", err) + } + + if status != AllSent { + t.Errorf("sendRequests() failed to return the expected status."+ + "\nexpected: %s\nreceived: %s", AllSent, status) + } + + if len(m.net.(*testNetworkManager).e2eMessages) < len(g.Members)-1 { + t.Errorf("sendRequests() failed to send the correct number of requests."+ + "\nexpected: %d\nreceived: %d", len(g.Members)-1, + len(m.net.(*testNetworkManager).e2eMessages)) + } + + for i := 0; i < len(m.net.(*testNetworkManager).e2eMessages); i++ { + msg := m.net.(*testNetworkManager).GetE2eMsg(i) + + // Check if the message recipient is a member in the group + matchesMember := false + for j, m := range g.Members { + if msg.Recipient.Cmp(m.ID) { + matchesMember = true + g.Members = append(g.Members[:j], g.Members[j+1:]...) + break + } + } + if !matchesMember { + t.Errorf("Message %d has recipient ID %s that is not in membership.", + i, msg.Recipient) + } + + testRequest := &Request{} + err = proto.Unmarshal(msg.Payload, testRequest) + if err != nil { + t.Errorf("Failed to unmarshal proto message (%d): %+v", i, err) + } + + if expected.String() != testRequest.String() { + t.Errorf("Message %d has unexpected payload."+ + "\nexpected: %s\nreceived: %s", i, expected, testRequest) + } + } +} + +// Tests that Manager.sendRequests returns the correct status when all sends +// fail. +func TestManager_sendRequests_SendAllFail(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, g := newTestManagerWithStore(prng, 10, 1, nil, nil, t) + expectedErr := fmt.Sprintf(sendRequestAllErr, len(g.Members)-1, "") + + rounds, status, err := m.sendRequests(g) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("sendRequests() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } + + if status != AllFail { + t.Errorf("sendRequests() failed to return the expected status."+ + "\nexpected: %s\nreceived: %s", AllFail, status) + } + + if rounds != nil { + t.Errorf("sendRequests() returned rounds on failure."+ + "\nexpected: %v\nreceived: %v", nil, rounds) + } + + if len(m.net.(*testNetworkManager).e2eMessages) != 0 { + t.Errorf("sendRequests() sent %d messages when sending should have failed.", + len(m.net.(*testNetworkManager).e2eMessages)) + } +} + +// Tests that Manager.sendRequests returns the correct status when some of the +// sends fail. +func TestManager_sendRequests_SendPartialSent(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, g := newTestManagerWithStore(prng, 10, 2, nil, nil, t) + expectedErr := fmt.Sprintf(sendRequestPartialErr, (len(g.Members)-1)/2, + len(g.Members)-1, "") + + _, status, err := m.sendRequests(g) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("sendRequests() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } + + if status != PartialSent { + t.Errorf("sendRequests() failed to return the expected status."+ + "\nexpected: %s\nreceived: %s", PartialSent, status) + } + + if len(m.net.(*testNetworkManager).e2eMessages) != (len(g.Members)-1)/2+1 { + t.Errorf("sendRequests() sent %d out of %d expected messages.", + len(m.net.(*testNetworkManager).e2eMessages), (len(g.Members)-1)/2+1) + } +} + +// Unit test of Manager.sendRequest. +func TestManager_sendRequest(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, g := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + + expected := message.Send{ + Recipient: g.Members[0].ID, + Payload: []byte("request message"), + MessageType: message.GroupCreationRequest, + } + _, err := m.sendRequest(expected.Recipient, expected.Payload) + if err != nil { + t.Errorf("sendRequest() returned an error: %+v", err) + } + + received := m.net.(*testNetworkManager).GetE2eMsg(0) + + if !reflect.DeepEqual(expected, received) { + t.Errorf("sendRequest() did not send the correct message."+ + "\nexpected: %+v\nreceived: %+v", expected, received) + } +} + +// Error path: an error is returned when SendE2E fails +func TestManager_sendRequest_SendE2eError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, _ := newTestManagerWithStore(prng, 10, 1, nil, nil, t) + expectedErr := strings.SplitN(sendE2eErr, "%", 2)[0] + + _, err := m.sendRequest(id.NewIdFromString("memberID", id.User, t), nil) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("sendRequest() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Unit test of roundIdMap2List. +func Test_roundIdMap2List(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + + // Construct map and expected list + n := 100 + expected := make([]id.Round, n) + ridMap := make(map[id.Round]struct{}, n) + for i := 0; i < n; i++ { + expected[i] = id.Round(prng.Uint64()) + ridMap[expected[i]] = struct{}{} + } + + // Create list of IDs from map + ridList := roundIdMap2List(ridMap) + + // Sort expected and received slices to see if they match + sort.Slice(expected, func(i, j int) bool { return expected[i] < expected[j] }) + sort.Slice(ridList, func(i, j int) bool { return ridList[i] < ridList[j] }) + + if !reflect.DeepEqual(expected, ridList) { + t.Errorf("roundIdMap2List() failed to return the expected list."+ + "\nexpected: %v\nreceived: %v", expected, ridList) + } + +} diff --git a/groupChat/send_test.go b/groupChat/send_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1db5cec791e7790bf519ab98030a088f423b076e --- /dev/null +++ b/groupChat/send_test.go @@ -0,0 +1,557 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + "bytes" + "encoding/base64" + gs "gitlab.com/elixxir/client/groupChat/groupStore" + "gitlab.com/elixxir/client/storage" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "math/rand" + "strings" + "testing" + "time" +) + +// Unit test of Manager.Send. +func TestManager_Send(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, g := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + message := []byte("Group chat message.") + sender := m.gs.GetUser().DeepCopy() + + _, err := m.Send(g.ID, message) + if err != nil { + t.Errorf("Send() returned an error: %+v", err) + } + + // Get messages sent with or return an error if no messages were sent + var messages map[id.ID]format.Message + if len(m.net.(*testNetworkManager).messages) > 0 { + messages = m.net.(*testNetworkManager).GetMsgMap(0) + } else { + t.Error("No group cMix messages received.") + } + + timeNow := netTime.Now() + + // Loop through each message and make sure the recipient ID matches a member + // in the group and that each message can be decrypted and have the expected + // values + for rid, msg := range messages { + // Check if recipient ID is in member list + var foundMember group.Member + for _, mem := range g.Members { + if rid.Cmp(mem.ID) { + foundMember = mem + } + } + + // Error if the recipient ID is not found in the member list + if foundMember == (group.Member{}) { + t.Errorf("Failed to find ID %s in memorship list.", rid) + continue + } + + publicMsg, err := unmarshalPublicMsg(msg.GetContents()) + if err != nil { + t.Errorf("Failed to unmarshal publicMsg: %+v", err) + } + // Attempt to read the message + messageID, timestamp, senderID, readMsg, err := m.decryptMessage( + g, msg, publicMsg, timeNow) + if err != nil { + t.Errorf("Failed to read message for %s: %+v", rid.String(), err) + } + + internalMsg, _ := newInternalMsg(publicMsg.GetPayloadSize()) + internalMsg.SetTimestamp(timestamp) + internalMsg.SetSenderID(m.gs.GetUser().ID) + internalMsg.SetPayload(message) + expectedMsgID := group.NewMessageID(g.ID, internalMsg.Marshal()) + + if expectedMsgID != messageID { + t.Errorf("Message ID received for %s too different from expected."+ + "\nexpected: %s\nreceived: %s", &rid, expectedMsgID, messageID) + } + + if !timestamp.Round(5 * time.Second).Equal(timeNow.Round(5 * time.Second)) { + t.Errorf("Timestamp received for %s too different from expected."+ + "\nexpected: %s\nreceived: %s", &rid, timeNow, timestamp) + } + + if !senderID.Cmp(sender.ID) { + t.Errorf("Sender ID received for %s incorrect."+ + "\nexpected: %s\nreceived: %s", &rid, sender.ID, senderID) + } + + if !bytes.Equal(readMsg, message) { + t.Errorf("Message received for %s incorrect."+ + "\nexpected: %q\nreceived: %q", &rid, message, readMsg) + } + } +} + +// Error path: error is returned when the message is too large. +func TestManager_Send_CmixMessageError(t *testing.T) { + // Set up new test manager that will make SendManyCMIX error + prng := rand.New(rand.NewSource(42)) + m, g := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + expectedErr := strings.SplitN(newCmixMsgErr, "%", 2)[0] + + // Send message + _, err := m.Send(g.ID, make([]byte, 400)) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("Send() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: SendManyCMIX returns an error. +func TestManager_Send_SendManyCMIXError(t *testing.T) { + // Set up new test manager that will make SendManyCMIX error + prng := rand.New(rand.NewSource(42)) + m, g := newTestManagerWithStore(prng, 10, 1, nil, nil, t) + expectedErr := strings.SplitN(sendManyCmixErr, "%", 2)[0] + + // Send message + _, err := m.Send(g.ID, []byte("message")) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("Send() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } + + // If messages were added, then error + if len(m.net.(*testNetworkManager).messages) > 0 { + t.Error("Group cMix messages received when SendManyCMIX errors.") + } +} + +// Tests that Manager.createMessages generates the messages for the correct group. +func TestManager_createMessages(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, g := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + + message := []byte("Test group message.") + sender := m.gs.GetUser() + messages, err := m.createMessages(g.ID, message) + if err != nil { + t.Errorf("createMessages() returned an error: %+v", err) + } + + recipients := append(g.Members[:2], g.Members[3:]...) + + i := 0 + for rid, msg := range messages { + for _, recipient := range recipients { + if !rid.Cmp(recipient.ID) { + continue + } + + publicMsg, err := unmarshalPublicMsg(msg.GetContents()) + if err != nil { + t.Errorf("Failed to unmarshal publicMsg: %+v", err) + } + + messageID, timestamp, testSender, testMessage, err := m.decryptMessage( + g, msg, publicMsg, netTime.Now()) + if err != nil { + t.Errorf("Failed to find member to read message %d: %+v", i, err) + } + + internalMsg, _ := newInternalMsg(publicMsg.GetPayloadSize()) + internalMsg.SetTimestamp(timestamp) + internalMsg.SetSenderID(m.gs.GetUser().ID) + internalMsg.SetPayload(message) + expectedMsgID := group.NewMessageID(g.ID, internalMsg.Marshal()) + + if messageID != expectedMsgID { + t.Errorf("Failed to read correct message ID for message %d."+ + "\nexpected: %s\nreceived: %s", i, expectedMsgID, messageID) + } + + if !sender.ID.Cmp(testSender) { + t.Errorf("Failed to read correct sender ID for message %d."+ + "\nexpected: %s\nreceived: %s", i, sender.ID, testSender) + } + + if !bytes.Equal(message, testMessage) { + t.Errorf("Failed to read correct message for message %d."+ + "\nexpected: %s\nreceived: %s", i, message, testMessage) + } + } + i++ + } +} + +// Error path: test that an error is returned when the group ID does not match a +// group in storage. +func TestManager_createMessages_InvalidGroupIdError(t *testing.T) { + expectedErr := strings.SplitN(newNoGroupErr, "%", 2)[0] + + // Create new test Manager and Group + prng := rand.New(rand.NewSource(42)) + m, _ := newTestManagerWithStore(prng, 10, 0, nil, nil, t) + + // Read message and make sure the error is expected + _, err := m.createMessages(id.NewIdFromString("invalidID", id.Group, t), nil) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("createMessages() did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that Manager.newMessage returns messages with correct data. +func TestGroup_newMessages(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + m, g := newTestManager(prng, t) + + message := []byte("Test group message.") + sender := m.gs.GetUser() + timestamp := netTime.Now() + messages, err := m.newMessages(g, message, timestamp) + if err != nil { + t.Errorf("newMessages() returned an error: %+v", err) + } + + recipients := append(g.Members[:2], g.Members[3:]...) + + i := 0 + for rid, msg := range messages { + for _, recipient := range recipients { + if !rid.Cmp(recipient.ID) { + continue + } + + publicMsg, err := unmarshalPublicMsg(msg.GetContents()) + if err != nil { + t.Errorf("Failed to unmarshal publicMsg: %+v", err) + } + + messageID, testTimestamp, testSender, testMessage, err := m.decryptMessage( + g, msg, publicMsg, netTime.Now()) + if err != nil { + t.Errorf("Failed to find member to read message %d.", i) + } + + internalMsg, _ := newInternalMsg(publicMsg.GetPayloadSize()) + internalMsg.SetTimestamp(timestamp) + internalMsg.SetSenderID(m.gs.GetUser().ID) + internalMsg.SetPayload(message) + expectedMsgID := group.NewMessageID(g.ID, internalMsg.Marshal()) + + if messageID != expectedMsgID { + t.Errorf("Failed to read correct message ID for message %d."+ + "\nexpected: %s\nreceived: %s", i, expectedMsgID, messageID) + } + + if !timestamp.Equal(testTimestamp) { + t.Errorf("Failed to read correct timeout for message %d."+ + "\nexpected: %s\nreceived: %s", i, timestamp, testTimestamp) + } + + if !sender.ID.Cmp(testSender) { + t.Errorf("Failed to read correct sender ID for message %d."+ + "\nexpected: %s\nreceived: %s", i, sender.ID, testSender) + } + + if !bytes.Equal(message, testMessage) { + t.Errorf("Failed to read correct message for message %d."+ + "\nexpected: %s\nreceived: %s", i, message, testMessage) + } + } + i++ + } +} + +// Error path: an error is returned when Manager.neCmixMsg returns an error. +func TestGroup_newMessages_NewCmixMsgError(t *testing.T) { + expectedErr := strings.SplitN(newCmixErr, "%", 2)[0] + prng := rand.New(rand.NewSource(42)) + m, g := newTestManager(prng, t) + + _, err := m.newMessages(g, make([]byte, 1000), netTime.Now()) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("newMessages() failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that the message returned by newCmixMsg has all the expected parts. +func TestGroup_newCmixMsg(t *testing.T) { + // Create new test Manager and Group + prng := rand.New(rand.NewSource(42)) + m, g := newTestManager(prng, t) + + // Create test parameters + message := []byte("Test group message.") + mem := g.Members[3] + timeNow := netTime.Now() + + // Create cMix message + prng = rand.New(rand.NewSource(42)) + msg, err := m.newCmixMsg(g, message, timeNow, mem, prng) + if err != nil { + t.Errorf("newCmixMsg() returned an error: %+v", err) + } + + // Create expected salt + prng = rand.New(rand.NewSource(42)) + var salt [group.SaltLen]byte + prng.Read(salt[:]) + + // Create expected key + key, _ := group.NewKdfKey(g.Key, group.ComputeEpoch(timeNow), salt) + + // Create expected messages + cmixMsg := format.NewMessage(m.store.Cmix().GetGroup().GetP().ByteLen()) + publicMsg, _ := newPublicMsg(cmixMsg.ContentsSize()) + internalMsg, _ := newInternalMsg(publicMsg.GetPayloadSize()) + internalMsg.SetTimestamp(timeNow) + internalMsg.SetSenderID(m.gs.GetUser().ID) + internalMsg.SetPayload(message) + payload := internalMsg.Marshal() + + // Check if key fingerprint is correct + expectedFp := group.NewKeyFingerprint(g.Key, salt, mem.ID) + if expectedFp != msg.GetKeyFP() { + t.Errorf("newCmixMsg() returned message with wrong key fingerprint."+ + "\nexpected: %s\nreceived: %s", expectedFp, msg.GetKeyFP()) + } + + // Check if key MAC is correct + encryptedPayload := group.Encrypt(key, expectedFp, payload) + expectedMAC := group.NewMAC(key, encryptedPayload, g.DhKeys[*mem.ID]) + if !bytes.Equal(expectedMAC, msg.GetMac()) { + t.Errorf("newCmixMsg() returned message with wrong MAC."+ + "\nexpected: %+v\nreceived: %+v", expectedMAC, msg.GetMac()) + } + + // Attempt to unmarshal public group message + publicMsg, err = unmarshalPublicMsg(msg.GetContents()) + if err != nil { + t.Errorf("Failed to unmarshal cMix message contents: %+v", err) + } + + // Attempt to decrypt payload + decryptedPayload := group.Decrypt(key, expectedFp, publicMsg.GetPayload()) + internalMsg, err = unmarshalInternalMsg(decryptedPayload) + if err != nil { + t.Errorf("Failed to unmarshal decrypted payload contents: %+v", err) + } + + // Check for expected values in internal message + if !internalMsg.GetTimestamp().Equal(timeNow) { + t.Errorf("Internal message has wrong timestamp."+ + "\nexpected: %s\nreceived: %s", timeNow, internalMsg.GetTimestamp()) + } + sid, err := internalMsg.GetSenderID() + if err != nil { + t.Fatalf("Failed to get sender ID from internal message: %+v", err) + } + if !sid.Cmp(m.gs.GetUser().ID) { + t.Errorf("Internal message has wrong sender ID."+ + "\nexpected: %s\nreceived: %s", m.gs.GetUser().ID, sid) + } + if !bytes.Equal(internalMsg.GetPayload(), message) { + t.Errorf("Internal message has wrong payload."+ + "\nexpected: %s\nreceived: %s", message, internalMsg.GetPayload()) + } +} + +// Error path: reader returns an error. +func TestGroup_newCmixMsg_SaltReaderError(t *testing.T) { + expectedErr := strings.SplitN(saltReadErr, "%", 2)[0] + m := &Manager{store: storage.InitTestingSession(t)} + + _, err := m.newCmixMsg(gs.Group{}, []byte{}, time.Time{}, group.Member{}, strings.NewReader("")) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("newCmixMsg() failed to return the expected error"+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: size of message is too large for the internalMsg. +func TestGroup_newCmixMsg_InternalMsgSizeError(t *testing.T) { + expectedErr := strings.SplitN(messageLenErr, "%", 2)[0] + + // Create new test Manager and Group + prng := rand.New(rand.NewSource(42)) + m, g := newTestManager(prng, t) + + // Create test parameters + message := make([]byte, 341) + mem := group.Member{ID: id.NewIdFromString("memberID", id.User, t)} + + // Create cMix message + prng = rand.New(rand.NewSource(42)) + _, err := m.newCmixMsg(g, message, netTime.Now(), mem, prng) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("newCmixMsg() failed to return the expected error"+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: payload size too small to fit publicMsg. +func Test_newMessageParts_PublicMsgSizeErr(t *testing.T) { + expectedErr := strings.SplitN(newPublicMsgErr, "%", 2)[0] + + _, _, err := newMessageParts(publicMinLen - 1) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("newMessageParts() did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: payload size too small to fit internalMsg. +func Test_newMessageParts_InternalMsgSizeErr(t *testing.T) { + expectedErr := strings.SplitN(newInternalMsgErr, "%", 2)[0] + + _, _, err := newMessageParts(publicMinLen) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("newMessageParts() did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests the consistency of newSalt. +func Test_newSalt_Consistency(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + expectedSalts := []string{ + "U4x/lrFkvxuXu59LtHLon1sUhPJSCcnZND6SugndnVI=", + "39ebTXZCm2F6DJ+fDTulWwzA1hRMiIU1hBrL4HCbB1g=", + "CD9h03W8ArQd9PkZKeGP2p5vguVOdI6B555LvW/jTNw=", + "uoQ+6NY+jE/+HOvqVG2PrBPdGqwEzi6ih3xVec+ix44=", + "GwuvrogbgqdREIpC7TyQPKpDRlp4YgYWl4rtDOPGxPM=", + "rnvD4ElbVxL+/b4MECiH4QDazS2IX2kstgfaAKEcHHA=", + "ceeWotwtwlpbdLLhKXBeJz8FySMmgo4rBW44F2WOEGE=", + "SYlH/fNEQQ7UwRYCP6jjV2tv7Sf/iXS6wMr9mtBWkrE=", + "NhnnOJZN/ceejVNDc2Yc/WbXT+weG4lJGrcjbkt1IWI=", + } + + for i, expected := range expectedSalts { + salt, err := newSalt(prng) + if err != nil { + t.Errorf("newSalt() returned an error (%d): %+v", i, err) + } + + saltString := base64.StdEncoding.EncodeToString(salt[:]) + + if expected != saltString { + t.Errorf("newSalt() did not return the expected salt (%d)."+ + "\nexpected: %s\nreceived: %s", i, expected, saltString) + } + + // fmt.Printf("\"%s\",\n", saltString) + } +} + +// Error path: reader returns an error. +func Test_newSalt_ReadError(t *testing.T) { + expectedErr := strings.SplitN(saltReadErr, "%", 2)[0] + + _, err := newSalt(strings.NewReader("")) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("newSalt() failed to return the expected error"+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: reader fails to return enough bytes. +func Test_newSalt_ReadLengthError(t *testing.T) { + expectedErr := strings.SplitN(saltReadLengthErr, "%", 2)[0] + + _, err := newSalt(strings.NewReader("A")) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("newSalt() failed to return the expected error"+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that the marshaled internalMsg can be unmarshaled and has all the +// original values. +func Test_setInternalPayload(t *testing.T) { + internalMsg, err := newInternalMsg(internalMinLen * 2) + if err != nil { + t.Errorf("Failed to create a new internalMsg: %+v", err) + } + + timestamp := netTime.Now() + sender := id.NewIdFromString("sender ID", id.User, t) + message := []byte("This is an internal message.") + + payload := setInternalPayload(internalMsg, timestamp, sender, message) + if err != nil { + t.Errorf("setInternalPayload() returned an error: %+v", err) + } + + // Attempt to unmarshal and check all values + unmarshalled, err := unmarshalInternalMsg(payload) + if err != nil { + t.Errorf("Failed to unmarshal internalMsg: %+v", err) + } + + if !timestamp.Equal(unmarshalled.GetTimestamp()) { + t.Errorf("Timestamp does not match original.\nexpected: %s\nreceived: %s", + timestamp, unmarshalled.GetTimestamp()) + } + + testSender, err := unmarshalled.GetSenderID() + if err != nil { + t.Errorf("Failed to get sender ID: %+v", err) + } + if !sender.Cmp(testSender) { + t.Errorf("Sender ID does not match original.\nexpected: %s\nreceived: %s", + sender, testSender) + } + + if !bytes.Equal(message, unmarshalled.GetPayload()) { + t.Errorf("Payload does not match original.\nexpected: %v\nreceived: %v", + message, unmarshalled.GetPayload()) + } +} + +// Tests that the marshaled publicMsg can be unmarshaled and has all the +// original values. +func Test_setPublicPayload(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + publicMsg, err := newPublicMsg(publicMinLen * 2) + if err != nil { + t.Errorf("Failed to create a new publicMsg: %+v", err) + } + + var salt [group.SaltLen]byte + prng.Read(salt[:]) + encryptedPayload := make([]byte, publicMsg.GetPayloadSize()) + copy(encryptedPayload, "This is an internal message.") + + payload := setPublicPayload(publicMsg, salt, encryptedPayload) + if err != nil { + t.Errorf("setPublicPayload() returned an error: %+v", err) + } + + // Attempt to unmarshal and check all values + unmarshalled, err := unmarshalPublicMsg(payload) + if err != nil { + t.Errorf("Failed to unmarshal publicMsg: %+v", err) + } + + if salt != unmarshalled.GetSalt() { + t.Errorf("Salt does not match original.\nexpected: %v\nreceived: %v", + salt, unmarshalled.GetSalt()) + } + + if !bytes.Equal(encryptedPayload, unmarshalled.GetPayload()) { + t.Errorf("Payload does not match original.\nexpected: %v\nreceived: %v", + encryptedPayload, unmarshalled.GetPayload()) + } +} diff --git a/groupChat/utils_test.go b/groupChat/utils_test.go new file mode 100644 index 0000000000000000000000000000000000000000..eac797af004142994c37598337b4b98d79299b75 --- /dev/null +++ b/groupChat/utils_test.go @@ -0,0 +1,334 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package groupChat + +import ( + "encoding/base64" + "github.com/pkg/errors" + gs "gitlab.com/elixxir/client/groupChat/groupStore" + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/network/gateway" + "gitlab.com/elixxir/client/stoppable" + "gitlab.com/elixxir/client/storage" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/client/switchboard" + "gitlab.com/elixxir/comms/network" + "gitlab.com/elixxir/crypto/contact" + "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/elixxir/crypto/e2e" + "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/elixxir/crypto/group" + "gitlab.com/elixxir/ekv" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/comms/connect" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/crypto/large" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/id/ephemeral" + "gitlab.com/xx_network/primitives/ndf" + "math/rand" + "sync" + "testing" +) + +// newTestManager creates a new Manager for testing. +func newTestManager(rng *rand.Rand, t *testing.T) (*Manager, gs.Group) { + store := storage.InitTestingSession(t) + user := group.Member{ + ID: store.GetUser().ReceptionID, + DhKey: store.GetUser().E2eDhPublicKey, + } + + g := newTestGroupWithUser(store.E2e().GetGroup(), user.ID, user.DhKey, + store.GetUser().E2eDhPrivateKey, rng, t) + gStore, err := gs.NewStore(versioned.NewKV(make(ekv.Memstore)), user) + if err != nil { + t.Fatalf("Failed to create new group store: %+v", err) + } + m := &Manager{ + store: store, + rng: fastRNG.NewStreamGenerator(1000, 10, csprng.NewSystemRNG), + gs: gStore, + } + return m, g +} + +// newTestManager creates a new Manager that has groups stored for testing. One +// of the groups in the list is also returned. +func newTestManagerWithStore(rng *rand.Rand, numGroups int, sendErr int, + requestFunc RequestCallback, receiveFunc ReceiveCallback, + t *testing.T) (*Manager, gs.Group) { + + store := storage.InitTestingSession(t) + + user := group.Member{ + ID: store.GetUser().ReceptionID, + DhKey: store.GetUser().E2eDhPublicKey, + } + + gStore, err := gs.NewStore(versioned.NewKV(make(ekv.Memstore)), user) + if err != nil { + t.Fatalf("Failed to create new group store: %+v", err) + } + + var g gs.Group + for i := 0; i < numGroups; i++ { + g = newTestGroupWithUser(store.E2e().GetGroup(), user.ID, user.DhKey, + store.GetUser().E2eDhPrivateKey, rng, t) + if err = gStore.Add(g); err != nil { + t.Fatalf("Failed to add group %d to group store: %+v", i, err) + } + } + + m := &Manager{ + store: store, + swb: switchboard.New(), + net: newTestNetworkManager(sendErr, t), + rng: fastRNG.NewStreamGenerator(1000, 10, csprng.NewSystemRNG), + gs: gStore, + requestFunc: requestFunc, + receiveFunc: receiveFunc, + } + return m, g +} + +// getMembership returns a Membership with random members for testing. +func getMembership(size int, uid *id.ID, pubKey *cyclic.Int, grp *cyclic.Group, prng *rand.Rand, t *testing.T) group.Membership { + contacts := make([]contact.Contact, size) + for i := range contacts { + randId, _ := id.NewRandomID(prng, id.User) + contacts[i] = contact.Contact{ + ID: randId, + DhPubKey: grp.NewInt(int64(prng.Int31() + 1)), + } + } + + contacts[2].ID = uid + contacts[2].DhPubKey = pubKey + + membership, err := group.NewMembership(contacts[0], contacts[1:]...) + if err != nil { + t.Errorf("Failed to create new membership: %+v", err) + } + + return membership +} + +// newTestGroup generates a new group with random values for testing. +func newTestGroup(grp *cyclic.Group, privKey *cyclic.Int, rng *rand.Rand, t *testing.T) gs.Group { + // Generate name from base 64 encoded random data + nameBytes := make([]byte, 16) + rng.Read(nameBytes) + name := []byte(base64.StdEncoding.EncodeToString(nameBytes)) + + // Generate the message from base 64 encoded random data + msgBytes := make([]byte, 128) + rng.Read(msgBytes) + msg := []byte(base64.StdEncoding.EncodeToString(msgBytes)) + + membership := getMembership(10, id.NewIdFromString("userID", id.User, t), + randCycInt(rng), grp, rng, t) + + dkl := gs.GenerateDhKeyList(id.NewIdFromString("userID", id.User, t), privKey, membership, grp) + + idPreimage, err := group.NewIdPreimage(rng) + if err != nil { + t.Fatalf("Failed to generate new group ID preimage: %+v", err) + } + + keyPreimage, err := group.NewKeyPreimage(rng) + if err != nil { + t.Fatalf("Failed to generate new group key preimage: %+v", err) + } + + groupID := group.NewID(idPreimage, membership) + groupKey := group.NewKey(keyPreimage, membership) + + return gs.NewGroup(name, groupID, groupKey, idPreimage, keyPreimage, msg, + membership, dkl) +} + +// newTestGroup generates a new group with random values for testing. +func newTestGroupWithUser(grp *cyclic.Group, uid *id.ID, pubKey, + privKey *cyclic.Int, rng *rand.Rand, t *testing.T) gs.Group { + // Generate name from base 64 encoded random data + nameBytes := make([]byte, 16) + rng.Read(nameBytes) + name := []byte(base64.StdEncoding.EncodeToString(nameBytes)) + + // Generate the message from base 64 encoded random data + msgBytes := make([]byte, 128) + rng.Read(msgBytes) + msg := []byte(base64.StdEncoding.EncodeToString(msgBytes)) + + membership := getMembership(10, uid, pubKey, grp, rng, t) + + dkl := gs.GenerateDhKeyList(uid, privKey, membership, grp) + + idPreimage, err := group.NewIdPreimage(rng) + if err != nil { + t.Fatalf("Failed to generate new group ID preimage: %+v", err) + } + + keyPreimage, err := group.NewKeyPreimage(rng) + if err != nil { + t.Fatalf("Failed to generate new group key preimage: %+v", err) + } + + groupID := group.NewID(idPreimage, membership) + groupKey := group.NewKey(keyPreimage, membership) + + return gs.NewGroup(name, groupID, groupKey, idPreimage, keyPreimage, msg, + membership, dkl) +} + +// randCycInt returns a random cyclic int. +func randCycInt(rng *rand.Rand) *cyclic.Int { + return getGroup().NewInt(int64(rng.Int31() + 1)) +} + +func getGroup() *cyclic.Group { + return cyclic.NewGroup( + large.NewIntFromString(getNDF().E2E.Prime, 16), + large.NewIntFromString(getNDF().E2E.Generator, 16)) +} + +func newTestNetworkManager(sendErr int, t *testing.T) interfaces.NetworkManager { + instanceComms := &connect.ProtoComms{ + Manager: connect.NewManagerTesting(t), + } + + thisInstance, err := network.NewInstanceTesting(instanceComms, getNDF(), + getNDF(), nil, nil, t) + if err != nil { + t.Fatalf("Failed to create new test instance: %v", err) + } + + return &testNetworkManager{ + instance: thisInstance, + messages: []map[id.ID]format.Message{}, + sendErr: sendErr, + } +} + +// testNetworkManager is a test implementation of NetworkManager interface. +type testNetworkManager struct { + instance *network.Instance + messages []map[id.ID]format.Message + e2eMessages []message.Send + errSkip int + sendErr int + sync.RWMutex +} + +func (tnm *testNetworkManager) GetMsgMap(i int) map[id.ID]format.Message { + tnm.RLock() + defer tnm.RUnlock() + return tnm.messages[i] +} + +func (tnm *testNetworkManager) GetE2eMsg(i int) message.Send { + tnm.RLock() + defer tnm.RUnlock() + return tnm.e2eMessages[i] +} + +func (tnm *testNetworkManager) SendE2E(msg message.Send, _ params.E2E, _ *stoppable.Single) ([]id.Round, e2e.MessageID, error) { + tnm.Lock() + defer tnm.Unlock() + + tnm.errSkip++ + if tnm.sendErr == 1 { + return nil, e2e.MessageID{}, errors.New("SendE2E error") + } else if tnm.sendErr == 2 && tnm.errSkip%2 == 0 { + return nil, e2e.MessageID{}, errors.New("SendE2E error") + } + + tnm.e2eMessages = append(tnm.e2eMessages, msg) + + return []id.Round{0, 1, 2, 3}, e2e.MessageID{}, nil +} + +func (tnm *testNetworkManager) SendUnsafe(message.Send, params.Unsafe) ([]id.Round, error) { + return []id.Round{}, nil +} + +func (tnm *testNetworkManager) SendCMIX(format.Message, *id.ID, params.CMIX) (id.Round, ephemeral.Id, error) { + return 0, ephemeral.Id{}, nil +} + +func (tnm *testNetworkManager) SendManyCMIX(messages map[id.ID]format.Message, _ params.CMIX) (id.Round, []ephemeral.Id, error) { + if tnm.sendErr == 1 { + return 0, nil, errors.New("SendManyCMIX error") + } + + tnm.Lock() + defer tnm.Unlock() + + tnm.messages = append(tnm.messages, messages) + + return 0, nil, nil +} + +func (tnm *testNetworkManager) GetInstance() *network.Instance { return tnm.instance } +func (tnm *testNetworkManager) GetHealthTracker() interfaces.HealthTracker { return nil } +func (tnm *testNetworkManager) Follow(interfaces.ClientErrorReport) (stoppable.Stoppable, error) { + return nil, nil +} +func (tnm *testNetworkManager) CheckGarbledMessages() {} +func (tnm *testNetworkManager) InProgressRegistrations() int { return 0 } +func (tnm *testNetworkManager) GetSender() *gateway.Sender { return nil } +func (tnm *testNetworkManager) GetAddressSize() uint8 { return 0 } +func (tnm *testNetworkManager) RegisterAddressSizeNotification(string) (chan uint8, error) { + return nil, nil +} +func (tnm *testNetworkManager) UnregisterAddressSizeNotification(string) {} + +func getNDF() *ndf.NetworkDefinition { + return &ndf.NetworkDefinition{ + E2E: ndf.Group{ + Prime: "E2EE983D031DC1DB6F1A7A67DF0E9A8E5561DB8E8D49413394C049B7A" + + "8ACCEDC298708F121951D9CF920EC5D146727AA4AE535B0922C688B55B3D" + + "D2AEDF6C01C94764DAB937935AA83BE36E67760713AB44A6337C20E78615" + + "75E745D31F8B9E9AD8412118C62A3E2E29DF46B0864D0C951C394A5CBBDC" + + "6ADC718DD2A3E041023DBB5AB23EBB4742DE9C1687B5B34FA48C3521632C" + + "4A530E8FFB1BC51DADDF453B0B2717C2BC6669ED76B4BDD5C9FF558E88F2" + + "6E5785302BEDBCA23EAC5ACE92096EE8A60642FB61E8F3D24990B8CB12EE" + + "448EEF78E184C7242DD161C7738F32BF29A841698978825B4111B4BC3E1E" + + "198455095958333D776D8B2BEEED3A1A1A221A6E37E664A64B83981C46FF" + + "DDC1A45E3D5211AAF8BFBC072768C4F50D7D7803D2D4F278DE8014A47323" + + "631D7E064DE81C0C6BFA43EF0E6998860F1390B5D3FEACAF1696015CB79C" + + "3F9C2D93D961120CD0E5F12CBB687EAB045241F96789C38E89D796138E63" + + "19BE62E35D87B1048CA28BE389B575E994DCA755471584A09EC723742DC3" + + "5873847AEF49F66E43873", + Generator: "2", + }, + CMIX: ndf.Group{ + Prime: "9DB6FB5951B66BB6FE1E140F1D2CE5502374161FD6538DF1648218642" + + "F0B5C48C8F7A41AADFA187324B87674FA1822B00F1ECF8136943D7C55757" + + "264E5A1A44FFE012E9936E00C1D3E9310B01C7D179805D3058B2A9F4BB6F" + + "9716BFE6117C6B5B3CC4D9BE341104AD4A80AD6C94E005F4B993E14F091E" + + "B51743BF33050C38DE235567E1B34C3D6A5C0CEAA1A0F368213C3D19843D" + + "0B4B09DCB9FC72D39C8DE41F1BF14D4BB4563CA28371621CAD3324B6A2D3" + + "92145BEBFAC748805236F5CA2FE92B871CD8F9C36D3292B5509CA8CAA77A" + + "2ADFC7BFD77DDA6F71125A7456FEA153E433256A2261C6A06ED3693797E7" + + "995FAD5AABBCFBE3EDA2741E375404AE25B", + Generator: "5C7FF6B06F8F143FE8288433493E4769C4D988ACE5BE25A0E2480" + + "9670716C613D7B0CEE6932F8FAA7C44D2CB24523DA53FBE4F6EC3595892D" + + "1AA58C4328A06C46A15662E7EAA703A1DECF8BBB2D05DBE2EB956C142A33" + + "8661D10461C0D135472085057F3494309FFA73C611F78B32ADBB5740C361" + + "C9F35BE90997DB2014E2EF5AA61782F52ABEB8BD6432C4DD097BC5423B28" + + "5DAFB60DC364E8161F4A2A35ACA3A10B1C4D203CC76A470A33AFDCBDD929" + + "59859ABD8B56E1725252D78EAC66E71BA9AE3F1DD2487199874393CD4D83" + + "2186800654760E1E34C09E4D155179F9EC0DC4473F996BDCE6EED1CABED8" + + "B6F116F7AD9CF505DF0F998E34AB27514B0FFE7", + }, + } +} diff --git a/interfaces/IsRunning.go b/interfaces/IsRunning.go deleted file mode 100644 index 950b842864712858b3d1dcf769f4765e5af86159..0000000000000000000000000000000000000000 --- a/interfaces/IsRunning.go +++ /dev/null @@ -1,8 +0,0 @@ -package interfaces - -// this interface is used to allow the follower to to be stopped later if it -// fails - -type Running interface{ - IsRunning()bool -} diff --git a/interfaces/message/receiveMessage.go b/interfaces/message/receiveMessage.go index fad6fb750ccc21cc9f03391056a8559a53cd5159..d11f8880865e8c6db95086c054f554126835b5c2 100644 --- a/interfaces/message/receiveMessage.go +++ b/interfaces/message/receiveMessage.go @@ -15,12 +15,14 @@ import ( ) type Receive struct { - ID e2e.MessageID - Payload []byte - MessageType Type - Sender *id.ID - RecipientID *id.ID - EphemeralID ephemeral.Id - Timestamp time.Time - Encryption EncryptionType + ID e2e.MessageID + Payload []byte + MessageType Type + Sender *id.ID + RecipientID *id.ID + EphemeralID ephemeral.Id + RoundId id.Round + RoundTimestamp time.Time + Timestamp time.Time // Message timestamp of when the user sent + Encryption EncryptionType } diff --git a/interfaces/message/type.go b/interfaces/message/type.go index 71a1c72ff431ab1c6164856fe892e72af8c21a68..5c8012fea55a6f7c6afcc953ef37dbf0daa7b6e0 100644 --- a/interfaces/message/type.go +++ b/interfaces/message/type.go @@ -49,4 +49,8 @@ const ( KeyExchangeTrigger = 30 // Rekey confirmation message. Sent by partner to confirm completion of a rekey KeyExchangeConfirm = 31 + + /* Group chat message types */ + // A group chat request message sent to all members in a group. + GroupCreationRequest = 40 ) diff --git a/interfaces/networkManager.go b/interfaces/networkManager.go index ad9e03caf6f53e13e9d8e9190f2417749c99f490..710700144bc2641b0f3caf87c00b65879cc0cb1a 100644 --- a/interfaces/networkManager.go +++ b/interfaces/networkManager.go @@ -20,9 +20,11 @@ import ( ) type NetworkManager interface { - SendE2E(m message.Send, p params.E2E) ([]id.Round, e2e.MessageID, error) + // The stoppable can be nil. + SendE2E(m message.Send, p params.E2E, stop *stoppable.Single) ([]id.Round, e2e.MessageID, error) SendUnsafe(m message.Send, p params.Unsafe) ([]id.Round, error) SendCMIX(message format.Message, recipient *id.ID, p params.CMIX) (id.Round, ephemeral.Id, error) + SendManyCMIX(messages map[id.ID]format.Message, p params.CMIX) (id.Round, []ephemeral.Id, error) GetInstance() *network.Instance GetHealthTracker() HealthTracker GetSender() *gateway.Sender @@ -45,4 +47,4 @@ type NetworkManager interface { } //for use in key exchange which needs to be callable inside of network -type SendE2E func(m message.Send, p params.E2E) ([]id.Round, e2e.MessageID, error) +type SendE2E func(m message.Send, p params.E2E, stop *stoppable.Single) ([]id.Round, e2e.MessageID, error) diff --git a/keyExchange/confirm.go b/keyExchange/confirm.go index b89b7c8dbeca6bcc7eaded004a301a511c332365..45224193f80a302ae24e4e4f01ec0f7afe647638 100644 --- a/keyExchange/confirm.go +++ b/keyExchange/confirm.go @@ -23,6 +23,7 @@ func startConfirm(sess *storage.Session, c chan message.Receive, select { case <-stop.Quit(): cleanup() + stop.ToStopped() return case confirmation := <-c: handleConfirm(sess, confirmation) diff --git a/keyExchange/rekey.go b/keyExchange/rekey.go index f2aab4c42ebedebd258e83168461b719821e1ef2..f4caacd2586bd7a8c0c8c296ba7f8ae19377db4f 100644 --- a/keyExchange/rekey.go +++ b/keyExchange/rekey.go @@ -15,6 +15,7 @@ import ( "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/interfaces/utility" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage" "gitlab.com/elixxir/client/storage/e2e" "gitlab.com/elixxir/comms/network" @@ -25,10 +26,11 @@ import ( ) func CheckKeyExchanges(instance *network.Instance, sendE2E interfaces.SendE2E, - sess *storage.Session, manager *e2e.Manager, sendTimeout time.Duration) { + sess *storage.Session, manager *e2e.Manager, sendTimeout time.Duration, + stop *stoppable.Single) { sessions := manager.TriggerNegotiations() for _, session := range sessions { - go trigger(instance, sendE2E, sess, manager, session, sendTimeout) + go trigger(instance, sendE2E, sess, manager, session, sendTimeout, stop) } } @@ -38,7 +40,7 @@ func CheckKeyExchanges(instance *network.Instance, sendE2E interfaces.SendE2E, // session while the latter on an extant session func trigger(instance *network.Instance, sendE2E interfaces.SendE2E, sess *storage.Session, manager *e2e.Manager, session *e2e.Session, - sendTimeout time.Duration) { + sendTimeout time.Duration, stop *stoppable.Single) { var negotiatingSession *e2e.Session jww.INFO.Printf("Negotation triggered for session %s with "+ "status: %s", session, session.NegotiationStatus()) @@ -61,7 +63,7 @@ func trigger(instance *network.Instance, sendE2E interfaces.SendE2E, } // send the rekey notification to the partner - err := negotiate(instance, sendE2E, sess, negotiatingSession, sendTimeout) + err := negotiate(instance, sendE2E, sess, negotiatingSession, sendTimeout, stop) // if sending the negotiation fails, revert the state of the session to // unconfirmed so it will be triggered in the future if err != nil { @@ -71,8 +73,8 @@ func trigger(instance *network.Instance, sendE2E interfaces.SendE2E, } func negotiate(instance *network.Instance, sendE2E interfaces.SendE2E, - sess *storage.Session, session *e2e.Session, - sendTimeout time.Duration) error { + sess *storage.Session, session *e2e.Session, sendTimeout time.Duration, + stop *stoppable.Single) error { e2eStore := sess.E2e() //generate public key @@ -102,7 +104,7 @@ func negotiate(instance *network.Instance, sendE2E interfaces.SendE2E, e2eParams := params.GetDefaultE2E() e2eParams.Type = params.KeyExchange - rounds, _, err := sendE2E(m, e2eParams) + rounds, _, err := sendE2E(m, e2eParams, stop) // If the send fails, returns the error so it can be handled. The caller // should ensure the calling session is in a state where the Rekey will // be triggered next time a key is used diff --git a/keyExchange/trigger.go b/keyExchange/trigger.go index e22dd76f0b30974254de9fb6bd356e8e3f69d88e..c88d068fe8d1463ea4c2b5408dd9b18c4030e488 100644 --- a/keyExchange/trigger.go +++ b/keyExchange/trigger.go @@ -32,14 +32,15 @@ const ( func startTrigger(sess *storage.Session, net interfaces.NetworkManager, c chan message.Receive, stop *stoppable.Single, params params.Rekey, cleanup func()) { - for true { + for { select { case <-stop.Quit(): cleanup() + stop.ToStopped() return case request := <-c: go func() { - err := handleTrigger(sess, net, request, params) + err := handleTrigger(sess, net, request, params, stop) if err != nil { jww.ERROR.Printf(errFailed, err) } @@ -49,7 +50,7 @@ func startTrigger(sess *storage.Session, net interfaces.NetworkManager, } func handleTrigger(sess *storage.Session, net interfaces.NetworkManager, - request message.Receive, param params.Rekey) error { + request message.Receive, param params.Rekey, stop *stoppable.Single) error { //ensure the message was encrypted properly if request.Encryption != message.E2E { errMsg := fmt.Sprintf(errBadTrigger, request.Sender) @@ -126,7 +127,10 @@ func handleTrigger(sess *storage.Session, net interfaces.NetworkManager, // send fails sess.GetCriticalMessages().AddProcessing(m, e2eParams) - rounds, _, err := net.SendE2E(m, e2eParams) + rounds, _, err := net.SendE2E(m, e2eParams, stop) + if err != nil { + return err + } //Register the event for all rounds sendResults := make(chan ds.EventReturn, len(rounds)) diff --git a/keyExchange/trigger_test.go b/keyExchange/trigger_test.go index 5dce2d6a7c50c8e774c3ecd88dd00f5e3db1cacc..d01f9101396d573c8c97892aa3256ede2ba91d49 100644 --- a/keyExchange/trigger_test.go +++ b/keyExchange/trigger_test.go @@ -10,6 +10,7 @@ package keyExchange import ( "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage/e2e" dh "gitlab.com/elixxir/crypto/diffieHellman" "gitlab.com/xx_network/crypto/csprng" @@ -66,8 +67,9 @@ func TestHandleTrigger(t *testing.T) { // Handle the trigger and check for an error rekeyParams := params.GetDefaultRekey() + stop := stoppable.NewSingle("stoppable") rekeyParams.RoundTimeout = 0 * time.Second - err = handleTrigger(aliceSession, aliceManager, receiveMsg, rekeyParams) + err = handleTrigger(aliceSession, aliceManager, receiveMsg, rekeyParams, stop) if err != nil { t.Errorf("Handle trigger error: %v", err) } diff --git a/keyExchange/utils_test.go b/keyExchange/utils_test.go index d9dd4ef14b2dea6fd6c3f06db794e14a87b44cc8..876a16576bfd95871fca6e099fb77caff25f6c02 100644 --- a/keyExchange/utils_test.go +++ b/keyExchange/utils_test.go @@ -66,7 +66,7 @@ func (t *testNetworkManagerGeneric) CheckGarbledMessages() { return } -func (t *testNetworkManagerGeneric) SendE2E(m message.Send, p params.E2E) ( +func (t *testNetworkManagerGeneric) SendE2E(message.Send, params.E2E, *stoppable.Single) ( []id.Round, cE2e.MessageID, error) { rounds := []id.Round{id.Round(0), id.Round(1), id.Round(2)} return rounds, cE2e.MessageID{}, nil @@ -84,6 +84,10 @@ func (t *testNetworkManagerGeneric) SendCMIX(message format.Message, rid *id.ID, } +func (t *testNetworkManagerGeneric) SendManyCMIX(messages map[id.ID]format.Message, p params.CMIX) (id.Round, []ephemeral.Id, error) { + return id.Round(0), []ephemeral.Id{}, nil +} + func (t *testNetworkManagerGeneric) GetInstance() *network.Instance { return t.instance @@ -163,7 +167,7 @@ func (t *testNetworkManagerFullExchange) CheckGarbledMessages() { // Intended for alice to send to bob. Trigger's Bob's confirmation, chaining the operation // together -func (t *testNetworkManagerFullExchange) SendE2E(m message.Send, p params.E2E) ( +func (t *testNetworkManagerFullExchange) SendE2E(message.Send, params.E2E, *stoppable.Single) ( []id.Round, cE2e.MessageID, error) { rounds := []id.Round{id.Round(0), id.Round(1), id.Round(2)} @@ -189,18 +193,18 @@ func (t *testNetworkManagerFullExchange) SendE2E(m message.Send, p params.E2E) ( bobSwitchboard.Speak(confirmMessage) return rounds, cE2e.MessageID{}, nil - } func (t *testNetworkManagerFullExchange) SendUnsafe(m message.Send, p params.Unsafe) ([]id.Round, error) { - return nil, nil } func (t *testNetworkManagerFullExchange) SendCMIX(message format.Message, eid *id.ID, p params.CMIX) (id.Round, ephemeral.Id, error) { - return id.Round(0), ephemeral.Id{}, nil +} +func (t *testNetworkManagerFullExchange) SendManyCMIX(messages map[id.ID]format.Message, p params.CMIX) (id.Round, []ephemeral.Id, error) { + return id.Round(0), []ephemeral.Id{}, nil } func (t *testNetworkManagerFullExchange) GetInstance() *network.Instance { diff --git a/network/ephemeral/addressSpace_test.go b/network/ephemeral/addressSpace_test.go index e08f08d844a6aff3bf387db5456a2147fe89f16b..03cca463bc829e9db4ff6bb916fb6738e9b12e19 100644 --- a/network/ephemeral/addressSpace_test.go +++ b/network/ephemeral/addressSpace_test.go @@ -72,7 +72,7 @@ func Test_addressSpace_Get_WaitBroadcast(t *testing.T) { t.Errorf("Get returned the wrong size.\nexpected: %d\nreceived: %d", initSize, size) } - case <-time.NewTimer(15 * time.Millisecond).C: + case <-time.NewTimer(25 * time.Millisecond).C: t.Error("Get blocking when the Cond has broadcast.") } }() @@ -137,7 +137,7 @@ func Test_addressSpace_update_GetAndChannels(t *testing.T) { t.Errorf("Thread %d received unexpected size."+ "\nexpected: %d\nreceived: %d", i, expectedSize, size) } - case <-time.NewTimer(15 * time.Millisecond).C: + case <-time.NewTimer(20 * time.Millisecond).C: t.Errorf("Timed out waiting for Get to return on thread %d.", i) } }(i, waitChan) @@ -166,7 +166,7 @@ func Test_addressSpace_update_GetAndChannels(t *testing.T) { case size := <-notifyChan: t.Errorf("Received size %d on channel %s when it should not have.", size, chanID) - case <-time.NewTimer(15 * time.Millisecond).C: + case <-time.NewTimer(20 * time.Millisecond).C: } }(chanID, notifyChan) } @@ -196,7 +196,7 @@ func Test_addressSpace_update_GetAndChannels(t *testing.T) { t.Errorf("Failed to receive expected size on channel %s."+ "\nexpected: %d\nreceived: %d", chanID, expectedSize, size) } - case <-time.NewTimer(15 * time.Millisecond).C: + case <-time.NewTimer(20 * time.Millisecond).C: t.Errorf("Timed out waiting on channel %s", chanID) } }(chanID, notifyChan) @@ -211,7 +211,7 @@ func Test_addressSpace_update_GetAndChannels(t *testing.T) { case size := <-notifyChan: t.Errorf("Received size %d on channel %s when it should not have.", size, chanID) - case <-time.NewTimer(15 * time.Millisecond).C: + case <-time.NewTimer(20 * time.Millisecond).C: } }() diff --git a/network/ephemeral/testutil.go b/network/ephemeral/testutil.go index c213078eeecaca6178d6f2cd7617caf468da792e..92eb68c408a0ecee0071651b5660c922d731934d 100644 --- a/network/ephemeral/testutil.go +++ b/network/ephemeral/testutil.go @@ -33,7 +33,7 @@ type testNetworkManager struct { msg message.Send } -func (t *testNetworkManager) SendE2E(m message.Send, _ params.E2E) ([]id.Round, +func (t *testNetworkManager) SendE2E(m message.Send, _ params.E2E, _ *stoppable.Single) ([]id.Round, e2e.MessageID, error) { rounds := []id.Round{ id.Round(0), @@ -62,6 +62,10 @@ func (t *testNetworkManager) SendCMIX(format.Message, *id.ID, params.CMIX) (id.R return 0, ephemeral.Id{}, nil } +func (t *testNetworkManager) SendManyCMIX(messages map[id.ID]format.Message, p params.CMIX) (id.Round, []ephemeral.Id, error) { + return 0, []ephemeral.Id{}, nil +} + func (t *testNetworkManager) GetInstance() *network.Instance { return t.instance } diff --git a/network/ephemeral/tracker.go b/network/ephemeral/tracker.go index da81a326914389233631d15fd19bf9ee65744ed0..83f855375c23ad6fa0d95d175802de5d12f5659c 100644 --- a/network/ephemeral/tracker.go +++ b/network/ephemeral/tracker.go @@ -122,6 +122,7 @@ func track(session *storage.Session, addrSpace *AddressSpace, ourId *id.ID, stop case addressSize = <-addressSizeUpdate: receptionStore.SetToExpire(addressSize) case <-stop.Quit(): + stop.ToStopped() return } } diff --git a/network/ephemeral/tracker_test.go b/network/ephemeral/tracker_test.go index 3b9e03b29db0b470eebb6a56e490d5f98ac0385b..3307446bc17f25cbf8d8bc119e01a2889b1c2ae2 100644 --- a/network/ephemeral/tracker_test.go +++ b/network/ephemeral/tracker_test.go @@ -47,7 +47,7 @@ func TestCheck(t *testing.T) { ourId := id.NewIdFromBytes([]byte("Sauron"), t) stop := Track(session, NewTestAddressSpace(15, t), ourId) - err = stop.Close(3 * time.Second) + err = stop.Close() if err != nil { t.Errorf("Could not close thread: %+v", err) } @@ -83,7 +83,7 @@ func TestCheck_Thread(t *testing.T) { }() time.Sleep(3 * time.Second) - err = stop.Close(3 * time.Second) + err = stop.Close() if err != nil { t.Errorf("Could not close thread: %v", err) } diff --git a/network/follow.go b/network/follow.go index 8669ea162a38dc51bdbb7601f8023a19406b7351..aa505deecbb5cda93f1ebed0124ef070cf234705 100644 --- a/network/follow.go +++ b/network/follow.go @@ -28,6 +28,7 @@ import ( jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces" "gitlab.com/elixxir/client/network/rounds" + "gitlab.com/elixxir/client/stoppable" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/primitives/knownRounds" "gitlab.com/elixxir/primitives/states" @@ -49,19 +50,20 @@ type followNetworkComms interface { // followNetwork polls the network to get updated on the state of nodes, the // round status, and informs the client when messages can be retrieved. -func (m *manager) followNetwork(report interfaces.ClientErrorReport, quitCh <-chan struct{}, isRunning interfaces.Running) { +func (m *manager) followNetwork(report interfaces.ClientErrorReport, + stop *stoppable.Single) { ticker := time.NewTicker(m.param.TrackNetworkPeriod) TrackTicker := time.NewTicker(debugTrackPeriod) rng := m.Rng.GetStream() - done := false - for !done { + for { select { - case <-quitCh: + case <-stop.Quit(): rng.Close() - done = true + stop.ToStopped() + return case <-ticker.C: - m.follow(report, rng, m.Comms, isRunning) + m.follow(report, rng, m.Comms, stop) case <-TrackTicker.C: numPolls := atomic.SwapUint64(m.tracker, 0) if m.numLatencies != 0 { @@ -75,19 +77,13 @@ func (m *manager) followNetwork(report interfaces.ClientErrorReport, quitCh <-ch jww.INFO.Printf("Polled the network %d times in the "+ "last %s", numPolls, debugTrackPeriod) } - - } - if !isRunning.IsRunning() { - jww.ERROR.Printf("Killing network follower " + - "due to failed exit") - return } } } // executes each iteration of the follower func (m *manager) follow(report interfaces.ClientErrorReport, rng csprng.Source, - comms followNetworkComms, isRunning interfaces.Running) { + comms followNetworkComms, stop *stoppable.Single) { //Get the identity we will poll for identity, err := m.Session.Reception().GetIdentity(rng, m.addrSpace.GetWithoutWait()) @@ -119,10 +115,11 @@ func (m *manager) follow(report interfaces.ClientErrorReport, rng csprng.Source, identity.EphId.Int64(), identity.Source, identity.StartRequest, identity.EndRequest, identity.EndRequest.Sub(identity.StartRequest), host.GetId()) return comms.SendPoll(host, &pollReq) - }) - if !isRunning.IsRunning() { - jww.ERROR.Printf("Killing network follower " + - "due to failed exit") + }, stop) + + // Exit if the thread has been stopped + if stoppable.CheckErr(err) { + jww.INFO.Print(err) return } @@ -167,9 +164,9 @@ func (m *manager) follow(report interfaces.ClientErrorReport, rng csprng.Source, // Update the address space size // todo: this is a fix for incompatibility with the live network // remove once the live network has been pushed to - if len(m.Instance.GetPartialNdf().Get().AddressSpace)!=0{ + if len(m.Instance.GetPartialNdf().Get().AddressSpace) != 0 { m.addrSpace.Update(m.Instance.GetPartialNdf().Get().AddressSpace[0].Size) - }else{ + } else { m.addrSpace.Update(18) } diff --git a/network/gateway/sender.go b/network/gateway/sender.go index 7e07b247353156960992382eb93bbe718ab9851a..e8943e21549e34bdbe692a689f6f789e7b01af73 100644 --- a/network/gateway/sender.go +++ b/network/gateway/sender.go @@ -11,6 +11,7 @@ package gateway import ( "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage" "gitlab.com/elixxir/comms/network" "gitlab.com/elixxir/crypto/fastRNG" @@ -36,12 +37,14 @@ func NewSender(poolParams PoolParams, rng *fastRNG.StreamGenerator, ndf *ndf.Net } // SendToAny Call given sendFunc to any Host in the HostPool, attempting with up to numProxies destinations -func (s *Sender) SendToAny(sendFunc func(host *connect.Host) (interface{}, error)) (interface{}, error) { +func (s *Sender) SendToAny(sendFunc func(host *connect.Host) (interface{}, error), stop *stoppable.Single) (interface{}, error) { proxies := s.getAny(s.poolParams.ProxyAttempts, nil) for i := range proxies { result, err := sendFunc(proxies[i]) - if err == nil { + if stop != nil && !stop.IsRunning() { + return nil, errors.Errorf(stoppable.ErrMsg, stop.Name(), "SendToAny") + } else if err == nil { return result, nil } else { jww.WARN.Printf("Unable to SendToAny %s: %s", proxies[i].GetId().String(), err) @@ -57,7 +60,8 @@ func (s *Sender) SendToAny(sendFunc func(host *connect.Host) (interface{}, error // SendToPreferred Call given sendFunc to any Host in the HostPool, attempting with up to numProxies destinations func (s *Sender) SendToPreferred(targets []*id.ID, - sendFunc func(host *connect.Host, target *id.ID) (interface{}, bool, error)) (interface{}, error) { + sendFunc func(host *connect.Host, target *id.ID) (interface{}, bool, error), + stop *stoppable.Single) (interface{}, error) { // Get the hosts and shuffle randomly targetHosts := s.getPreferred(targets) @@ -65,7 +69,9 @@ func (s *Sender) SendToPreferred(targets []*id.ID, // Attempt to send directly to targets if they are in the HostPool for i := range targetHosts { result, didAbort, err := sendFunc(targetHosts[i], targets[i]) - if err == nil { + if stop != nil && !stop.IsRunning() { + return nil, errors.Errorf(stoppable.ErrMsg, stop.Name(), "SendToPreferred") + } else if err == nil { return result, nil } else { if didAbort { @@ -103,7 +109,9 @@ func (s *Sender) SendToPreferred(targets []*id.ID, } result, didAbort, err := sendFunc(targetProxies[proxyIdx], target) - if err == nil { + if stop != nil && !stop.IsRunning() { + return nil, errors.Errorf(stoppable.ErrMsg, stop.Name(), "SendToPreferred") + } else if err == nil { return result, nil } else { if didAbort { diff --git a/network/gateway/sender_test.go b/network/gateway/sender_test.go index 35868f1d6c32f83d28cfb3c534d1877cc94b53a4..d8dbf16227e8724ef509589bb7133efea888cc26 100644 --- a/network/gateway/sender_test.go +++ b/network/gateway/sender_test.go @@ -78,7 +78,7 @@ func TestSender_SendToAny(t *testing.T) { } // Test sendToAny with test interfaces - result, err := sender.SendToAny(SendToAny_HappyPath) + result, err := sender.SendToAny(SendToAny_HappyPath, nil) if err != nil { t.Errorf("Should not error in SendToAny happy path: %v", err) } @@ -89,12 +89,12 @@ func TestSender_SendToAny(t *testing.T) { "\n\tReceived: %v", happyPathReturn, result) } - _, err = sender.SendToAny(SendToAny_KnownError) + _, err = sender.SendToAny(SendToAny_KnownError, nil) if err == nil { t.Fatalf("Expected error path did not receive error") } - _, err = sender.SendToAny(SendToAny_UnknownError) + _, err = sender.SendToAny(SendToAny_UnknownError, nil) if err == nil { t.Fatalf("Expected error path did not receive error") } @@ -139,7 +139,7 @@ func TestSender_SendToPreferred(t *testing.T) { preferredHost := sender.hostList[preferredIndex] // Happy path - result, err := sender.SendToPreferred([]*id.ID{preferredHost.GetId()}, SendToPreferred_HappyPath) + result, err := sender.SendToPreferred([]*id.ID{preferredHost.GetId()}, SendToPreferred_HappyPath, nil) if err != nil { t.Errorf("Should not error in SendToPreferred happy path: %v", err) } @@ -151,7 +151,7 @@ func TestSender_SendToPreferred(t *testing.T) { } // Call a send which returns an error which triggers replacement - _, err = sender.SendToPreferred([]*id.ID{preferredHost.GetId()}, SendToPreferred_KnownError) + _, err = sender.SendToPreferred([]*id.ID{preferredHost.GetId()}, SendToPreferred_KnownError, nil) if err == nil { t.Fatalf("Expected error path did not receive error") } @@ -171,7 +171,7 @@ func TestSender_SendToPreferred(t *testing.T) { preferredHost = sender.hostList[preferredIndex] // Unknown error return will not trigger replacement - _, err = sender.SendToPreferred([]*id.ID{preferredHost.GetId()}, SendToPreferred_UnknownError) + _, err = sender.SendToPreferred([]*id.ID{preferredHost.GetId()}, SendToPreferred_UnknownError, nil) if err == nil { t.Fatalf("Expected error path did not receive error") } diff --git a/network/health/tracker.go b/network/health/tracker.go index ff53904f2015fe81af4938efd9ca8222fd68eba1..1fc5e8cfb7285916192e34918e5fa9068b06a928 100644 --- a/network/health/tracker.go +++ b/network/health/tracker.go @@ -114,25 +114,26 @@ func (t *Tracker) Start() (stoppable.Stoppable, error) { stop := stoppable.NewSingle("Health Tracker") - go t.start(stop.Quit()) + go t.start(stop) return stop, nil } // Long-running thread used to monitor and report on network health -func (t *Tracker) start(quitCh <-chan struct{}) { +func (t *Tracker) start(stop *stoppable.Single) { timer := time.NewTimer(t.timeout) for { var heartbeat network.Heartbeat select { - case <-quitCh: + case <-stop.Quit(): t.mux.Lock() t.isHealthy = false t.running = false t.mux.Unlock() t.transmit(false) - break + stop.ToStopped() + return case heartbeat = <-t.heartbeat: if healthy(heartbeat) { // Stop and reset timer @@ -146,10 +147,9 @@ func (t *Tracker) start(quitCh <-chan struct{}) { timer.Reset(t.timeout) t.setHealth(true) } - break case <-timer.C: t.setHealth(false) - break + return } } } diff --git a/network/manager.go b/network/manager.go index c072da4eb05bd74f60fd95dfdcfcdd7bd47e7428..fcb20f2d77f94a57c4c2f01e581b03fca06b6178 100644 --- a/network/manager.go +++ b/network/manager.go @@ -136,7 +136,7 @@ func (m *manager) Follow(report interfaces.ClientErrorReport) (stoppable.Stoppab // Start the Network Tracker trackNetworkStopper := stoppable.NewSingle("TrackNetwork") - go m.followNetwork(report, trackNetworkStopper.Quit(), trackNetworkStopper) + go m.followNetwork(report, trackNetworkStopper) multi.Add(trackNetworkStopper) // Message reception diff --git a/network/message/bundle.go b/network/message/bundle.go index 56f1618d643641da6e9c2550e998a37a1269344c..81c649bd3798d693cd451dd7322d9d83e95510d9 100644 --- a/network/message/bundle.go +++ b/network/message/bundle.go @@ -9,13 +9,15 @@ package message import ( "gitlab.com/elixxir/client/storage/reception" + pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/primitives/format" "gitlab.com/xx_network/primitives/id" ) type Bundle struct { - Round id.Round - Messages []format.Message - Finish func() - Identity reception.IdentityUse + Round id.Round + RoundInfo *pb.RoundInfo + Messages []format.Message + Finish func() + Identity reception.IdentityUse } diff --git a/network/message/critical.go b/network/message/critical.go index 384ac9981f1dec0471b72aa3083880c26fc56382..84b88da602ac8fb87a1562d25ab8a52c6f1c23ed 100644 --- a/network/message/critical.go +++ b/network/message/critical.go @@ -12,6 +12,7 @@ import ( "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/interfaces/utility" + "gitlab.com/elixxir/client/stoppable" ds "gitlab.com/elixxir/comms/network/dataStructures" "gitlab.com/elixxir/primitives/format" "gitlab.com/elixxir/primitives/states" @@ -27,22 +28,22 @@ import ( // Tracker (/network/Health/Tracker.g0) //Thread loop for processing critical messages -func (m *Manager) processCriticalMessages(quitCh <-chan struct{}) { - done := false - for !done { +func (m *Manager) processCriticalMessages(stop *stoppable.Single) { + for { select { - case <-quitCh: - done = true + case <-stop.Quit(): + stop.ToStopped() + return case isHealthy := <-m.networkIsHealthy: if isHealthy { - m.criticalMessages() + m.criticalMessages(stop) } } } } // processes all critical messages -func (m *Manager) criticalMessages() { +func (m *Manager) criticalMessages(stop *stoppable.Single) { critMsgs := m.Session.GetCriticalMessages() // try to send every message in the critical messages and the raw critical // messages buffer in parallel @@ -53,7 +54,7 @@ func (m *Manager) criticalMessages() { jww.INFO.Printf("Resending critical message to %s ", msg.Recipient) //send the message - rounds, _, err := m.SendE2E(msg, param) + rounds, _, err := m.SendE2E(msg, param, stop) //if the message fail to send, notify the buffer so it can be handled //in the future and exit if err != nil { @@ -95,7 +96,7 @@ func (m *Manager) criticalMessages() { jww.INFO.Printf("Resending critical raw message to %s "+ "(msgDigest: %s)", rid, msg.Digest()) //send the message - round, _, err := m.SendCMIX(m.sender, msg, rid, param) + round, _, err := m.SendCMIX(m.sender, msg, rid, param, stop) //if the message fail to send, notify the buffer so it can be handled //in the future and exit if err != nil { diff --git a/network/message/garbled.go b/network/message/garbled.go index d9e1ac6be427f2ca0e9c8e89572fb6eed2285483..3a828b5eed2584ff323df7cdbb6016726c9dfa18 100644 --- a/network/message/garbled.go +++ b/network/message/garbled.go @@ -10,6 +10,7 @@ package message import ( jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/primitives/format" "time" ) @@ -33,12 +34,12 @@ func (m *Manager) CheckGarbledMessages() { } //long running thread which processes garbled messages -func (m *Manager) processGarbledMessages(quitCh <-chan struct{}) { - done := false - for !done { +func (m *Manager) processGarbledMessages(stop *stoppable.Single) { + for { select { - case <-quitCh: - done = true + case <-stop.Quit(): + stop.ToStopped() + return case <-m.triggerGarbled: m.handleGarbledMessages() } diff --git a/network/message/garbled_test.go b/network/message/garbled_test.go index ab4919f16d674ea10a609ca6d0e46a442259789f..d254c56c48329187955596c8367b106fab77cc08 100644 --- a/network/message/garbled_test.go +++ b/network/message/garbled_test.go @@ -7,6 +7,7 @@ import ( "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/network/internal" "gitlab.com/elixxir/client/network/message/parse" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage" "gitlab.com/elixxir/client/switchboard" "gitlab.com/elixxir/comms/client" @@ -120,8 +121,8 @@ func TestManager_CheckGarbledMessages(t *testing.T) { encryptedMsg := key.Encrypt(msg) i.Session.GetGarbledMessages().Add(encryptedMsg) - quitch := make(chan struct{}) - go m.processGarbledMessages(quitch) + stop := stoppable.NewSingle("stop") + go m.processGarbledMessages(stop) m.CheckGarbledMessages() diff --git a/network/message/handler.go b/network/message/handler.go index b8892cff56f51009558998fcd74b1fa027c4a712..060332e8cb982744e9925c539ac455a499bf2e18 100644 --- a/network/message/handler.go +++ b/network/message/handler.go @@ -10,23 +10,24 @@ package message import ( jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces/message" - "gitlab.com/elixxir/client/storage/reception" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/crypto/e2e" fingerprint2 "gitlab.com/elixxir/crypto/fingerprint" "gitlab.com/elixxir/primitives/format" + "gitlab.com/elixxir/primitives/states" "gitlab.com/xx_network/primitives/id" "time" ) -func (m *Manager) handleMessages(quitCh <-chan struct{}) { - done := false - for !done { +func (m *Manager) handleMessages(stop *stoppable.Single) { + for { select { - case <-quitCh: - done = true + case <-stop.Quit(): + stop.ToStopped() + return case bundle := <-m.messageReception: for _, msg := range bundle.Messages { - m.handleMessage(msg, bundle.Identity) + m.handleMessage(msg, bundle) } bundle.Finish() } @@ -34,10 +35,11 @@ func (m *Manager) handleMessages(quitCh <-chan struct{}) { } -func (m *Manager) handleMessage(ecrMsg format.Message, identity reception.IdentityUse) { +func (m *Manager) handleMessage(ecrMsg format.Message, bundle Bundle) { // We've done all the networking, now process the message fingerprint := ecrMsg.GetKeyFP() msgDigest := ecrMsg.Digest() + identity := bundle.Identity e2eKv := m.Session.E2e() @@ -90,18 +92,16 @@ func (m *Manager) handleMessage(ecrMsg format.Message, identity reception.Identi // if it doesnt match any form of encrypted, hear it as a raw message // and add it to garbled messages to be handled later msg = ecrMsg - if err != nil { - jww.DEBUG.Printf("Failed to unmarshal ephemeral ID "+ - "on unknown message: %+v", err) - } raw := message.Receive{ - Payload: msg.Marshal(), - MessageType: message.Raw, - Sender: &id.ID{}, - EphemeralID: identity.EphId, - Timestamp: time.Time{}, - Encryption: message.None, - RecipientID: identity.Source, + Payload: msg.Marshal(), + MessageType: message.Raw, + Sender: &id.ID{}, + EphemeralID: identity.EphId, + Timestamp: time.Time{}, + Encryption: message.None, + RecipientID: identity.Source, + RoundId: id.Round(bundle.RoundInfo.ID), + RoundTimestamp: time.Unix(0, int64(bundle.RoundInfo.Timestamps[states.QUEUED])), } jww.INFO.Printf("Garbled/RAW Message: keyFP: %v, msgDigest: %s", msg.GetKeyFP(), msg.Digest()) @@ -124,6 +124,8 @@ func (m *Manager) handleMessage(ecrMsg format.Message, identity reception.Identi xxMsg.RecipientID = identity.Source xxMsg.EphemeralID = identity.EphId xxMsg.Encryption = encTy + xxMsg.RoundId = id.Round(bundle.RoundInfo.ID) + xxMsg.RoundTimestamp = time.Unix(0, int64(bundle.RoundInfo.Timestamps[states.QUEUED])) if xxMsg.MessageType == message.Raw { jww.WARN.Panicf("Recieved a message of type 'Raw' from %s."+ "Message Ignored, 'Raw' is a reserved type. Message supressed.", diff --git a/network/message/manager.go b/network/message/manager.go index 7728910aa7e62186c682dcb97b297cb470dcac58..d5bf21a10db947a10f6566c93d005481770789ef 100644 --- a/network/message/manager.go +++ b/network/message/manager.go @@ -58,19 +58,19 @@ func (m *Manager) StartProcessies() stoppable.Stoppable { //create the message handler workers for i := uint(0); i < m.param.MessageReceptionWorkerPoolSize; i++ { stop := stoppable.NewSingle(fmt.Sprintf("MessageReception Worker %v", i)) - go m.handleMessages(stop.Quit()) + go m.handleMessages(stop) multi.Add(stop) } //create the critical messages thread critStop := stoppable.NewSingle("CriticalMessages") - go m.processCriticalMessages(critStop.Quit()) + go m.processCriticalMessages(critStop) m.Health.AddChannel(m.networkIsHealthy) multi.Add(critStop) //create the garbled messages thread garbledStop := stoppable.NewSingle("GarbledMessages") - go m.processGarbledMessages(garbledStop.Quit()) + go m.processGarbledMessages(garbledStop) multi.Add(garbledStop) return multi diff --git a/network/message/sendCmix.go b/network/message/sendCmix.go index 2263b903f2d05491ede47501802d7e3e3fab5135..d011a533cb05aad5ecdf801267cb0db7a5f6055d 100644 --- a/network/message/sendCmix.go +++ b/network/message/sendCmix.go @@ -13,37 +13,30 @@ import ( jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/network/gateway" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/comms/network" "gitlab.com/elixxir/crypto/fastRNG" - "gitlab.com/elixxir/crypto/fingerprint" "gitlab.com/elixxir/primitives/format" - "gitlab.com/elixxir/primitives/states" "gitlab.com/xx_network/comms/connect" "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/id/ephemeral" "gitlab.com/xx_network/primitives/netTime" "strings" - "time" ) -// interface for SendCMIX comms; allows mocking this in testing -type sendCmixCommsInterface interface { - SendPutMessage(host *connect.Host, message *pb.GatewaySlot) (*pb.GatewaySlotResponse, error) -} - -// 1.5 seconds -const sendTimeBuffer = 500 * time.Millisecond - // WARNING: Potentially Unsafe // Public manager function to send a message over CMIX -func (m *Manager) SendCMIX(sender *gateway.Sender, msg format.Message, recipient *id.ID, param params.CMIX) (id.Round, ephemeral.Id, error) { +func (m *Manager) SendCMIX(sender *gateway.Sender, msg format.Message, + recipient *id.ID, param params.CMIX, stop *stoppable.Single) (id.Round, ephemeral.Id, error) { msgCopy := msg.Copy() - return sendCmixHelper(sender, msgCopy, recipient, m.param, param, m.Instance, m.Session, m.nodeRegistration, m.Rng, m.TransmissionID, m.Comms) + return sendCmixHelper(sender, msgCopy, recipient, param, m.Instance, + m.Session, m.nodeRegistration, m.Rng, m.TransmissionID, m.Comms, stop) } -// Payloads send are not End to End encrypted, MetaData is NOT protected with +// Helper function for sendCmix +// NOTE: Payloads send are not End to End encrypted, MetaData is NOT protected with // this call, see SendE2E for End to End encryption and full privacy protection // Internal SendCmix which bypasses the network check, will attempt to send to // the network without checking state. It has a built in retry system which can @@ -51,9 +44,11 @@ func (m *Manager) SendCMIX(sender *gateway.Sender, msg format.Message, recipient // If the message is successfully sent, the id of the round sent it is returned, // which can be registered with the network instance to get a callback on // its status -func sendCmixHelper(sender *gateway.Sender, msg format.Message, recipient *id.ID, messageParams params.Messages, cmixParams params.CMIX, instance *network.Instance, - session *storage.Session, nodeRegistration chan network.NodeGateway, rng *fastRNG.StreamGenerator, senderId *id.ID, - comms sendCmixCommsInterface) (id.Round, ephemeral.Id, error) { +func sendCmixHelper(sender *gateway.Sender, msg format.Message, + recipient *id.ID, cmixParams params.CMIX, instance *network.Instance, + session *storage.Session, nodeRegistration chan network.NodeGateway, + rng *fastRNG.StreamGenerator, senderId *id.ID, comms sendCmixCommsInterface, + stop *stoppable.Single) (id.Round, ephemeral.Id, error) { timeStart := netTime.Now() attempted := set.New() @@ -86,89 +81,24 @@ func sendCmixHelper(sender *gateway.Sender, msg format.Message, recipient *id.ID //add the round on to the list of attempted so it is not tried again attempted.Insert(bestRound) - //set the ephemeral ID - ephID, _, _, err := ephemeral.GetId(recipient, - uint(bestRound.AddressSpaceSize), - int64(bestRound.Timestamps[states.QUEUED])) + // Retrieve host and key information from round + firstGateway, roundKeys, err := processRound(instance, session, nodeRegistration, bestRound, recipient.String(), msg.Digest()) if err != nil { - jww.FATAL.Panicf("Failed to generate ephemeral ID when "+ - "sending to %s (msgDigest: %s): %+v", err, recipient, - msg.Digest()) + jww.WARN.Printf("SendCmix failed to process round (will retry): %v", err) + continue } + // Build the messages to send stream := rng.GetStream() - ephIdFilled, err := ephID.Fill(uint(bestRound.AddressSpaceSize), stream) - if err != nil { - jww.FATAL.Panicf("Failed to obfuscate the ephemeralID when "+ - "sending to %s (msgDigest: %s): %+v", recipient, msg.Digest(), - err) - } - stream.Close() - - msg.SetEphemeralRID(ephIdFilled[:]) - - //set the identity fingerprint - ifp := fingerprint.IdentityFP(msg.GetContents(), recipient) - msg.SetIdentityFP(ifp) - //build the topology - idList, err := id.NewIDListFromBytes(bestRound.Topology) + wrappedMsg, encMsg, ephID, err := buildSlotMessage(msg, recipient, + firstGateway, stream, senderId, bestRound, roundKeys) if err != nil { - jww.ERROR.Printf("Failed to use topology for round %d when "+ - "sending to %s (msgDigest: %s): %+v", bestRound.ID, - recipient, msg.Digest(), err) - continue - } - topology := connect.NewCircuit(idList) - //get they keys for the round, reject if any nodes do not have - //keying relationships - roundKeys, missingKeys := session.Cmix().GetRoundKeys(topology) - if len(missingKeys) > 0 { - jww.WARN.Printf("Failed to send on round %d to %s "+ - "(msgDigest: %s) due to missing relationships with nodes: %s", - bestRound.ID, recipient, msg.Digest(), missingKeys) - go handleMissingNodeKeys(instance, nodeRegistration, missingKeys) - time.Sleep(cmixParams.RetryDelay) - continue + stream.Close() + return 0, ephemeral.Id{}, err } - - //get the gateway to transmit to - firstGateway := topology.GetNodeAtIndex(0).DeepCopy() - firstGateway.SetType(id.Gateway) - - //encrypt the message - stream = rng.GetStream() - salt := make([]byte, 32) - _, err = stream.Read(salt) stream.Close() - if err != nil { - jww.ERROR.Printf("Failed to generate salt when sending to "+ - "%s (msgDigest: %s): %+v", recipient, msg.Digest(), err) - return 0, ephemeral.Id{}, errors.WithMessage(err, - "Failed to generate salt, this should never happen") - } - - encMsg, kmacs := roundKeys.Encrypt(msg, salt, id.Round(bestRound.ID)) - - //build the message payload - msgPacket := &pb.Slot{ - SenderID: senderId.Bytes(), - PayloadA: encMsg.GetPayloadA(), - PayloadB: encMsg.GetPayloadB(), - Salt: salt, - KMACs: kmacs, - } - - //create the wrapper to the gateway - wrappedMsg := &pb.GatewaySlot{ - Message: msgPacket, - RoundID: bestRound.ID, - } - //Add the mac proving ownership - wrappedMsg.MAC = roundKeys.MakeClientGatewayKey(salt, - network.GenerateSlotDigest(wrappedMsg)) - jww.INFO.Printf("Sending to EphID %d (%s) on round %d, "+ "(msgDigest: %s, ecrMsgDigest: %s) via gateway %s", ephID.Int64(), recipient, bestRound.ID, msg.Digest(), @@ -179,42 +109,36 @@ func sendCmixHelper(sender *gateway.Sender, msg format.Message, recipient *id.ID wrappedMsg.Target = target.Marshal() result, err := comms.SendPutMessage(host, wrappedMsg) if err != nil { - if strings.Contains(err.Error(), - "try a different round.") { - jww.WARN.Printf("Failed to send to %s (msgDigest: %s) "+ - "due to round error with round %d, retrying: %+v", - recipient, msg.Digest(), bestRound.ID, err) - return nil, true, err - } else if strings.Contains(err.Error(), - "Could not authenticate client. Is the client registered "+ - "with this node?") { - jww.WARN.Printf("Failed to send to %s (msgDigest: %s) "+ - "via %s due to failed authentication: %s", - recipient, msg.Digest(), firstGateway.String(), err) - //if we failed to send due to the gateway not recognizing our - // authorization, renegotiate with the node to refresh it - nodeID := firstGateway.DeepCopy() - nodeID.SetType(id.Node) - //delete the keys - session.Cmix().Remove(nodeID) - //trigger - go handleMissingNodeKeys(instance, nodeRegistration, []*id.ID{nodeID}) - return nil, true, err + // fixme: should we provide as a slice the whole topology? + warn, err := handlePutMessageError(firstGateway, instance, session, nodeRegistration, recipient.String(), bestRound, err) + if warn { + jww.WARN.Printf("SendCmix Failed: %+v", err) + } else { + return result, false, errors.WithMessagef(err, "SendCmix %s", unrecoverableError) } } return result, false, err } - result, err := sender.SendToPreferred([]*id.ID{firstGateway}, sendFunc) + result, err := sender.SendToPreferred([]*id.ID{firstGateway}, sendFunc, stop) + + // Exit if the thread has been stopped + if stoppable.CheckErr(err) { + return 0, ephemeral.Id{}, err + } //if the comm errors or the message fails to send, continue retrying. - //return if it sends properly if err != nil { - jww.ERROR.Printf("Failed to send to EphID %d (%s) on "+ - "round %d, trying a new round: %+v", ephID.Int64(), recipient, - bestRound.ID, err) - continue + if !strings.Contains(err.Error(), unrecoverableError) { + jww.ERROR.Printf("SendCmix failed to send to EphID %d (%s) on "+ + "round %d, trying a new round: %+v", ephID.Int64(), recipient, + bestRound.ID, err) + continue + } + + return 0, ephemeral.Id{}, err } + // Return if it sends properly gwSlotResp := result.(*pb.GatewaySlotResponse) if gwSlotResp.Accepted { jww.INFO.Printf("Successfully sent to EphID %v (source: %s) "+ @@ -223,29 +147,10 @@ func sendCmixHelper(sender *gateway.Sender, msg format.Message, recipient *id.ID } else { jww.FATAL.Panicf("Gateway %s returned no error, but failed "+ "to accept message when sending to EphID %d (%s) on round %d", - firstGateway.String(), ephID.Int64(), recipient, bestRound.ID) + firstGateway, ephID.Int64(), recipient, bestRound.ID) } + } return 0, ephemeral.Id{}, errors.New("failed to send the message, " + "unknown error") } - -// Signals to the node registration thread to register a node if keys are -// missing. Identity is triggered automatically when the node is first seen, -// so this should on trigger on rare events. -func handleMissingNodeKeys(instance *network.Instance, - newNodeChan chan network.NodeGateway, nodes []*id.ID) { - for _, n := range nodes { - ng, err := instance.GetNodeAndGateway(n) - if err != nil { - jww.ERROR.Printf("Node contained in round cannot be found: %s", err) - continue - } - select { - case newNodeChan <- ng: - default: - jww.ERROR.Printf("Failed to send node registration for %s", n) - } - - } -} diff --git a/network/message/sendCmixUtils.go b/network/message/sendCmixUtils.go new file mode 100644 index 0000000000000000000000000000000000000000..c1d8e295885b1dadaa018ea881d07b1896625d87 --- /dev/null +++ b/network/message/sendCmixUtils.go @@ -0,0 +1,231 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package message + +import ( + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/storage" + "gitlab.com/elixxir/client/storage/cmix" + pb "gitlab.com/elixxir/comms/mixmessages" + "gitlab.com/elixxir/comms/network" + "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/elixxir/crypto/fingerprint" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/elixxir/primitives/states" + "gitlab.com/xx_network/comms/connect" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/id/ephemeral" + "strconv" + "strings" + "time" +) + +// Interface for SendCMIX comms; allows mocking this in testing. +type sendCmixCommsInterface interface { + SendPutMessage(host *connect.Host, message *pb.GatewaySlot) (*pb.GatewaySlotResponse, error) + SendPutManyMessages(host *connect.Host, messages *pb.GatewaySlots) (*pb.GatewaySlotResponse, error) +} + +// 2.5 seconds +const sendTimeBuffer = 2500 * time.Millisecond +const unrecoverableError = "failed with an unrecoverable error" + +// handlePutMessageError handles errors received from a PutMessage or a +// PutManyMessage network call. A printable error will be returned giving more +// context. If the error is not among recoverable errors, then the recoverable +// boolean will be returned false. If the error is among recoverable errors, +// then the boolean will return true. +func handlePutMessageError(firstGateway *id.ID, instance *network.Instance, + session *storage.Session, nodeRegistration chan network.NodeGateway, + recipientString string, bestRound *pb.RoundInfo, + err error) (recoverable bool, returnErr error) { + + // If the comm errors or the message fails to send, then continue retrying; + // otherwise, return if it sends properly + if strings.Contains(err.Error(), "try a different round.") { + return true, errors.WithMessagef(err, "Failed to send to [%s] due to "+ + "round error with round %d, retrying...", + recipientString, bestRound.ID) + } else if strings.Contains(err.Error(), "Could not authenticate client. "+ + "Is the client registered with this node?") { + // If send failed due to the gateway not recognizing the authorization, + // then renegotiate with the node to refresh it + nodeID := firstGateway.DeepCopy() + nodeID.SetType(id.Node) + + // Delete the keys + session.Cmix().Remove(nodeID) + + // Trigger + go handleMissingNodeKeys(instance, nodeRegistration, []*id.ID{nodeID}) + + return true, errors.WithMessagef(err, "Failed to send to [%s] via %s "+ + "due to failed authentication, retrying...", + recipientString, firstGateway) + } + + return false, errors.WithMessage(err, "Failed to put cmix message") + +} + +// processRound is a helper function that determines the gateway to send to for +// a round and retrieves the round keys. +func processRound(instance *network.Instance, session *storage.Session, + nodeRegistration chan network.NodeGateway, bestRound *pb.RoundInfo, + recipientString, messageDigest string) (*id.ID, *cmix.RoundKeys, error) { + + // Build the topology + idList, err := id.NewIDListFromBytes(bestRound.Topology) + if err != nil { + return nil, nil, errors.WithMessagef(err, "Failed to use topology for "+ + "round %d when sending to [%s] (msgDigest(s): %s)", + bestRound.ID, recipientString, messageDigest) + } + topology := connect.NewCircuit(idList) + + // Get the keys for the round, reject if any nodes do not have keying + // relationships + roundKeys, missingKeys := session.Cmix().GetRoundKeys(topology) + if len(missingKeys) > 0 { + go handleMissingNodeKeys(instance, nodeRegistration, missingKeys) + + return nil, nil, errors.Errorf("Failed to send on round %d to [%s] "+ + "(msgDigest(s): %s) due to missing relationships with nodes: %s", + bestRound.ID, recipientString, messageDigest, missingKeys) + } + + // Get the gateway to transmit to + firstGateway := topology.GetNodeAtIndex(0).DeepCopy() + firstGateway.SetType(id.Gateway) + + return firstGateway, roundKeys, nil +} + +// buildSlotMessage is a helper function which forms a slotted message to send +// to a gateway. It encrypts passed in message and generates an ephemeral ID for +// the recipient. +func buildSlotMessage(msg format.Message, recipient *id.ID, target *id.ID, + stream *fastRNG.Stream, senderId *id.ID, bestRound *pb.RoundInfo, + roundKeys *cmix.RoundKeys) (*pb.GatewaySlot, format.Message, ephemeral.Id, + error) { + + // Set the ephemeral ID + ephID, _, _, err := ephemeral.GetId(recipient, + uint(bestRound.AddressSpaceSize), + int64(bestRound.Timestamps[states.QUEUED])) + if err != nil { + jww.FATAL.Panicf("Failed to generate ephemeral ID when sending to %s "+ + "(msgDigest: %s): %+v", err, recipient, msg.Digest()) + } + + ephIdFilled, err := ephID.Fill(uint(bestRound.AddressSpaceSize), stream) + if err != nil { + jww.FATAL.Panicf("Failed to obfuscate the ephemeralID when sending "+ + "to %s (msgDigest: %s): %+v", recipient, msg.Digest(), err) + } + + msg.SetEphemeralRID(ephIdFilled[:]) + + // Set the identity fingerprint + ifp := fingerprint.IdentityFP(msg.GetContents(), recipient) + + msg.SetIdentityFP(ifp) + + // Encrypt the message + salt := make([]byte, 32) + _, err = stream.Read(salt) + if err != nil { + jww.ERROR.Printf("Failed to generate salt when sending to %s "+ + "(msgDigest: %s): %+v", recipient, msg.Digest(), err) + return nil, format.Message{}, ephemeral.Id{}, errors.WithMessage(err, + "Failed to generate salt, this should never happen") + } + + encMsg, kmacs := roundKeys.Encrypt(msg, salt, id.Round(bestRound.ID)) + + // Build the message payload + msgPacket := &pb.Slot{ + SenderID: senderId.Bytes(), + PayloadA: encMsg.GetPayloadA(), + PayloadB: encMsg.GetPayloadB(), + Salt: salt, + KMACs: kmacs, + } + + // Create the wrapper to the gateway + slot := &pb.GatewaySlot{ + Message: msgPacket, + RoundID: bestRound.ID, + Target: target.Bytes(), + } + + // Add the mac proving ownership + slot.MAC = roundKeys.MakeClientGatewayKey(salt, + network.GenerateSlotDigest(slot)) + + return slot, encMsg, ephID, nil +} + +// handleMissingNodeKeys signals to the node registration thread to register a +// node if keys are missing. Identity is triggered automatically when the node +// is first seen, so this should on trigger on rare events. +func handleMissingNodeKeys(instance *network.Instance, + newNodeChan chan network.NodeGateway, nodes []*id.ID) { + for _, n := range nodes { + ng, err := instance.GetNodeAndGateway(n) + if err != nil { + jww.ERROR.Printf("Node contained in round cannot be found: %s", err) + continue + } + + select { + case newNodeChan <- ng: + default: + jww.ERROR.Printf("Failed to send node registration for %s", n) + } + + } +} + +// messageMapToStrings serializes a map of IDs and messages into a string of IDs +// and a string of message digests. Intended for use in printing to logs. +func messageMapToStrings(msgList map[id.ID]format.Message) (string, string) { + idStrings := make([]string, 0, len(msgList)) + msgDigests := make([]string, 0, len(msgList)) + for uid, msg := range msgList { + idStrings = append(idStrings, uid.String()) + msgDigests = append(msgDigests, msg.Digest()) + } + + return strings.Join(idStrings, ","), strings.Join(msgDigests, ",") +} + +// messagesToDigestString serializes a list of messages into a string of message +// digests. Intended for use in printing to the logs. +func messagesToDigestString(msgs []format.Message) string { + msgDigests := make([]string, 0, len(msgs)) + for _, msg := range msgs { + msgDigests = append(msgDigests, msg.Digest()) + } + + return strings.Join(msgDigests, ",") +} + +// ephemeralIdListToString serializes a list of ephemeral IDs into a human- +// readable format. Intended for use in printing to logs. +func ephemeralIdListToString(idList []ephemeral.Id) string { + idStrings := make([]string, 0, len(idList)) + + for i := 0; i < len(idList); i++ { + ephIdStr := strconv.FormatInt(idList[i].Int64(), 10) + idStrings = append(idStrings, ephIdStr) + } + + return strings.Join(idStrings, ",") +} diff --git a/network/message/sendCmix_test.go b/network/message/sendCmix_test.go index f3182a8a2407c4f29c1c615160a28fd793ab4ee6..28acd56525e7393500819539561cda4f8160ba07 100644 --- a/network/message/sendCmix_test.go +++ b/network/message/sendCmix_test.go @@ -18,47 +18,15 @@ import ( "gitlab.com/elixxir/crypto/fastRNG" "gitlab.com/elixxir/primitives/format" "gitlab.com/elixxir/primitives/states" - "gitlab.com/xx_network/comms/connect" "gitlab.com/xx_network/crypto/csprng" "gitlab.com/xx_network/crypto/large" "gitlab.com/xx_network/primitives/id" - "gitlab.com/xx_network/primitives/ndf" "gitlab.com/xx_network/primitives/netTime" "testing" "time" ) -type MockSendCMIXComms struct { - t *testing.T -} - -func (mc *MockSendCMIXComms) GetHost(hostId *id.ID) (*connect.Host, bool) { - nid1 := id.NewIdFromString("zezima", id.Node, mc.t) - gwid := nid1.DeepCopy() - gwid.SetType(id.Gateway) - h, _ := connect.NewHost(gwid, "0.0.0.0", []byte(""), connect.HostParams{ - MaxRetries: 0, - AuthEnabled: false, - }) - return h, true -} - -func (mc *MockSendCMIXComms) AddHost(hid *id.ID, address string, cert []byte, params connect.HostParams) (host *connect.Host, err error) { - host, _ = mc.GetHost(nil) - return host, nil -} - -func (mc *MockSendCMIXComms) RemoveHost(hid *id.ID) { - -} - -func (mc *MockSendCMIXComms) SendPutMessage(host *connect.Host, message *mixmessages.GatewaySlot) (*mixmessages.GatewaySlotResponse, error) { - return &mixmessages.GatewaySlotResponse{ - Accepted: true, - RoundID: 3, - }, nil -} - +// Unit test func Test_attemptSendCmix(t *testing.T) { sess1 := storage.InitTestingSession(t) @@ -143,68 +111,12 @@ func Test_attemptSendCmix(t *testing.T) { msgCmix := format.NewMessage(m.Session.Cmix().GetGroup().GetP().ByteLen()) msgCmix.SetContents([]byte("test")) e2e.SetUnencrypted(msgCmix, m.Session.User().GetCryptographicIdentity().GetTransmissionID()) - _, _, err = sendCmixHelper(sender, msgCmix, sess2.GetUser().ReceptionID, params.GetDefaultMessage(), params.GetDefaultCMIX(), - m.Instance, m.Session, m.nodeRegistration, m.Rng, - m.TransmissionID, &MockSendCMIXComms{t: t}) + _, _, err = sendCmixHelper(sender, msgCmix, sess2.GetUser().ReceptionID, + params.GetDefaultCMIX(), m.Instance, m.Session, m.nodeRegistration, + m.Rng, m.TransmissionID, &MockSendCMIXComms{t: t}, nil) if err != nil { t.Errorf("Failed to sendcmix: %+v", err) panic("t") return } } - -func getNDF() *ndf.NetworkDefinition { - nodeId := id.NewIdFromString("zezima", id.Node, &testing.T{}) - gwId := nodeId.DeepCopy() - gwId.SetType(id.Gateway) - return &ndf.NetworkDefinition{ - E2E: ndf.Group{ - Prime: "E2EE983D031DC1DB6F1A7A67DF0E9A8E5561DB8E8D49413394C049B" + - "7A8ACCEDC298708F121951D9CF920EC5D146727AA4AE535B0922C688B55B3DD2AE" + - "DF6C01C94764DAB937935AA83BE36E67760713AB44A6337C20E7861575E745D31F" + - "8B9E9AD8412118C62A3E2E29DF46B0864D0C951C394A5CBBDC6ADC718DD2A3E041" + - "023DBB5AB23EBB4742DE9C1687B5B34FA48C3521632C4A530E8FFB1BC51DADDF45" + - "3B0B2717C2BC6669ED76B4BDD5C9FF558E88F26E5785302BEDBCA23EAC5ACE9209" + - "6EE8A60642FB61E8F3D24990B8CB12EE448EEF78E184C7242DD161C7738F32BF29" + - "A841698978825B4111B4BC3E1E198455095958333D776D8B2BEEED3A1A1A221A6E" + - "37E664A64B83981C46FFDDC1A45E3D5211AAF8BFBC072768C4F50D7D7803D2D4F2" + - "78DE8014A47323631D7E064DE81C0C6BFA43EF0E6998860F1390B5D3FEACAF1696" + - "015CB79C3F9C2D93D961120CD0E5F12CBB687EAB045241F96789C38E89D796138E" + - "6319BE62E35D87B1048CA28BE389B575E994DCA755471584A09EC723742DC35873" + - "847AEF49F66E43873", - Generator: "2", - }, - CMIX: ndf.Group{ - Prime: "9DB6FB5951B66BB6FE1E140F1D2CE5502374161FD6538DF1648218642F0B5C48" + - "C8F7A41AADFA187324B87674FA1822B00F1ECF8136943D7C55757264E5A1A44F" + - "FE012E9936E00C1D3E9310B01C7D179805D3058B2A9F4BB6F9716BFE6117C6B5" + - "B3CC4D9BE341104AD4A80AD6C94E005F4B993E14F091EB51743BF33050C38DE2" + - "35567E1B34C3D6A5C0CEAA1A0F368213C3D19843D0B4B09DCB9FC72D39C8DE41" + - "F1BF14D4BB4563CA28371621CAD3324B6A2D392145BEBFAC748805236F5CA2FE" + - "92B871CD8F9C36D3292B5509CA8CAA77A2ADFC7BFD77DDA6F71125A7456FEA15" + - "3E433256A2261C6A06ED3693797E7995FAD5AABBCFBE3EDA2741E375404AE25B", - Generator: "5C7FF6B06F8F143FE8288433493E4769C4D988ACE5BE25A0E24809670716C613" + - "D7B0CEE6932F8FAA7C44D2CB24523DA53FBE4F6EC3595892D1AA58C4328A06C4" + - "6A15662E7EAA703A1DECF8BBB2D05DBE2EB956C142A338661D10461C0D135472" + - "085057F3494309FFA73C611F78B32ADBB5740C361C9F35BE90997DB2014E2EF5" + - "AA61782F52ABEB8BD6432C4DD097BC5423B285DAFB60DC364E8161F4A2A35ACA" + - "3A10B1C4D203CC76A470A33AFDCBDD92959859ABD8B56E1725252D78EAC66E71" + - "BA9AE3F1DD2487199874393CD4D832186800654760E1E34C09E4D155179F9EC0" + - "DC4473F996BDCE6EED1CABED8B6F116F7AD9CF505DF0F998E34AB27514B0FFE7", - }, - Gateways: []ndf.Gateway{ - { - ID: gwId.Marshal(), - Address: "0.0.0.0", - TlsCertificate: "", - }, - }, - Nodes: []ndf.Node{ - { - ID: nodeId.Marshal(), - Address: "0.0.0.0", - TlsCertificate: "", - }, - }, - } -} diff --git a/network/message/sendE2E.go b/network/message/sendE2E.go index b09e96c4a19598e7724ccf0bd5781ff60745b3d1..7b468ad1a30c818396d631b4b9f85b5945dd88ab 100644 --- a/network/message/sendE2E.go +++ b/network/message/sendE2E.go @@ -13,6 +13,7 @@ import ( "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/keyExchange" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/crypto/e2e" "gitlab.com/elixxir/primitives/format" "gitlab.com/xx_network/primitives/id" @@ -21,7 +22,8 @@ import ( "time" ) -func (m *Manager) SendE2E(msg message.Send, param params.E2E) ([]id.Round, e2e.MessageID, error) { +func (m *Manager) SendE2E(msg message.Send, param params.E2E, + stop *stoppable.Single) ([]id.Round, e2e.MessageID, error) { if msg.MessageType == message.Raw { return nil, e2e.MessageID{}, errors.Errorf("Raw (%d) is a reserved "+ "message type", msg.MessageType) @@ -58,7 +60,7 @@ func (m *Manager) SendE2E(msg message.Send, param params.E2E) ([]id.Round, e2e.M if msg.MessageType != message.KeyExchangeTrigger { // check if any rekeys need to happen and trigger them keyExchange.CheckKeyExchanges(m.Instance, m.SendE2E, - m.Session, partner, 1*time.Minute) + m.Session, partner, 1*time.Minute, stop) } //create the cmix message @@ -96,7 +98,7 @@ func (m *Manager) SendE2E(msg message.Send, param params.E2E) ([]id.Round, e2e.M go func(i int) { var err error roundIds[i], _, err = m.SendCMIX(m.sender, msgEnc, msg.Recipient, - param.CMIX) + param.CMIX, stop) if err != nil { errCh <- err } diff --git a/network/message/sendManyCmix.go b/network/message/sendManyCmix.go new file mode 100644 index 0000000000000000000000000000000000000000..e9d598182bb805fc3220804389af43456eccf022 --- /dev/null +++ b/network/message/sendManyCmix.go @@ -0,0 +1,184 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package message + +import ( + "github.com/golang-collections/collections/set" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/network/gateway" + "gitlab.com/elixxir/client/storage" + pb "gitlab.com/elixxir/comms/mixmessages" + "gitlab.com/elixxir/comms/network" + "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/comms/connect" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/id/ephemeral" + "gitlab.com/xx_network/primitives/netTime" + "strings" +) + +// SendManyCMIX sends many "raw" cMix message payloads to each of the provided +// recipients. Used to send messages in group chats. Metadata is NOT protected +// with this call and can leak data about yourself. Returns the round ID of the +// round the payload was sent or an error if it fails. +// WARNING: Potentially Unsafe +func (m *Manager) SendManyCMIX(sender *gateway.Sender, + messages map[id.ID]format.Message, p params.CMIX) (id.Round, []ephemeral.Id, + error) { + + // Create message copies + messagesCopy := make(map[id.ID]format.Message, len(messages)) + for rid, msg := range messages { + messagesCopy[rid] = msg.Copy() + } + + return sendManyCmixHelper(sender, messagesCopy, p, m.Instance, m.Session, + m.nodeRegistration, m.Rng, m.TransmissionID, m.Comms) +} + +// sendManyCmixHelper is a helper function for Manager.SendManyCMIX. +// +// NOTE: Payloads sent are not end to end encrypted, metadata is NOT protected +// with this call; see SendE2E for end to end encryption and full privacy +// protection. Internal SendManyCMIX, which bypasses the network check, will +// attempt to send to the network without checking state. It has a built in +// retry system which can be configured through the params object. +// +// If the message is successfully sent, the ID of the round sent it is returned, +// which can be registered with the network instance to get a callback on its +// status. +func sendManyCmixHelper(sender *gateway.Sender, msgs map[id.ID]format.Message, + param params.CMIX, instance *network.Instance, session *storage.Session, + nodeRegistration chan network.NodeGateway, rng *fastRNG.StreamGenerator, + senderId *id.ID, comms sendCmixCommsInterface) (id.Round, []ephemeral.Id, error) { + + timeStart := netTime.Now() + attempted := set.New() + stream := rng.GetStream() + defer stream.Close() + + recipientString, msgDigests := messageMapToStrings(msgs) + + jww.INFO.Printf("Looking for round to send cMix messages to [%s] "+ + "(msgDigest: %s)", recipientString, msgDigests) + + for numRoundTries := uint(0); numRoundTries < param.RoundTries; numRoundTries++ { + elapsed := netTime.Now().Sub(timeStart) + + if elapsed > param.Timeout { + jww.INFO.Printf("No rounds to send to %s (msgDigest: %s) were found "+ + "before timeout %s", recipientString, msgDigests, param.Timeout) + return 0, []ephemeral.Id{}, + errors.New("sending cMix message timed out") + } + + if numRoundTries > 0 { + jww.INFO.Printf("Attempt %d to find round to send message to %s "+ + "(msgDigest: %s)", numRoundTries+1, recipientString, msgDigests) + } + + remainingTime := param.Timeout - elapsed + + // Find the best round to send to, excluding attempted rounds + bestRound, _ := instance.GetWaitingRounds().GetUpcomingRealtime( + remainingTime, attempted, sendTimeBuffer) + if bestRound == nil { + continue + } + + // Add the round on to the list of attempted so it is not tried again + attempted.Insert(bestRound) + + // Retrieve host and key information from round + firstGateway, roundKeys, err := processRound(instance, session, + nodeRegistration, bestRound, recipientString, msgDigests) + if err != nil { + jww.WARN.Printf("SendManyCMIX failed to process round %d "+ + "(will retry): %+v", bestRound.ID, err) + continue + } + + // Build a slot for every message and recipient + slots := make([]*pb.GatewaySlot, len(msgs)) + ephemeralIds := make([]ephemeral.Id, len(msgs)) + encMsgs := make([]format.Message, len(msgs)) + i := 0 + for recipient, msg := range msgs { + slots[i], encMsgs[i], ephemeralIds[i], err = buildSlotMessage( + msg, &recipient, firstGateway, stream, senderId, bestRound, roundKeys) + if err != nil { + return 0, []ephemeral.Id{}, errors.Errorf("failed to build "+ + "slot message for %s: %+v", recipient, err) + } + i++ + } + + // Serialize lists into a printable format + ephemeralIdsString := ephemeralIdListToString(ephemeralIds) + encMsgsDigest := messagesToDigestString(encMsgs) + + jww.INFO.Printf("Sending to EphIDs [%s] (%s) on round %d, "+ + "(msgDigest: %s, ecrMsgDigest: %s) via gateway %s", + ephemeralIdsString, recipientString, bestRound.ID, msgDigests, + encMsgsDigest, firstGateway) + + // Wrap slots in the proper message type + wrappedMessage := &pb.GatewaySlots{ + Messages: slots, + RoundID: bestRound.ID, + } + + // Send the payload + sendFunc := func(host *connect.Host, target *id.ID) (interface{}, bool, error) { + wrappedMessage.Target = target.Marshal() + result, err := comms.SendPutManyMessages(host, wrappedMessage) + if err != nil { + warn, err := handlePutMessageError(firstGateway, instance, + session, nodeRegistration, recipientString, bestRound, err) + if warn { + jww.WARN.Printf("SendManyCMIX Failed: %+v", err) + } else { + return result, false, errors.WithMessagef(err, + "SendManyCMIX %s", unrecoverableError) + } + } + return result, false, err + } + result, err := sender.SendToPreferred([]*id.ID{firstGateway}, sendFunc, nil) + + // If the comm errors or the message fails to send, continue retrying + if err != nil { + if !strings.Contains(err.Error(), unrecoverableError) { + jww.ERROR.Printf("SendManyCMIX failed to send to EphIDs [%s] "+ + "(sources: %s) on round %d, trying a new round %+v", + ephemeralIdsString, recipientString, bestRound.ID, err) + continue + } + + return 0, []ephemeral.Id{}, err + } + + // Return if it sends properly + gwSlotResp := result.(*pb.GatewaySlotResponse) + if gwSlotResp.Accepted { + jww.INFO.Printf("Successfully sent to EphIDs %v (sources: [%s]) in "+ + "round %d", ephemeralIdsString, recipientString, bestRound.ID) + return id.Round(bestRound.ID), ephemeralIds, nil + } else { + jww.FATAL.Panicf("Gateway %s returned no error, but failed to "+ + "accept message when sending to EphIDs [%s] (%s) on round %d", + firstGateway, ephemeralIdsString, recipientString, bestRound.ID) + } + } + + return 0, []ephemeral.Id{}, + errors.New("failed to send the message, unknown error") +} diff --git a/network/message/sendManyCmix_test.go b/network/message/sendManyCmix_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2d07cf456e4177af469a44360132590bb4fc62d3 --- /dev/null +++ b/network/message/sendManyCmix_test.go @@ -0,0 +1,134 @@ +package message + +import ( + "github.com/pkg/errors" + "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/network/gateway" + "gitlab.com/elixxir/client/network/internal" + "gitlab.com/elixxir/client/storage" + "gitlab.com/elixxir/client/switchboard" + "gitlab.com/elixxir/comms/client" + "gitlab.com/elixxir/comms/mixmessages" + "gitlab.com/elixxir/comms/network" + ds "gitlab.com/elixxir/comms/network/dataStructures" + "gitlab.com/elixxir/comms/testutils" + "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/elixxir/crypto/e2e" + "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/elixxir/primitives/states" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/crypto/large" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "testing" + "time" +) + +// Unit test +func Test_attemptSendManyCmix(t *testing.T) { + sess1 := storage.InitTestingSession(t) + + numRecipients := 3 + recipients := make([]*id.ID, numRecipients) + sw := switchboard.New() + l := TestListener{ + ch: make(chan bool), + } + for i := 0; i < numRecipients; i++ { + sess := storage.InitTestingSession(t) + sw.RegisterListener(sess.GetUser().TransmissionID, message.Raw, l) + recipients[i] = sess.GetUser().ReceptionID + } + + comms, err := client.NewClientComms(sess1.GetUser().TransmissionID, nil, nil, nil) + if err != nil { + t.Errorf("Failed to start client comms: %+v", err) + } + inst, err := network.NewInstanceTesting(comms.ProtoComms, getNDF(), nil, nil, nil, t) + if err != nil { + t.Errorf("Failed to start instance: %+v", err) + } + now := netTime.Now() + nid1 := id.NewIdFromString("zezima", id.Node, t) + nid2 := id.NewIdFromString("jakexx360", id.Node, t) + nid3 := id.NewIdFromString("westparkhome", id.Node, t) + grp := cyclic.NewGroup(large.NewInt(7), large.NewInt(13)) + sess1.Cmix().Add(nid1, grp.NewInt(1)) + sess1.Cmix().Add(nid2, grp.NewInt(2)) + sess1.Cmix().Add(nid3, grp.NewInt(3)) + + timestamps := []uint64{ + uint64(now.Add(-30 * time.Second).UnixNano()), // PENDING + uint64(now.Add(-25 * time.Second).UnixNano()), // PRECOMPUTING + uint64(now.Add(-5 * time.Second).UnixNano()), // STANDBY + uint64(now.Add(5 * time.Second).UnixNano()), // QUEUED + 0} // REALTIME + + ri := &mixmessages.RoundInfo{ + ID: 3, + UpdateID: 0, + State: uint32(states.QUEUED), + BatchSize: 0, + Topology: [][]byte{nid1.Marshal(), nid2.Marshal(), nid3.Marshal()}, + Timestamps: timestamps, + Errors: nil, + ClientErrors: nil, + ResourceQueueTimeoutMillis: 0, + Signature: nil, + AddressSpaceSize: 4, + } + + if err = testutils.SignRoundInfoRsa(ri, t); err != nil { + t.Errorf("Failed to sign mock round info: %v", err) + } + + pubKey, err := testutils.LoadPublicKeyTesting(t) + if err != nil { + t.Errorf("Failed to load a key for testing: %v", err) + } + rnd := ds.NewRound(ri, pubKey, nil) + inst.GetWaitingRounds().Insert(rnd) + i := internal.Internal{ + Session: sess1, + Switchboard: sw, + Rng: fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + Comms: comms, + Health: nil, + TransmissionID: sess1.GetUser().TransmissionID, + Instance: inst, + NodeRegistration: nil, + } + p := gateway.DefaultPoolParams() + p.MaxPoolSize = 1 + sender, err := gateway.NewSender(p, i.Rng, getNDF(), &MockSendCMIXComms{t: t}, i.Session, nil) + if err != nil { + t.Errorf("%+v", errors.New(err.Error())) + return + } + m := NewManager(i, params.Messages{ + MessageReceptionBuffLen: 20, + MessageReceptionWorkerPoolSize: 20, + MaxChecksGarbledMessage: 20, + GarbledMessageWait: time.Hour, + }, nil, sender) + msgCmix := format.NewMessage(m.Session.Cmix().GetGroup().GetP().ByteLen()) + msgCmix.SetContents([]byte("test")) + e2e.SetUnencrypted(msgCmix, m.Session.User().GetCryptographicIdentity().GetTransmissionID()) + messages := make([]format.Message, numRecipients) + for i := 0; i < numRecipients; i++ { + messages[i] = msgCmix + } + + msgMap := make(map[id.ID]format.Message, numRecipients) + for i := 0; i < numRecipients; i++ { + msgMap[*recipients[i]] = msgCmix + } + + _, _, err = sendManyCmixHelper(sender, msgMap, params.GetDefaultCMIX(), m.Instance, + m.Session, m.nodeRegistration, m.Rng, m.TransmissionID, &MockSendCMIXComms{t: t}) + if err != nil { + t.Errorf("Failed to sendcmix: %+v", err) + } +} diff --git a/network/message/sendUnsafe.go b/network/message/sendUnsafe.go index 19f7d5d5ce0bfe6dd14df77b76f0a21c824b2404..938bcc07c8bab9bd24e1bdd53cb1c38c3b4111c4 100644 --- a/network/message/sendUnsafe.go +++ b/network/message/sendUnsafe.go @@ -64,7 +64,7 @@ func (m *Manager) SendUnsafe(msg message.Send, param params.Unsafe) ([]id.Round, wg.Add(1) go func(i int) { var err error - roundIds[i], _, err = m.SendCMIX(m.sender, msgCmix, msg.Recipient, param.CMIX) + roundIds[i], _, err = m.SendCMIX(m.sender, msgCmix, msg.Recipient, param.CMIX, nil) if err != nil { errCh <- err } diff --git a/network/message/utils_test.go b/network/message/utils_test.go new file mode 100644 index 0000000000000000000000000000000000000000..18284ad0bb6b6693db062f2fb5b048b5e18ca08f --- /dev/null +++ b/network/message/utils_test.go @@ -0,0 +1,106 @@ +package message + +import ( + "gitlab.com/elixxir/comms/mixmessages" + "gitlab.com/xx_network/comms/connect" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/ndf" + "testing" +) + +type MockSendCMIXComms struct { + t *testing.T +} + +func (mc *MockSendCMIXComms) GetHost(*id.ID) (*connect.Host, bool) { + nid1 := id.NewIdFromString("zezima", id.Node, mc.t) + gwID := nid1.DeepCopy() + gwID.SetType(id.Gateway) + h, _ := connect.NewHost(gwID, "0.0.0.0", []byte(""), connect.HostParams{ + MaxRetries: 0, + AuthEnabled: false, + }) + return h, true +} + +func (mc *MockSendCMIXComms) AddHost(*id.ID, string, []byte, connect.HostParams) (host *connect.Host, err error) { + host, _ = mc.GetHost(nil) + return host, nil +} + +func (mc *MockSendCMIXComms) RemoveHost(*id.ID) { + +} + +func (mc *MockSendCMIXComms) SendPutMessage(*connect.Host, *mixmessages.GatewaySlot) (*mixmessages.GatewaySlotResponse, error) { + return &mixmessages.GatewaySlotResponse{ + Accepted: true, + RoundID: 3, + }, nil +} + +func (mc *MockSendCMIXComms) SendPutManyMessages(*connect.Host, *mixmessages.GatewaySlots) (*mixmessages.GatewaySlotResponse, error) { + return &mixmessages.GatewaySlotResponse{ + Accepted: true, + RoundID: 3, + }, nil +} + +func getNDF() *ndf.NetworkDefinition { + nodeId := id.NewIdFromString("zezima", id.Node, &testing.T{}) + gwId := nodeId.DeepCopy() + gwId.SetType(id.Gateway) + return &ndf.NetworkDefinition{ + E2E: ndf.Group{ + Prime: "E2EE983D031DC1DB6F1A7A67DF0E9A8E5561DB8E8D49413394C049B7A" + + "8ACCEDC298708F121951D9CF920EC5D146727AA4AE535B0922C688B55B3D" + + "D2AEDF6C01C94764DAB937935AA83BE36E67760713AB44A6337C20E78615" + + "75E745D31F8B9E9AD8412118C62A3E2E29DF46B0864D0C951C394A5CBBDC" + + "6ADC718DD2A3E041023DBB5AB23EBB4742DE9C1687B5B34FA48C3521632C" + + "4A530E8FFB1BC51DADDF453B0B2717C2BC6669ED76B4BDD5C9FF558E88F2" + + "6E5785302BEDBCA23EAC5ACE92096EE8A60642FB61E8F3D24990B8CB12EE" + + "448EEF78E184C7242DD161C7738F32BF29A841698978825B4111B4BC3E1E" + + "198455095958333D776D8B2BEEED3A1A1A221A6E37E664A64B83981C46FF" + + "DDC1A45E3D5211AAF8BFBC072768C4F50D7D7803D2D4F278DE8014A47323" + + "631D7E064DE81C0C6BFA43EF0E6998860F1390B5D3FEACAF1696015CB79C" + + "3F9C2D93D961120CD0E5F12CBB687EAB045241F96789C38E89D796138E63" + + "19BE62E35D87B1048CA28BE389B575E994DCA755471584A09EC723742DC3" + + "5873847AEF49F66E43873", + Generator: "2", + }, + CMIX: ndf.Group{ + Prime: "9DB6FB5951B66BB6FE1E140F1D2CE5502374161FD6538DF1648218642" + + "F0B5C48C8F7A41AADFA187324B87674FA1822B00F1ECF8136943D7C55757" + + "264E5A1A44FFE012E9936E00C1D3E9310B01C7D179805D3058B2A9F4BB6F" + + "9716BFE6117C6B5B3CC4D9BE341104AD4A80AD6C94E005F4B993E14F091E" + + "B51743BF33050C38DE235567E1B34C3D6A5C0CEAA1A0F368213C3D19843D" + + "0B4B09DCB9FC72D39C8DE41F1BF14D4BB4563CA28371621CAD3324B6A2D3" + + "92145BEBFAC748805236F5CA2FE92B871CD8F9C36D3292B5509CA8CAA77A" + + "2ADFC7BFD77DDA6F71125A7456FEA153E433256A2261C6A06ED3693797E7" + + "995FAD5AABBCFBE3EDA2741E375404AE25B", + Generator: "5C7FF6B06F8F143FE8288433493E4769C4D988ACE5BE25A0E2480" + + "9670716C613D7B0CEE6932F8FAA7C44D2CB24523DA53FBE4F6EC3595892D" + + "1AA58C4328A06C46A15662E7EAA703A1DECF8BBB2D05DBE2EB956C142A33" + + "8661D10461C0D135472085057F3494309FFA73C611F78B32ADBB5740C361" + + "C9F35BE90997DB2014E2EF5AA61782F52ABEB8BD6432C4DD097BC5423B28" + + "5DAFB60DC364E8161F4A2A35ACA3A10B1C4D203CC76A470A33AFDCBDD929" + + "59859ABD8B56E1725252D78EAC66E71BA9AE3F1DD2487199874393CD4D83" + + "2186800654760E1E34C09E4D155179F9EC0DC4473F996BDCE6EED1CABED8" + + "B6F116F7AD9CF505DF0F998E34AB27514B0FFE7", + }, + Gateways: []ndf.Gateway{ + { + ID: gwId.Marshal(), + Address: "0.0.0.0", + TlsCertificate: "", + }, + }, + Nodes: []ndf.Node{ + { + ID: nodeId.Marshal(), + Address: "0.0.0.0", + TlsCertificate: "", + }, + }, + } +} diff --git a/network/node/register.go b/network/node/register.go index 82ab2cc3d937245adbc63812d6e875bc63657845..a16e1cd08486b3bd45abd59a3f3061b7cd43a3bd 100644 --- a/network/node/register.go +++ b/network/node/register.go @@ -55,7 +55,8 @@ func StartRegistration(sender *gateway.Sender, session *storage.Session, rngGen return multi } -func registerNodes(sender *gateway.Sender, session *storage.Session, rngGen *fastRNG.StreamGenerator, comms RegisterNodeCommsInterface, +func registerNodes(sender *gateway.Sender, session *storage.Session, + rngGen *fastRNG.StreamGenerator, comms RegisterNodeCommsInterface, stop *stoppable.Single, c chan network.NodeGateway) { u := session.User() regSignature := u.GetTransmissionRegistrationValidationSignature() @@ -67,13 +68,15 @@ func registerNodes(sender *gateway.Sender, session *storage.Session, rngGen *fas rng := rngGen.GetStream() interval := time.Duration(500) * time.Millisecond t := time.NewTicker(interval) - for true { + for { select { case <-stop.Quit(): t.Stop() + stop.ToStopped() return case gw := <-c: - err := registerWithNode(sender, comms, gw, regSignature, regTimestamp, uci, cmix, rng) + err := registerWithNode(sender, comms, gw, regSignature, + regTimestamp, uci, cmix, rng, stop) if err != nil { jww.ERROR.Printf("Failed to register node: %+v", err) } @@ -84,9 +87,10 @@ func registerNodes(sender *gateway.Sender, session *storage.Session, rngGen *fas //registerWithNode serves as a helper for RegisterWithNodes // It registers a user with a specific in the client's ndf. -func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw network.NodeGateway, - regSig []byte, registrationTimestampNano int64, uci *user.CryptographicIdentity, - store *cmix.Store, rng csprng.Source) error { +func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, + ngw network.NodeGateway, regSig []byte, registrationTimestampNano int64, + uci *user.CryptographicIdentity, store *cmix.Store, rng csprng.Source, + stop *stoppable.Single) error { nodeID, err := ngw.Node.GetNodeId() if err != nil { @@ -122,7 +126,9 @@ func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, // keys transmissionHash, _ := hash.NewCMixHash() - nonce, dhPub, err := requestNonce(sender, comms, gatewayID, regSig, registrationTimestampNano, uci, store, rng) + nonce, dhPub, err := requestNonce(sender, comms, gatewayID, regSig, + registrationTimestampNano, uci, store, rng, stop) + if err != nil { return errors.Errorf("Failed to request nonce: %+v", err) } @@ -133,7 +139,8 @@ func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, // Confirm received nonce jww.INFO.Printf("Register: Confirming received nonce from node %s", nodeID.String()) err = confirmNonce(sender, comms, uci.GetTransmissionID().Bytes(), - nonce, uci.GetTransmissionRSA(), gatewayID) + nonce, uci.GetTransmissionRSA(), gatewayID, stop) + if err != nil { errMsg := fmt.Sprintf("Register: Unable to confirm nonce: %v", err) return errors.New(errMsg) @@ -151,7 +158,7 @@ func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, func requestNonce(sender *gateway.Sender, comms RegisterNodeCommsInterface, gwId *id.ID, regSig []byte, registrationTimestampNano int64, uci *user.CryptographicIdentity, - store *cmix.Store, rng csprng.Source) ([]byte, []byte, error) { + store *cmix.Store, rng csprng.Source, stop *stoppable.Single) ([]byte, []byte, error) { dhPub := store.GetDHPublicKey().Bytes() opts := rsa.NewDefaultOptions() @@ -195,7 +202,8 @@ func requestNonce(sender *gateway.Sender, comms RegisterNodeCommsInterface, gwId return nil, err } return nonceResponse, nil - }) + }, stop) + if err != nil { return nil, nil, err } @@ -208,8 +216,9 @@ func requestNonce(sender *gateway.Sender, comms RegisterNodeCommsInterface, gwId // confirmNonce is a helper for the Register function // It signs a nonce and sends it for confirmation // Returns nil if successful, error otherwise -func confirmNonce(sender *gateway.Sender, comms RegisterNodeCommsInterface, UID, nonce []byte, - privateKeyRSA *rsa.PrivateKey, gwID *id.ID) error { +func confirmNonce(sender *gateway.Sender, comms RegisterNodeCommsInterface, UID, + nonce []byte, privateKeyRSA *rsa.PrivateKey, gwID *id.ID, + stop *stoppable.Single) error { opts := rsa.NewDefaultOptions() opts.Hash = hash.CMixHash h, _ := hash.NewCMixHash() @@ -248,6 +257,7 @@ func confirmNonce(sender *gateway.Sender, comms RegisterNodeCommsInterface, UID, return nil, err } return confirmResponse, nil - }) + }, stop) + return err } diff --git a/network/rounds/historical.go b/network/rounds/historical.go index 45aed7fdfaa7b296a46de79645e19fd010f88634..b4920335458283243fe7f4b77d4777bde2efb4e7 100644 --- a/network/rounds/historical.go +++ b/network/rounds/historical.go @@ -9,6 +9,7 @@ package rounds import ( jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage/reception" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/xx_network/comms/connect" @@ -41,19 +42,18 @@ type historicalRoundRequest struct { // Long running thread which process historical rounds // Can be killed by sending a signal to the quit channel // takes a comms interface to aid in testing -func (m *Manager) processHistoricalRounds(comm historicalRoundsComms, quitCh <-chan struct{}) { +func (m *Manager) processHistoricalRounds(comm historicalRoundsComms, stop *stoppable.Single) { timerCh := make(<-chan time.Time) rng := m.Rng.GetStream() var roundRequests []historicalRoundRequest - done := false - for !done { + for { shouldProcess := false // wait for a quit or new round to check select { - case <-quitCh: + case <-stop.Quit(): rng.Close() // return all roundRequests in the queue to the input channel so they can // be checked in the future. If the queue is full, disable them as @@ -64,7 +64,8 @@ func (m *Manager) processHistoricalRounds(comm historicalRoundsComms, quitCh <-c default: } } - done = true + stop.ToStopped() + return // if the timer elapses process roundRequests to ensure the delay isn't too long case <-timerCh: if len(roundRequests) > 0 { @@ -100,7 +101,7 @@ func (m *Manager) processHistoricalRounds(comm historicalRoundsComms, quitCh <-c jww.DEBUG.Printf("Requesting Historical rounds %v from "+ "gateway %s", rounds, host.GetId()) return comm.RequestHistoricalRounds(host, hr) - }) + }, stop) if err != nil { jww.ERROR.Printf("Failed to request historical roundRequests "+ diff --git a/network/rounds/manager.go b/network/rounds/manager.go index c488e39f3b2dc2311166b99d21b680dd3d0b8028..c695cd7a8c78be622769a4388a669eda5757b304 100644 --- a/network/rounds/manager.go +++ b/network/rounds/manager.go @@ -8,17 +8,16 @@ package rounds import ( - "fmt" "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/network/internal" "gitlab.com/elixxir/client/network/message" "gitlab.com/elixxir/client/stoppable" + "strconv" ) type Manager struct { params params.Rounds - internal.Internal sender *gateway.Sender @@ -48,19 +47,19 @@ func (m *Manager) StartProcessors() stoppable.Stoppable { //start the historical rounds thread historicalRoundsStopper := stoppable.NewSingle("ProcessHistoricalRounds") - go m.processHistoricalRounds(m.Comms, historicalRoundsStopper.Quit()) + go m.processHistoricalRounds(m.Comms, historicalRoundsStopper) multi.Add(historicalRoundsStopper) //start the message retrieval worker pool for i := uint(0); i < m.params.NumMessageRetrievalWorkers; i++ { - stopper := stoppable.NewSingle(fmt.Sprintf("Messager Retriever %v", i)) - go m.processMessageRetrieval(m.Comms, stopper.Quit()) + stopper := stoppable.NewSingle("Message Retriever " + strconv.Itoa(int(i))) + go m.processMessageRetrieval(m.Comms, stopper) multi.Add(stopper) } // Start the periodic unchecked round worker stopper := stoppable.NewSingle("UncheckRound") - go m.processUncheckedRounds(m.params.UncheckRoundPeriod, backOffTable, stopper.Quit()) + go m.processUncheckedRounds(m.params.UncheckRoundPeriod, backOffTable, stopper) multi.Add(stopper) return multi diff --git a/network/rounds/retrieve.go b/network/rounds/retrieve.go index 8b5029f07322d90ea11fa64a3ed0f4ba49ed33d0..caf5ce0e70d2b688a2045feb8c0dd21bcecdee6a 100644 --- a/network/rounds/retrieve.go +++ b/network/rounds/retrieve.go @@ -12,6 +12,7 @@ import ( "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/network/message" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage/reception" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/crypto/shuffle" @@ -37,20 +38,20 @@ const noRoundError = "does not have round %d" // processMessageRetrieval received a roundLookup request and pings the gateways // of that round for messages for the requested identity in the roundLookup func (m *Manager) processMessageRetrieval(comms messageRetrievalComms, - quitCh <-chan struct{}) { + stop *stoppable.Single) { - done := false - for !done { + for { select { - case <-quitCh: - done = true + case <-stop.Quit(): + stop.ToStopped() + return case rl := <-m.lookupRoundMessages: - // wrap this around the pickup logic ri := rl.roundInfo + jww.DEBUG.Printf("Checking for messages in round %d", ri.ID) err := m.Session.UncheckedRounds().AddRound(rl.roundInfo, rl.identity.EphId, rl.identity.Source) if err != nil { - jww.ERROR.Printf("Could not find round %d in unchecked rounds store: %v", + jww.ERROR.Printf("Could not add round %d in unchecked rounds store: %v", rl.roundInfo.ID, err) } @@ -81,11 +82,17 @@ func (m *Manager) processMessageRetrieval(comms messageRetrievalComms, }) // If ForceMessagePickupRetry, we are forcing processUncheckedRounds by - // randomly not picking up messages + // randomly not picking up messages (FOR INTEGRATION TEST). Only done if + // round has not been ignored before var bundle message.Bundle if m.params.ForceMessagePickupRetry { - jww.INFO.Printf("Forcing message pickup retry for round %d", ri.ID) - bundle, err = m.forceMessagePickupRetry(ri, rl, comms, gwIds) + bundle, err = m.forceMessagePickupRetry(ri, rl, comms, gwIds, stop) + + // Exit if the thread has been stopped + if stoppable.CheckErr(err) { + jww.ERROR.Print(err) + continue + } if err != nil { jww.ERROR.Printf("Failed to get pickup round %d "+ "from all gateways (%v): %s", @@ -93,7 +100,14 @@ func (m *Manager) processMessageRetrieval(comms messageRetrievalComms, } } else { // Attempt to request for this gateway - bundle, err = m.getMessagesFromGateway(id.Round(ri.ID), rl.identity, comms, gwIds) + bundle, err = m.getMessagesFromGateway(id.Round(ri.ID), rl.identity, comms, gwIds, stop) + + // Exit if the thread has been stopped + if stoppable.CheckErr(err) { + jww.ERROR.Print(err) + continue + } + // After trying all gateways, if none returned we mark the round as a // failure and print out the last error if err != nil { @@ -105,6 +119,7 @@ func (m *Manager) processMessageRetrieval(comms messageRetrievalComms, } if len(bundle.Messages) != 0 { + jww.DEBUG.Printf("Removing round %d from unchecked store", ri.ID) err = m.Session.UncheckedRounds().Remove(id.Round(ri.ID)) if err != nil { jww.ERROR.Printf("Could not remove round %d "+ @@ -113,6 +128,7 @@ func (m *Manager) processMessageRetrieval(comms messageRetrievalComms, // If successful and there are messages, we send them to another thread bundle.Identity = rl.identity + bundle.RoundInfo = rl.roundInfo m.messageBundles <- bundle } @@ -122,8 +138,9 @@ func (m *Manager) processMessageRetrieval(comms messageRetrievalComms, // getMessagesFromGateway attempts to get messages from their assigned // gateway host in the round specified. If successful -func (m *Manager) getMessagesFromGateway(roundID id.Round, identity reception.IdentityUse, - comms messageRetrievalComms, gwIds []*id.ID) (message.Bundle, error) { +func (m *Manager) getMessagesFromGateway(roundID id.Round, + identity reception.IdentityUse, comms messageRetrievalComms, gwIds []*id.ID, + stop *stoppable.Single) (message.Bundle, error) { start := time.Now() // Send to the gateways using backup proxies result, err := m.sender.SendToPreferred(gwIds, func(host *connect.Host, target *id.ID) (interface{}, bool, error) { @@ -140,12 +157,13 @@ func (m *Manager) getMessagesFromGateway(roundID id.Round, identity reception.Id // If the gateway doesnt have the round, return an error msgResp, err := comms.RequestMessages(host, msgReq) if err == nil && !msgResp.GetHasRound() { + jww.INFO.Printf("No round error for round %d received from %s", roundID, target) return message.Bundle{}, false, errors.Errorf(noRoundError, roundID) } return msgResp, false, err - }) - + }, stop) + jww.INFO.Printf("Received message for round %d, processing...", roundID) // Fail the round if an error occurs so it can be tried again later if err != nil { return message.Bundle{}, errors.WithMessagef(err, "Failed to "+ @@ -190,22 +208,28 @@ func (m *Manager) getMessagesFromGateway(roundID id.Round, identity reception.Id // Helper function which forces processUncheckedRounds by randomly // not looking up messages func (m *Manager) forceMessagePickupRetry(ri *pb.RoundInfo, rl roundLookup, - comms messageRetrievalComms, gwIds []*id.ID) (bundle message.Bundle, err error) { - // Flip a coin to determine whether to pick up message - stream := m.Rng.GetStream() - defer stream.Close() - b := make([]byte, 8) - _, err = stream.Read(b) - if err != nil { - jww.FATAL.Panic(err.Error()) - } - result := binary.BigEndian.Uint64(b) - if result%2 == 0 { - // Do not call get message, leaving the round to be picked up - // in unchecked round scheduler process - return + comms messageRetrievalComms, gwIds []*id.ID, + stop *stoppable.Single) (bundle message.Bundle, err error) { + rnd, _ := m.Session.UncheckedRounds().GetRound(id.Round(ri.ID)) + if rnd.NumChecks == 0 { + // Flip a coin to determine whether to pick up message + stream := m.Rng.GetStream() + defer stream.Close() + b := make([]byte, 8) + _, err = stream.Read(b) + if err != nil { + jww.FATAL.Panic(err.Error()) + } + result := binary.BigEndian.Uint64(b) + if result%2 == 0 { + jww.INFO.Printf("Forcing a message pickup retry for round %d", ri.ID) + // Do not call get message, leaving the round to be picked up + // in unchecked round scheduler process + return + } + } // Attempt to request for this gateway - return m.getMessagesFromGateway(id.Round(ri.ID), rl.identity, comms, gwIds) + return m.getMessagesFromGateway(id.Round(ri.ID), rl.identity, comms, gwIds, stop) } diff --git a/network/rounds/retrieve_test.go b/network/rounds/retrieve_test.go index ae54730119dd1c441116ba2187985b078caf6c84..856fa05fa3e9b911f6180b0ed368031993c907a1 100644 --- a/network/rounds/retrieve_test.go +++ b/network/rounds/retrieve_test.go @@ -10,6 +10,7 @@ import ( "bytes" "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/network/message" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage/reception" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/crypto/fastRNG" @@ -28,7 +29,7 @@ func TestManager_ProcessMessageRetrieval(t *testing.T) { testManager := newManager(t) roundId := id.Round(5) mockComms := &mockMessageRetrievalComms{testingSignature: t} - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") testNdf := getNDF() nodeId := id.NewIdFromString(ReturningGateway, id.Node, &testing.T{}) gwId := nodeId.DeepCopy() @@ -51,7 +52,7 @@ func TestManager_ProcessMessageRetrieval(t *testing.T) { testManager.messageBundles = messageBundleChan // Initialize the message retrieval - go testManager.processMessageRetrieval(mockComms, quitChan) + go testManager.processMessageRetrieval(mockComms, stop) // Construct expected values for checking expectedEphID := ephemeral.Id{1, 2, 3, 4, 5, 6, 7, 8} @@ -92,8 +93,10 @@ func TestManager_ProcessMessageRetrieval(t *testing.T) { testBundle = <-messageBundleChan // Close the process - quitChan <- struct{}{} - + err := stop.Close() + if err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } }() // Ensure bundle received and has expected values @@ -136,7 +139,7 @@ func TestManager_ProcessMessageRetrieval_NoRound(t *testing.T) { testManager.sender, _ = gateway.NewSender(p, testManager.Rng, testNdf, mockComms, testManager.Session, nil) - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") // Create a local channel so reception is possible (testManager.messageBundles is // send only via newManager call above) @@ -144,7 +147,7 @@ func TestManager_ProcessMessageRetrieval_NoRound(t *testing.T) { testManager.messageBundles = messageBundleChan // Initialize the message retrieval - go testManager.processMessageRetrieval(mockComms, quitChan) + go testManager.processMessageRetrieval(mockComms, stop) expectedEphID := ephemeral.Id{1, 2, 3, 4, 5, 6, 7, 8} @@ -183,8 +186,9 @@ func TestManager_ProcessMessageRetrieval_NoRound(t *testing.T) { testBundle = <-messageBundleChan // Close the process - quitChan <- struct{}{} - + if err := stop.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } }() time.Sleep(2 * time.Second) @@ -202,7 +206,7 @@ func TestManager_ProcessMessageRetrieval_FalsePositive(t *testing.T) { testManager := newManager(t) roundId := id.Round(5) mockComms := &mockMessageRetrievalComms{testingSignature: t} - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") testNdf := getNDF() nodeId := id.NewIdFromString(FalsePositive, id.Node, &testing.T{}) gwId := nodeId.DeepCopy() @@ -222,7 +226,7 @@ func TestManager_ProcessMessageRetrieval_FalsePositive(t *testing.T) { testManager.messageBundles = messageBundleChan // Initialize the message retrieval - go testManager.processMessageRetrieval(mockComms, quitChan) + go testManager.processMessageRetrieval(mockComms, stop) // Construct expected values for checking expectedEphID := ephemeral.Id{1, 2, 3, 4, 5, 6, 7, 8} @@ -263,8 +267,9 @@ func TestManager_ProcessMessageRetrieval_FalsePositive(t *testing.T) { testBundle = <-messageBundleChan // Close the process - quitChan <- struct{}{} - + if err := stop.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } }() // Ensure no bundle was received due to false positive test @@ -282,7 +287,7 @@ func TestManager_ProcessMessageRetrieval_Quit(t *testing.T) { testManager := newManager(t) roundId := id.Round(5) mockComms := &mockMessageRetrievalComms{testingSignature: t} - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") // Create a local channel so reception is possible (testManager.messageBundles is // send only via newManager call above) @@ -290,10 +295,16 @@ func TestManager_ProcessMessageRetrieval_Quit(t *testing.T) { testManager.messageBundles = messageBundleChan // Initialize the message retrieval - go testManager.processMessageRetrieval(mockComms, quitChan) + go testManager.processMessageRetrieval(mockComms, stop) // Close the process early, before any logic below can be completed - quitChan <- struct{}{} + if err := stop.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } + + if err := stoppable.WaitForStopped(stop, 300*time.Millisecond); err != nil { + t.Fatalf("Failed to stop stoppable: %+v", err) + } // Construct expected values for checking expectedEphID := ephemeral.Id{1, 2, 3, 4, 5, 6, 7, 8} @@ -348,7 +359,7 @@ func TestManager_ProcessMessageRetrieval_MultipleGateways(t *testing.T) { testManager := newManager(t) roundId := id.Round(5) mockComms := &mockMessageRetrievalComms{testingSignature: t} - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") testNdf := getNDF() nodeId := id.NewIdFromString(ReturningGateway, id.Node, &testing.T{}) gwId := nodeId.DeepCopy() @@ -368,7 +379,7 @@ func TestManager_ProcessMessageRetrieval_MultipleGateways(t *testing.T) { testManager.messageBundles = messageBundleChan // Initialize the message retrieval - go testManager.processMessageRetrieval(mockComms, quitChan) + go testManager.processMessageRetrieval(mockComms, stop) // Construct expected values for checking expectedEphID := ephemeral.Id{1, 2, 3, 4, 5, 6, 7, 8} @@ -410,8 +421,9 @@ func TestManager_ProcessMessageRetrieval_MultipleGateways(t *testing.T) { testBundle = <-messageBundleChan // Close the process - quitChan <- struct{}{} - + if err := stop.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } }() // Ensure that expected bundle is still received from happy comm diff --git a/network/rounds/unchecked.go b/network/rounds/unchecked.go index df2f6c85875b50bf85c1cbc1254d4d8337a4b9dd..e62bff0c6885d71080b4544cf6602ec684fa2a9a 100644 --- a/network/rounds/unchecked.go +++ b/network/rounds/unchecked.go @@ -9,6 +9,7 @@ package rounds import ( jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage/reception" "gitlab.com/xx_network/primitives/netTime" "time" @@ -36,14 +37,14 @@ var backOffTable = [cappedTries]time.Duration{tryZero, tryOne, tryTwo, tryThree, // If a round is found to be due on a periodical check, the round is sent // back to processMessageRetrieval. func (m *Manager) processUncheckedRounds(checkInterval time.Duration, backoffTable [cappedTries]time.Duration, - quitCh <-chan struct{}) { + stop *stoppable.Single) { ticker := time.NewTicker(checkInterval) uncheckedRoundStore := m.Session.UncheckedRounds() - done := false - for !done { + for { select { - case <-quitCh: - done = true + case <-stop.Quit(): + stop.ToStopped() + return case <-ticker.C: // Pull and iterate through uncheckedRound list @@ -67,7 +68,8 @@ func (m *Manager) processUncheckedRounds(checkInterval time.Duration, backoffTab // Send to processMessageRetrieval select { case m.lookupRoundMessages <- rl: - case <-time.After(500 * time.Second): + case <-time.After(1 * time.Second): + jww.WARN.Printf("Timing out, not retrying round %d", rl.roundInfo.ID) } // Update the state of the round for next look-up (if needed) diff --git a/network/rounds/unchecked_test.go b/network/rounds/unchecked_test.go index cc66c02712acbffae509eeb7270cc88463e80c58..980ca1e6157a583cc239cab4332e358c5b04e84a 100644 --- a/network/rounds/unchecked_test.go +++ b/network/rounds/unchecked_test.go @@ -10,6 +10,7 @@ package rounds import ( "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/network/message" + "gitlab.com/elixxir/client/stoppable" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/crypto/fastRNG" "gitlab.com/xx_network/crypto/csprng" @@ -27,7 +28,8 @@ func TestUncheckedRoundScheduler(t *testing.T) { testManager := newManager(t) roundId := id.Round(5) mockComms := &mockMessageRetrievalComms{testingSignature: t} - quitChan := make(chan struct{}) + stop1 := stoppable.NewSingle("singleStoppable1") + stop2 := stoppable.NewSingle("singleStoppable2") testNdf := getNDF() nodeId := id.NewIdFromString(ReturningGateway, id.Node, &testing.T{}) gwId := nodeId.DeepCopy() @@ -47,8 +49,8 @@ func TestUncheckedRoundScheduler(t *testing.T) { testBackoffTable := newTestBackoffTable(t) checkInterval := 250 * time.Millisecond // Initialize the message retrieval - go testManager.processMessageRetrieval(mockComms, quitChan) - go testManager.processUncheckedRounds(checkInterval, testBackoffTable, quitChan) + go testManager.processMessageRetrieval(mockComms, stop1) + go testManager.processUncheckedRounds(checkInterval, testBackoffTable, stop2) requestGateway := id.NewIdFromString(ReturningGateway, id.Gateway, t) @@ -73,7 +75,12 @@ func TestUncheckedRoundScheduler(t *testing.T) { testBundle = <-messageBundleChan // Close the process - quitChan <- struct{}{} + if err := stop1.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } + if err := stop2.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } }() diff --git a/network/send.go b/network/send.go index d70a0c6661c2ee6f4c40f313219baa7cbb86fd52..d54e478c7c705f4995c451b1e302e5ae77d292f7 100644 --- a/network/send.go +++ b/network/send.go @@ -12,6 +12,7 @@ import ( jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/crypto/e2e" "gitlab.com/elixxir/primitives/format" "gitlab.com/xx_network/primitives/id" @@ -28,7 +29,16 @@ func (m *manager) SendCMIX(msg format.Message, recipient *id.ID, param params.CM "network is not healthy") } - return m.message.SendCMIX(m.GetSender(), msg, recipient, param) + return m.message.SendCMIX(m.GetSender(), msg, recipient, param, nil) +} + +// SendManyCMIX sends many "raw" CMIX message payloads to each of the +// provided recipients. Used for group chat functionality. Returns the +// round ID of the round the payload was sent or an error if it fails. +func (m *manager) SendManyCMIX(messages map[id.ID]format.Message, + p params.CMIX) (id.Round, []ephemeral.Id, error) { + + return m.message.SendManyCMIX(m.sender, messages, p) } // SendUnsafe sends an unencrypted payload to the provided recipient @@ -52,7 +62,7 @@ func (m *manager) SendUnsafe(msg message.Send, param params.Unsafe) ([]id.Round, // SendE2E sends an end-to-end payload to the provided recipient with // the provided msgType. Returns the list of rounds in which parts of // the message were sent or an error if it fails. -func (m *manager) SendE2E(msg message.Send, e2eP params.E2E) ( +func (m *manager) SendE2E(msg message.Send, e2eP params.E2E, stop *stoppable.Single) ( []id.Round, e2e.MessageID, error) { if !m.Health.IsHealthy() { @@ -60,5 +70,5 @@ func (m *manager) SendE2E(msg message.Send, e2eP params.E2E) ( "message when the network is not healthy") } - return m.message.SendE2E(msg, e2eP) + return m.message.SendE2E(msg, e2eP, stop) } diff --git a/single/manager.go b/single/manager.go index 78481829180017f242ca45e43b2bacc4175cf4dd..063f7e8541e0f8f3f9063228fbc33d618247e2b7 100644 --- a/single/manager.go +++ b/single/manager.go @@ -74,13 +74,13 @@ func (m *Manager) StartProcesses() stoppable.Stoppable { transmissionStop := stoppable.NewSingle(singleUseTransmission) transmissionChan := make(chan message.Receive, rawMessageBuffSize) m.swb.RegisterChannel(singleUseReceiveTransmission, &id.ID{}, message.Raw, transmissionChan) - go m.receiveTransmissionHandler(transmissionChan, transmissionStop.Quit()) + go m.receiveTransmissionHandler(transmissionChan, transmissionStop) // Start waiting for single-use response responseStop := stoppable.NewSingle(singleUseResponse) responseChan := make(chan message.Receive, rawMessageBuffSize) m.swb.RegisterChannel(singleUseReceiveResponse, &id.ID{}, message.Raw, responseChan) - go m.receiveResponseHandler(responseChan, responseStop.Quit()) + go m.receiveResponseHandler(responseChan, responseStop) // Create a multi stoppable singleUseMulti := stoppable.NewMulti(singleUseStop) diff --git a/single/manager_test.go b/single/manager_test.go index a4384b166322df3206937ff58d56187aed0a6f52..4d570c35e900e2f36d1ee77bb76a298e927f09a7 100644 --- a/single/manager_test.go +++ b/single/manager_test.go @@ -181,7 +181,7 @@ func TestManager_StartProcesses_Stop(t *testing.T) { t.Error("Stoppable is not running.") } - err = stop.Close(1 * time.Millisecond) + err = stop.Close() if err != nil { t.Errorf("Failed to close: %+v", err) } @@ -283,8 +283,8 @@ func (tnm *testNetworkManager) GetMsg(i int) format.Message { return tnm.msgs[i] } -func (tnm *testNetworkManager) SendE2E(_ message.Send, _ params.E2E) ([]id.Round, e2e.MessageID, error) { - return nil, [32]byte{}, nil +func (tnm *testNetworkManager) SendE2E(message.Send, params.E2E, *stoppable.Single) ([]id.Round, e2e.MessageID, error) { + return nil, e2e.MessageID{}, nil } func (tnm *testNetworkManager) SendUnsafe(_ message.Send, _ params.Unsafe) ([]id.Round, error) { @@ -306,6 +306,23 @@ func (tnm *testNetworkManager) SendCMIX(msg format.Message, _ *id.ID, _ params.C return id.Round(rand.Uint64()), ephemeral.Id{}, nil } +func (tnm *testNetworkManager) SendManyCMIX(messages map[id.ID]format.Message, p params.CMIX) (id.Round, []ephemeral.Id, error) { + if tnm.cmixTimeout != 0 { + time.Sleep(tnm.cmixTimeout) + } else if tnm.cmixErr { + return 0, []ephemeral.Id{}, errors.New("sendCMIX error") + } + + tnm.Lock() + defer tnm.Unlock() + + for _, msg := range messages { + tnm.msgs = append(tnm.msgs, msg) + } + + return id.Round(rand.Uint64()), []ephemeral.Id{}, nil +} + func (tnm *testNetworkManager) GetInstance() *network.Instance { return tnm.instance } diff --git a/single/receiveResponse.go b/single/receiveResponse.go index d25880dae86dee70d882463aeba08d0310daa08b..6a6fa94fddbc10f71a6e311fea944a71ea8f35eb 100644 --- a/single/receiveResponse.go +++ b/single/receiveResponse.go @@ -11,6 +11,7 @@ import ( "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/crypto/e2e/auth" "gitlab.com/elixxir/crypto/e2e/singleUse" "gitlab.com/elixxir/primitives/format" @@ -20,13 +21,14 @@ import ( // receiveResponseHandler handles the reception of single-use response messages. func (m *Manager) receiveResponseHandler(rawMessages chan message.Receive, - quitChan <-chan struct{}) { + stop *stoppable.Single) { jww.DEBUG.Print("Waiting to receive single-use response messages.") for { select { - case <-quitChan: + case <-stop.Quit(): jww.DEBUG.Printf("Stopping waiting to receive single-use " + "response message.") + stop.ToStopped() return case msg := <-rawMessages: jww.DEBUG.Printf("Received CMIX message; checking if it is a " + diff --git a/single/receiveResponse_test.go b/single/receiveResponse_test.go index 4e7e8c352e8196f0ff24d5a0d60633f78439e8dd..ea75c5e3ec10cb0af1971af162b4e0fc69b78871 100644 --- a/single/receiveResponse_test.go +++ b/single/receiveResponse_test.go @@ -10,6 +10,7 @@ package single import ( "bytes" "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/crypto/e2e/auth" "gitlab.com/elixxir/crypto/e2e/singleUse" "gitlab.com/elixxir/primitives/format" @@ -26,7 +27,7 @@ import ( func TestManager_ReceiveResponseHandler(t *testing.T) { m := newTestManager(0, false, t) rawMessages := make(chan message.Receive, rawMessageBuffSize) - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") partner := NewContact(id.NewIdFromString("recipientID", id.User, t), m.store.E2e().GetGroup().NewInt(43), m.store.E2e().GetGroup().NewInt(42), singleUse.TagFP{}, 8) @@ -52,7 +53,7 @@ func TestManager_ReceiveResponseHandler(t *testing.T) { } }() - go m.receiveResponseHandler(rawMessages, quitChan) + go m.receiveResponseHandler(rawMessages, stop) for _, msg := range msgs { rawMessages <- message.Receive{ @@ -78,14 +79,16 @@ func TestManager_ReceiveResponseHandler(t *testing.T) { t.Errorf("Callback failed to be called.") } - quitChan <- struct{}{} + if err := stop.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } } // Error path: invalid CMIX message. func TestManager_ReceiveResponseHandler_CmixMessageError(t *testing.T) { m := newTestManager(0, false, t) rawMessages := make(chan message.Receive, rawMessageBuffSize) - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") partner := NewContact(id.NewIdFromString("recipientID", id.User, t), m.store.E2e().GetGroup().NewInt(43), m.store.E2e().GetGroup().NewInt(42), singleUse.TagFP{}, 8) @@ -106,7 +109,7 @@ func TestManager_ReceiveResponseHandler_CmixMessageError(t *testing.T) { } }() - go m.receiveResponseHandler(rawMessages, quitChan) + go m.receiveResponseHandler(rawMessages, stop) rawMessages <- message.Receive{ Payload: make([]byte, format.MinimumPrimeSize*2), @@ -124,7 +127,9 @@ func TestManager_ReceiveResponseHandler_CmixMessageError(t *testing.T) { case <-timer.C: } - quitChan <- struct{}{} + if err := stop.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } } // Happy path. diff --git a/single/reception.go b/single/reception.go index 53b6c12eda52386a3544b547bee0b54b91bc1d2a..23ca6f8bef891b8ae4426fe24172d4a88c2510c5 100644 --- a/single/reception.go +++ b/single/reception.go @@ -11,6 +11,7 @@ import ( "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" cAuth "gitlab.com/elixxir/crypto/e2e/auth" "gitlab.com/elixxir/crypto/e2e/singleUse" "gitlab.com/elixxir/primitives/format" @@ -19,14 +20,15 @@ import ( // receiveTransmissionHandler waits to receive single-use transmissions. When // a message is received, its is returned via its registered callback. func (m *Manager) receiveTransmissionHandler(rawMessages chan message.Receive, - quitChan <-chan struct{}) { + stop *stoppable.Single) { fp := singleUse.NewTransmitFingerprint(m.store.E2e().GetDHPublicKey()) jww.DEBUG.Print("Waiting to receive single-use transmission messages.") for { select { - case <-quitChan: + case <-stop.Quit(): jww.DEBUG.Printf("Stopping waiting to receive single-use " + "transmission message.") + stop.ToStopped() return case msg := <-rawMessages: jww.DEBUG.Printf("Received CMIX message; checking if it is a " + diff --git a/single/reception_test.go b/single/reception_test.go index 3266d9904b2fb51dd592c766b61a9c40409e0925..27a87b9a181e94437b8a3b471492123c2451547b 100644 --- a/single/reception_test.go +++ b/single/reception_test.go @@ -3,6 +3,7 @@ package single import ( "bytes" "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" contact2 "gitlab.com/elixxir/crypto/contact" "gitlab.com/elixxir/crypto/e2e/singleUse" "gitlab.com/elixxir/primitives/format" @@ -17,7 +18,6 @@ import ( func TestManager_receiveTransmissionHandler(t *testing.T) { m := newTestManager(0, false, t) rawMessages := make(chan message.Receive, rawMessageBuffSize) - quitChan := make(chan struct{}) partner := contact2.Contact{ ID: id.NewIdFromString("recipientID", id.User, t), DhPubKey: m.store.E2e().GetDHPublicKey(), @@ -35,7 +35,7 @@ func TestManager_receiveTransmissionHandler(t *testing.T) { m.callbackMap.registerCallback(tag, callback) - go m.receiveTransmissionHandler(rawMessages, quitChan) + go m.receiveTransmissionHandler(rawMessages, stoppable.NewSingle("singleStoppable")) rawMessages <- message.Receive{ Payload: msg.Marshal(), } @@ -57,7 +57,7 @@ func TestManager_receiveTransmissionHandler(t *testing.T) { func TestManager_receiveTransmissionHandler_QuitChan(t *testing.T) { m := newTestManager(0, false, t) rawMessages := make(chan message.Receive, rawMessageBuffSize) - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") tag := "Test tag" payload := make([]byte, 132) rand.New(rand.NewSource(42)).Read(payload) @@ -65,8 +65,11 @@ func TestManager_receiveTransmissionHandler_QuitChan(t *testing.T) { m.callbackMap.registerCallback(tag, callback) - go m.receiveTransmissionHandler(rawMessages, quitChan) - quitChan <- struct{}{} + go m.receiveTransmissionHandler(rawMessages, stop) + + if err := stop.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } timer := time.NewTimer(50 * time.Millisecond) @@ -82,7 +85,7 @@ func TestManager_receiveTransmissionHandler_QuitChan(t *testing.T) { func TestManager_receiveTransmissionHandler_FingerPrintError(t *testing.T) { m := newTestManager(0, false, t) rawMessages := make(chan message.Receive, rawMessageBuffSize) - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") partner := contact2.Contact{ ID: id.NewIdFromString("recipientID", id.User, t), DhPubKey: m.store.E2e().GetGroup().NewInt(42), @@ -100,7 +103,7 @@ func TestManager_receiveTransmissionHandler_FingerPrintError(t *testing.T) { m.callbackMap.registerCallback(tag, callback) - go m.receiveTransmissionHandler(rawMessages, quitChan) + go m.receiveTransmissionHandler(rawMessages, stop) rawMessages <- message.Receive{ Payload: msg.Marshal(), } @@ -119,7 +122,7 @@ func TestManager_receiveTransmissionHandler_FingerPrintError(t *testing.T) { func TestManager_receiveTransmissionHandler_ProcessMessageError(t *testing.T) { m := newTestManager(0, false, t) rawMessages := make(chan message.Receive, rawMessageBuffSize) - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") partner := contact2.Contact{ ID: id.NewIdFromString("recipientID", id.User, t), DhPubKey: m.store.E2e().GetDHPublicKey(), @@ -139,7 +142,7 @@ func TestManager_receiveTransmissionHandler_ProcessMessageError(t *testing.T) { m.callbackMap.registerCallback(tag, callback) - go m.receiveTransmissionHandler(rawMessages, quitChan) + go m.receiveTransmissionHandler(rawMessages, stop) rawMessages <- message.Receive{ Payload: msg.Marshal(), } @@ -158,7 +161,7 @@ func TestManager_receiveTransmissionHandler_ProcessMessageError(t *testing.T) { func TestManager_receiveTransmissionHandler_TagFpError(t *testing.T) { m := newTestManager(0, false, t) rawMessages := make(chan message.Receive, rawMessageBuffSize) - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") partner := contact2.Contact{ ID: id.NewIdFromString("recipientID", id.User, t), DhPubKey: m.store.E2e().GetDHPublicKey(), @@ -173,7 +176,7 @@ func TestManager_receiveTransmissionHandler_TagFpError(t *testing.T) { t.Fatalf("Failed to create tranmission CMIX message: %+v", err) } - go m.receiveTransmissionHandler(rawMessages, quitChan) + go m.receiveTransmissionHandler(rawMessages, stop) rawMessages <- message.Receive{ Payload: msg.Marshal(), } diff --git a/single/singleUseMap_test.go b/single/singleUseMap_test.go index 3b4a18171640935735b5d3b28042141c77ded768..1f91c755f048fa2a2444109a521931b21cdc6867 100644 --- a/single/singleUseMap_test.go +++ b/single/singleUseMap_test.go @@ -117,7 +117,7 @@ func Test_pending_addState_TimeoutError(t *testing.T) { *expectedState, *state) } - timer := time.NewTimer(timeout * 2) + timerTimeout := timeout * 4 select { case results := <-callbackChan: @@ -132,8 +132,8 @@ func Test_pending_addState_TimeoutError(t *testing.T) { if results.err == nil || !strings.Contains(results.err.Error(), "timed out") { t.Errorf("Callback did not return a time out error on return: %+v", results.err) } - case <-timer.C: - t.Error("Failed to time out.") + case <-time.NewTimer(timerTimeout).C: + t.Errorf("Failed to time out after %s.", timerTimeout) } } diff --git a/stoppable/bindings.go b/stoppable/bindings.go deleted file mode 100644 index 55784d6522bcb0567d8eb86b282bb9895302ab57..0000000000000000000000000000000000000000 --- a/stoppable/bindings.go +++ /dev/null @@ -1,37 +0,0 @@ -/////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -/////////////////////////////////////////////////////////////////////////////// - -package stoppable - -import "time" - -type Bindings interface { - Close(timeoutMS int) error - IsRunning() bool - Name() string -} - -func WrapForBindings(s Stoppable) Bindings { - return &bindingsStoppable{s: s} -} - -type bindingsStoppable struct { - s Stoppable -} - -func (bs *bindingsStoppable) Close(timeoutMS int) error { - timeout := time.Duration(timeoutMS) * time.Millisecond - return bs.s.Close(timeout) -} - -func (bs *bindingsStoppable) IsRunning() bool { - return bs.s.IsRunning() -} - -func (bs *bindingsStoppable) Name() string { - return bs.s.Name() -} diff --git a/stoppable/cleanup.go b/stoppable/cleanup.go deleted file mode 100644 index 0a84235087dfc0d420c8695679f036294ac9460c..0000000000000000000000000000000000000000 --- a/stoppable/cleanup.go +++ /dev/null @@ -1,99 +0,0 @@ -/////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -/////////////////////////////////////////////////////////////////////////////// - -package stoppable - -import ( - "github.com/pkg/errors" - jww "github.com/spf13/jwalterweatherman" - "gitlab.com/xx_network/primitives/netTime" - "sync" - "sync/atomic" - "time" -) - -const nameTag = " with cleanup" - -type CleanFunc func(duration time.Duration) error - -// Cleanup wraps any stoppable and runs a callback after to stop for cleanup -// behavior. The cleanup is run under the remainder of the timeout but will not -// be canceled if the timeout runs out. The cleanup function does not run if the -// thread does not stop. -type Cleanup struct { - stop Stoppable - // clean receives how long it has to run before the timeout, this is not - // expected to be used in most cases - clean CleanFunc - running uint32 - once sync.Once -} - -// NewCleanup creates a new Cleanup from the passed in stoppable and clean -// function. -func NewCleanup(stop Stoppable, clean CleanFunc) *Cleanup { - return &Cleanup{ - stop: stop, - clean: clean, - running: stopped, - } -} - -// IsRunning returns true if the thread is still running and its cleanup has -// completed. -func (c *Cleanup) IsRunning() bool { - return atomic.LoadUint32(&c.running) == running -} - -// Name returns the name of the stoppable denoting it has cleanup. -func (c *Cleanup) Name() string { - return c.stop.Name() + nameTag -} - -// Close stops the wrapped stoppable and after, runs the cleanup function. The -// cleanup function does not run if the thread fails to stop. -func (c *Cleanup) Close(timeout time.Duration) error { - var err error - - c.once.Do( - func() { - defer atomic.StoreUint32(&c.running, stopped) - start := netTime.Now() - - // Close each stoppable - if err := c.stop.Close(timeout); err != nil { - err = errors.WithMessagef(err, "Cleanup not executed for %s", - c.stop.Name()) - return - } - - // Run the cleanup function with the remaining time as a timeout - elapsed := netTime.Now().Sub(start) - - complete := make(chan error, 1) - go func() { - complete <- c.clean(elapsed) - }() - - select { - case err := <-complete: - if err != nil { - err = errors.WithMessagef(err, "Cleanup for %s failed", - c.stop.Name()) - } - case <-time.NewTimer(elapsed).C: - err = errors.Errorf("Clean up for %s timed out after %s", - c.stop.Name(), elapsed) - } - }) - - if err != nil { - jww.ERROR.Printf(err.Error()) - } - - return err -} diff --git a/stoppable/cleanup_test.go b/stoppable/cleanup_test.go deleted file mode 100644 index bc60be5def7b0a6296711262c456da591adfd08e..0000000000000000000000000000000000000000 --- a/stoppable/cleanup_test.go +++ /dev/null @@ -1,78 +0,0 @@ -/////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -/////////////////////////////////////////////////////////////////////////////// - -package stoppable - -import ( - "testing" -) - -// Tests that NewCleanup returns a Cleanup that is stopped with the given -// Stoppable. -func TestNewCleanup(t *testing.T) { - single := NewSingle("testSingle") - cleanup := NewCleanup(single, single.Close) - - if cleanup.stop != single { - t.Errorf("NewCleanup returned cleanup with incorrect Stoppable."+ - "\nexpected: %+v\nreceived: %+v", single, cleanup.stop) - } - - if cleanup.running != stopped { - t.Errorf("NewMulti returned Multi with incorrect running."+ - "\nexpected: %d\nreceived: %d", stopped, cleanup.running) - } -} - -// Tests that Cleanup.IsRunning returns the expected value when the Cleanup is -// marked as both running and not running. -func TestCleanup_IsRunning(t *testing.T) { - single := NewSingle("threadName") - cleanup := NewCleanup(single, single.Close) - - if cleanup.IsRunning() { - t.Errorf("IsRunning returned the wrong value when running."+ - "\nexpected: %t\nreceived: %t", true, cleanup.IsRunning()) - } - - cleanup.running = running - if !single.IsRunning() { - t.Errorf("IsRunning returned the wrong value when running."+ - "\nexpected: %t\nreceived: %t", false, single.IsRunning()) - } -} - -// Unit test of Cleanup.Name. -func TestCleanup_Name(t *testing.T) { - name := "threadName" - single := NewSingle(name) - cleanup := NewCleanup(single, single.Close) - - if name+nameTag != cleanup.Name() { - t.Errorf("Name did not return the expected name."+ - "\nexpected: %s\nreceived: %s", name+nameTag, cleanup.Name()) - } -} - -// Tests happy path of Cleanup.Close(). -func TestCleanup_Close(t *testing.T) { - single := NewSingle("threadName") - cleanup := NewCleanup(single, single.Close) - - // go func() { - // select { - // case <-time.NewTimer(10 * time.Millisecond).C: - // t.Error("Timed out waiting for quit channel.") - // case <-single.Quit(): - // } - // }() - - err := cleanup.Close(0) - if err != nil { - t.Errorf("Close() returned an error: %+v", err) - } -} diff --git a/stoppable/multi.go b/stoppable/multi.go index 5e60ab2d7759d6874f8018a18643a96dc1e7af26..60d1c8530300f2d52babf469d923627f236a6f1d 100644 --- a/stoppable/multi.go +++ b/stoppable/multi.go @@ -13,16 +13,14 @@ import ( "strings" "sync" "sync/atomic" - "time" ) // Error message. -const closeMultiErr = "MultiStopper %s failed to close %d/%d stoppers" +const closeMultiErr = "multi stoppable %q failed to close %d/%d stoppables" type Multi struct { stoppables []Stoppable name string - running uint32 mux sync.RWMutex once sync.Once } @@ -30,17 +28,11 @@ type Multi struct { // NewMulti returns a new multi Stoppable. func NewMulti(name string) *Multi { return &Multi{ - name: name, - running: running, + name: name, } } -// IsRunning returns true if stoppable is marked as running. -func (m *Multi) IsRunning() bool { - return atomic.LoadUint32(&m.running) == running -} - -// Add adds the given stoppable to the list of stoppables. +// Add adds the given Stoppable to the list of stoppables. func (m *Multi) Add(stoppable Stoppable) { m.mux.Lock() m.stoppables = append(m.stoppables, stoppable) @@ -51,47 +43,85 @@ func (m *Multi) Add(stoppable Stoppable) { // it contains. func (m *Multi) Name() string { m.mux.RLock() - defer m.mux.RUnlock() names := make([]string, len(m.stoppables)) for i, s := range m.stoppables { names[i] = s.Name() } - return m.name + ": {" + strings.Join(names, ", ") + "}" + m.mux.RUnlock() + + return m.name + "{" + strings.Join(names, ", ") + "}" +} + +// GetStatus returns the lowest status of all of the Stoppable children. The +// status is not the status of all Stoppables, but the status of the Stoppable +// with the lowest status. +func (m *Multi) GetStatus() Status { + lowestStatus := Stopped + m.mux.RLock() + + for _, s := range m.stoppables { + status := s.GetStatus() + if status < lowestStatus { + lowestStatus = status + } + } + + m.mux.RUnlock() + + return lowestStatus +} + +// IsRunning returns true if Stoppable is marked as running. +func (m *Multi) IsRunning() bool { + return m.GetStatus() == Running +} + +// IsStopping returns true if Stoppable is marked as stopping. +func (m *Multi) IsStopping() bool { + return m.GetStatus() == Stopping +} + +// IsStopped returns true if Stoppable is marked as stopped. +func (m *Multi) IsStopped() bool { + return m.GetStatus() == Stopped } -// Close closes all child stoppers. It does not return their errors and assumes -// they print them to the log. -func (m *Multi) Close(timeout time.Duration) error { - var err error - m.once.Do( - func() { - atomic.StoreUint32(&m.running, stopped) - - var numErrors uint32 - var wg sync.WaitGroup - - m.mux.Lock() - for _, stoppable := range m.stoppables { - wg.Add(1) - go func(stoppable Stoppable) { - if stoppable.Close(timeout) != nil { - atomic.AddUint32(&numErrors, 1) - } - wg.Done() - }(stoppable) - } - m.mux.Unlock() - - wg.Wait() - - if numErrors > 0 { - err = errors.Errorf( - closeMultiErr, m.name, numErrors, len(m.stoppables)) - jww.ERROR.Print(err.Error()) - } - }) - - return err +// Close issues a close signal to all child stoppables and marks the status of +// the Multi Stoppable as stopping. Returns an error if one or more child +// stoppables failed to close but it does not return their specific errors and +// assumes they print them to the log. +func (m *Multi) Close() error { + var numErrors uint32 + + m.once.Do(func() { + var wg sync.WaitGroup + + jww.TRACE.Printf("Sending on quit channel to multi stoppable %q.", + m.Name()) + + m.mux.Lock() + // Attempt to stop each stoppable in its own goroutine + for _, stoppable := range m.stoppables { + wg.Add(1) + go func(stoppable Stoppable) { + if stoppable.Close() != nil { + atomic.AddUint32(&numErrors, 1) + } + wg.Done() + }(stoppable) + } + m.mux.Unlock() + + wg.Wait() + }) + + if numErrors > 0 { + err := errors.Errorf(closeMultiErr, m.name, numErrors, len(m.stoppables)) + jww.ERROR.Print(err.Error()) + return err + } + + return nil } diff --git a/stoppable/multi_test.go b/stoppable/multi_test.go index b633e1d77d4e80317cff8ca938c0ff440e855baa..4a7eaf0873019d7ab9bd89e4f6aac2ba8f1661c9 100644 --- a/stoppable/multi_test.go +++ b/stoppable/multi_test.go @@ -12,6 +12,8 @@ import ( "reflect" "strconv" "strings" + "sync" + "sync/atomic" "testing" "time" ) @@ -25,28 +27,6 @@ func TestNewMulti(t *testing.T) { t.Errorf("NewMulti returned Multi with incorrect name."+ "\nexpected: %s\nreceived: %s", name, multi.name) } - - if multi.running != running { - t.Errorf("NewMulti returned Multi with incorrect running."+ - "\nexpected: %d\nreceived: %d", running, multi.running) - } -} - -// Tests that Multi.IsRunning returns the expected value when the Multi is -// marked as both running and not running. -func TestMulti_IsRunning(t *testing.T) { - multi := NewMulti("testMulti") - - if !multi.IsRunning() { - t.Errorf("IsRunning returned the wrong value when running."+ - "\nexpected: %t\nreceived: %t", true, multi.IsRunning()) - } - - multi.running = stopped - if multi.IsRunning() { - t.Errorf("IsRunning returned the wrong value when not running."+ - "\nexpected: %t\nreceived: %t", false, multi.IsRunning()) - } } // Tests that Multi.Add adds all the stoppables to the list. @@ -93,7 +73,7 @@ func TestMulti_Name(t *testing.T) { nameList = append(nameList, newName) } - expected := name + ": {" + strings.Join(nameList, ", ") + "}" + expected := name + "{" + strings.Join(nameList, ", ") + "}" if multi.Name() != expected { t.Errorf("Name failed to return the expected string."+ @@ -106,7 +86,7 @@ func TestMulti_Name_NoStoppables(t *testing.T) { name := "testMulti" multi := NewMulti(name) - expected := name + ": {" + "}" + expected := name + "{}" if multi.Name() != expected { t.Errorf("Name failed to return the expected string."+ @@ -114,8 +94,136 @@ func TestMulti_Name_NoStoppables(t *testing.T) { } } -// Tests that Multi.Close sends on all Single quit channels. -func TestMulti_Close(t *testing.T) { +// Tests that Multi.GetStatus returns the expected Status. +func TestMulti_GetStatus(t *testing.T) { + multi := NewMulti("testMulti") + single1 := NewSingle("testSingle1") + single2 := NewSingle("testSingle2") + atomic.StoreUint32((*uint32)(&single2.status), uint32(Stopped)) + multi.Add(single1) + multi.Add(single2) + + status := multi.GetStatus() + if status != Running { + t.Errorf("GetStatus returned the wrong status."+ + "\nexpected: %s\nreceived: %s", Running, status) + } + + atomic.StoreUint32((*uint32)(&single1.status), uint32(Stopping)) + status = multi.GetStatus() + if status != Stopping { + t.Errorf("GetStatus returned the wrong status."+ + "\nexpected: %s\nreceived: %s", Stopping, status) + } + + atomic.StoreUint32((*uint32)(&single1.status), uint32(Stopped)) + status = multi.GetStatus() + if status != Stopped { + t.Errorf("GetStatus returned the wrong status."+ + "\nexpected: %s\nreceived: %s", Stopped, status) + } +} + +// Tests that Multi.GetStatus returns the expected Status when it has no +// children. +func TestMulti_GetStatus_NoChildren(t *testing.T) { + multi := NewMulti("testMulti") + + status := multi.GetStatus() + if status != Stopped { + t.Errorf("GetStatus returned the wrong status."+ + "\nexpected: %s\nreceived: %s", Stopped, status) + } +} + +// Tests that Multi.IsRunning returns the expected value when the Multi is +// marked as running, stopping, and stopped. +func TestMulti_IsRunning(t *testing.T) { + multi := NewMulti("testMulti") + single1 := NewSingle("testSingle1") + single2 := NewSingle("testSingle2") + atomic.StoreUint32((*uint32)(&single2.status), uint32(Stopping)) + multi.Add(single1) + multi.Add(single2) + + if result := multi.IsRunning(); !result { + t.Errorf("IsRunning returned the wrong value when running."+ + "\nexpected: %t\nreceived: %t", true, result) + } + + atomic.StoreUint32((*uint32)(&single1.status), uint32(Stopping)) + atomic.StoreUint32((*uint32)(&single2.status), uint32(Stopped)) + if result := multi.IsRunning(); result { + t.Errorf("IsRunning returned the wrong value when stopping."+ + "\nexpected: %t\nreceived: %t", false, result) + } + + atomic.StoreUint32((*uint32)(&single2.status), uint32(Stopped)) + if result := multi.IsRunning(); result { + t.Errorf("IsRunning returned the wrong value when stopped."+ + "\nexpected: %t\nreceived: %t", false, result) + } +} + +// Tests that Multi.IsStopping returns the expected value when the Multi is +// marked as running, stopping, and stopped. +func TestMulti_IsStopping(t *testing.T) { + multi := NewMulti("testMulti") + single1 := NewSingle("testSingle1") + single2 := NewSingle("testSingle2") + atomic.StoreUint32((*uint32)(&single2.status), uint32(Stopped)) + multi.Add(single1) + multi.Add(single2) + + if result := multi.IsStopping(); result { + t.Errorf("IsStopping returned the wrong value when running."+ + "\nexpected: %t\nreceived: %t", true, result) + } + + atomic.StoreUint32((*uint32)(&single1.status), uint32(Stopping)) + if result := multi.IsStopping(); !result { + t.Errorf("IsStopping returned the wrong value when stopping."+ + "\nexpected: %t\nreceived: %t", false, result) + } + + atomic.StoreUint32((*uint32)(&single1.status), uint32(Stopped)) + if result := multi.IsStopping(); result { + t.Errorf("IsStopping returned the wrong value when stopped."+ + "\nexpected: %t\nreceived: %t", false, result) + } +} + +// Tests that Multi.IsStopped returns the expected value when the Multi is +// marked as running, stopping, and stopped. +func TestMulti_IsStopped(t *testing.T) { + multi := NewMulti("testMulti") + single1 := NewSingle("testSingle1") + single2 := NewSingle("testSingle2") + atomic.StoreUint32((*uint32)(&single2.status), uint32(Stopped)) + multi.Add(single1) + multi.Add(single2) + + if result := multi.IsStopped(); result { + t.Errorf("IsStopped returned the wrong value when running."+ + "\nexpected: %t\nreceived: %t", true, result) + } + + atomic.StoreUint32((*uint32)(&single1.status), uint32(Stopping)) + if result := multi.IsStopped(); result { + t.Errorf("IsStopped returned the wrong value when stopping."+ + "\nexpected: %t\nreceived: %t", false, result) + } + + atomic.StoreUint32((*uint32)(&single1.status), uint32(Stopped)) + if result := multi.IsStopped(); !result { + t.Errorf("IsStopped returned the wrong value when stopped."+ + "\nexpected: %t\nreceived: %t", false, result) + } +} + +// Tests that Multi.IsStopped returns true when all of the child stoppables are +// stopped. +func TestMulti_IsStopped_StoppedStatus(t *testing.T) { multi := NewMulti("testMulti") singles := []*Single{ NewSingle("testSingle0"), @@ -125,37 +233,48 @@ func TestMulti_Close(t *testing.T) { NewSingle("testSingle4"), } for _, single := range singles[:3] { + atomic.StoreUint32((*uint32)(&single.status), uint32(Stopped)) multi.Add(single) } subMulti := NewMulti("subMulti") for _, single := range singles[3:] { + atomic.StoreUint32((*uint32)(&single.status), uint32(Stopped)) subMulti.Add(single) } multi.Add(subMulti) + if !multi.IsStopped() { + t.Error("IsStopped did not find all stoppables as stopped.") + } +} + +// Error path: tests that Multi.IsStopped returns false when not all of the +// child stoppables are stopped. +func TestMulti_IsStopped_NotStoppedError(t *testing.T) { + multi := NewMulti("testMulti") + singles := []*Single{ + NewSingle("testSingle0"), + NewSingle("testSingle1"), + NewSingle("testSingle2"), + NewSingle("testSingle3"), + NewSingle("testSingle4"), + } for _, single := range singles { - go func(single *Single) { - select { - case <-time.NewTimer(5 * time.Millisecond).C: - t.Errorf("Single %s failed to quit.", single.Name()) - case <-single.Quit(): - } - }(single) + multi.Add(single) } - err := multi.Close(5 * time.Millisecond) - if err != nil { - t.Errorf("Close() returned an error: %v", err) + for _, single := range singles[:4] { + atomic.StoreUint32((*uint32)(&single.status), uint32(Stopped)) } - err = multi.Close(0) - if err != nil { - t.Errorf("Close() returned an error: %v", err) + if multi.IsStopped() { + t.Error("IsStopped found all the stoppables as stopped when some are " + + "still running") } } // Tests that Multi.Close sends on all Single quit channels. -func TestMulti_Close_Error(t *testing.T) { +func TestMulti_Close(t *testing.T) { multi := NewMulti("testMulti") singles := []*Single{ NewSingle("testSingle0"), @@ -173,7 +292,7 @@ func TestMulti_Close_Error(t *testing.T) { } multi.Add(subMulti) - for _, single := range singles[:2] { + for _, single := range singles { go func(single *Single) { select { case <-time.NewTimer(5 * time.Millisecond).C: @@ -182,16 +301,56 @@ func TestMulti_Close_Error(t *testing.T) { } }(single) } + + err := multi.Close() + if err != nil { + t.Errorf("Close() returned an error: %v", err) + } + + err = multi.Close() + if err != nil { + t.Errorf("Close() returned an error: %v", err) + } +} + +// Error path: tests that Multi.Close returns the expected error when the Single +// stoppables are not running. +func TestMulti_Close_StoppableCloseError(t *testing.T) { + multi := NewMulti("testMulti") + var singles []*Single + for i := 0; i < 5; i++ { + single := NewSingle("testSingle" + strconv.Itoa(i)) + singles = append(singles, single) + multi.Add(single) + atomic.StoreUint32((*uint32)(&single.status), uint32(Stopped)) + } + + var wg sync.WaitGroup + for _, single := range singles { + wg.Add(1) + go func(single *Single) { + select { + case <-time.NewTimer(15 * time.Millisecond).C: + case <-single.Quit(): + t.Errorf("Single %s to quit when it should have failed.", + single.Name()) + } + wg.Done() + }(single) + } + expectedErr := fmt.Sprintf(closeMultiErr, multi.name, 0, 0) expectedErr = strings.SplitN(expectedErr, " 0/0", 2)[0] - err := multi.Close(5 * time.Millisecond) + err := multi.Close() if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Errorf("Close() did not return the expected error."+ "\nexpected: %s\nreceived: %v", expectedErr, err) } - err = multi.Close(0) + wg.Wait() + + err = multi.Close() if err != nil { t.Errorf("Close() returned an error: %v", err) } diff --git a/stoppable/single.go b/stoppable/single.go index df2a5468c4aa1a348ea3b26c85949cca2b54ec0a..dfde7242ed83f0af975efa0656933b56d0b8145a 100644 --- a/stoppable/single.go +++ b/stoppable/single.go @@ -12,33 +12,79 @@ import ( jww "github.com/spf13/jwalterweatherman" "sync" "sync/atomic" - "time" ) // Error message. -const closeTimeoutErr = "stopper for %s failed to stop after timeout of %s" +const toStoppingErr = "failed to set the status of single stoppable %q to " + + "stopped when status is %s instead of %s" // Single allows stopping a single goroutine using a channel. It adheres to the // Stoppable interface. type Single struct { - name string - quit chan struct{} - running uint32 - once sync.Once + name string + quit chan struct{} + status Status + once sync.Once } // NewSingle returns a new single Stoppable. func NewSingle(name string) *Single { return &Single{ - name: name, - quit: make(chan struct{}), - running: running, + name: name, + quit: make(chan struct{}, 1), + status: Running, } } -// IsRunning returns true if stoppable is marked as running. +// Name returns the name of the Single Stoppable. +func (s *Single) Name() string { + return s.name +} + +// GetStatus returns the status of the Stoppable. +func (s *Single) GetStatus() Status { + return Status(atomic.LoadUint32((*uint32)(&s.status))) +} + +// IsRunning returns true if Stoppable is marked as running. func (s *Single) IsRunning() bool { - return atomic.LoadUint32(&s.running) == running + return s.GetStatus() == Running +} + +// IsStopping returns true if Stoppable is marked as stopping. +func (s *Single) IsStopping() bool { + return s.GetStatus() == Stopping +} + +// IsStopped returns true if Stoppable is marked as stopped. +func (s *Single) IsStopped() bool { + return s.GetStatus() == Stopped +} + +// toStopping changes the status from running to stopping. An error is returned +// if the status is not already set to running. +func (s *Single) toStopping() error { + if !atomic.CompareAndSwapUint32((*uint32)(&s.status), uint32(Running), uint32(Stopping)) { + return errors.Errorf(toStoppingErr, s.Name(), s.GetStatus(), Running) + } + + jww.TRACE.Printf("Switched status of single stoppable %q from %s to %s.", + s.Name(), Running, Stopping) + + return nil +} + +// ToStopped changes the status from stopping to stopped. Panics if the status +// is not already set to stopping. +func (s *Single) ToStopped() { + if !atomic.CompareAndSwapUint32((*uint32)(&s.status), uint32(Stopping), uint32(Stopped)) { + jww.FATAL.Panicf("Failed to set the status of single stoppable %q to "+ + "stopped when status is %s instead of %s.", + s.Name(), s.GetStatus(), Stopping) + } + + jww.TRACE.Printf("Switched status of single stoppable %q from %s to %s.", + s.Name(), Stopping, Stopped) } // Quit returns a receive-only channel that will be triggered when the Stoppable @@ -47,24 +93,28 @@ func (s *Single) Quit() <-chan struct{} { return s.quit } -// Name returns the name of the Single Stoppable. -func (s *Single) Name() string { - return s.name -} - // Close signals the Single to close via the quit channel. Returns an error if -// sending on the quit channel times out. -func (s *Single) Close(timeout time.Duration) error { +// the status of the Single is not Running. +func (s *Single) Close() error { var err error + s.once.Do(func() { - select { - case <-time.NewTimer(timeout).C: - err = errors.Errorf(closeTimeoutErr, s.name, timeout) - jww.ERROR.Print(err.Error()) - case s.quit <- struct{}{}: + // Attempt to set status to stopping or return an error if unable + err = s.toStopping() + if err != nil { + return } - atomic.StoreUint32(&s.running, stopped) + + jww.TRACE.Printf("Sending on quit channel to single stoppable %q.", + s.Name()) + + // Send on quit channel + s.quit <- struct{}{} }) + if err != nil { + jww.ERROR.Print(err.Error()) + } + return err } diff --git a/stoppable/single_test.go b/stoppable/single_test.go index 4898a4a59008c6ea747f7a54c59ee4b84388b23c..c93a1ebfcc1815b2d5c82e7f2e2dc43ff96a076c 100644 --- a/stoppable/single_test.go +++ b/stoppable/single_test.go @@ -9,6 +9,7 @@ package stoppable import ( "fmt" + "sync/atomic" "testing" "time" ) @@ -23,29 +24,186 @@ func TestNewSingle(t *testing.T) { "\nexpected: %s\nreceived: %s", name, single.name) } - if single.running != running { - t.Errorf("NewSingle returned Single with incorrect running."+ - "\nexpected: %d\nreceived: %d", running, single.running) + if single.status != Running { + t.Errorf("NewSingle returned Single with incorrect status."+ + "\nexpected: %s\nreceived: %s", Running, single.status) + } +} + +// Unit test of Single.Name. +func TestSingle_Name(t *testing.T) { + name := "threadName" + single := NewSingle(name) + + if name != single.Name() { + t.Errorf("Name did not return the expected name."+ + "\nexpected: %s\nreceived: %s", name, single.Name()) + } +} + +// Tests that Single.GetStatus returns the expected Status. +func TestSingle_GetStatus(t *testing.T) { + single := NewSingle("threadName") + + status := single.GetStatus() + if status != Running { + t.Errorf("GetStatus returned the wrong status."+ + "\nexpected: %s\nreceived: %s", Running, status) + } + + atomic.StoreUint32((*uint32)(&single.status), uint32(Stopping)) + status = single.GetStatus() + if status != Stopping { + t.Errorf("GetStatus returned the wrong status."+ + "\nexpected: %s\nreceived: %s", Stopping, status) + } + + atomic.StoreUint32((*uint32)(&single.status), uint32(Stopped)) + status = single.GetStatus() + if status != Stopped { + t.Errorf("GetStatus returned the wrong status."+ + "\nexpected: %s\nreceived: %s", Stopped, status) } } // Tests that Single.IsRunning returns the expected value when the Single is -// marked as both running and not running. +// marked as running, stopping, and stopped. func TestSingle_IsRunning(t *testing.T) { single := NewSingle("threadName") - if !single.IsRunning() { + if result := single.IsRunning(); !result { t.Errorf("IsRunning returned the wrong value when running."+ - "\nexpected: %t\nreceived: %t", true, single.IsRunning()) + "\nexpected: %t\nreceived: %t", true, result) + } + + single.status = Stopping + if result := single.IsRunning(); result { + t.Errorf("IsRunning returned the wrong value when stopping."+ + "\nexpected: %t\nreceived: %t", false, result) + } + + single.status = Stopped + if result := single.IsRunning(); result { + t.Errorf("IsRunning returned the wrong value when stopped."+ + "\nexpected: %t\nreceived: %t", false, result) + } +} + +// Tests that Single.IsStopping returns the expected value when the Single is +// marked as running, stopping, and stopped. +func TestSingle_IsStopping(t *testing.T) { + single := NewSingle("threadName") + + if result := single.IsStopping(); result { + t.Errorf("IsStopping returned the wrong value when running."+ + "\nexpected: %t\nreceived: %t", true, result) + } + + single.status = Stopping + if result := single.IsStopping(); !result { + t.Errorf("IsStopping returned the wrong value when stopping."+ + "\nexpected: %t\nreceived: %t", false, result) + } + + single.status = Stopped + if result := single.IsStopping(); result { + t.Errorf("IsStopping returned the wrong value when stopped."+ + "\nexpected: %t\nreceived: %t", false, result) + } +} + +// Tests that Single.IsStopped returns the expected value when the Single is +// marked as running, stopping, and stopped. +func TestSingle_IsStopped(t *testing.T) { + single := NewSingle("threadName") + + if result := single.IsStopped(); result { + t.Errorf("IsStopped returned the wrong value when running."+ + "\nexpected: %t\nreceived: %t", true, result) + } + + single.status = Stopping + if result := single.IsStopped(); result { + t.Errorf("IsStopped returned the wrong value when stopping."+ + "\nexpected: %t\nreceived: %t", false, result) + } + + single.status = Stopped + if result := single.IsStopped(); !result { + t.Errorf("IsStopped returned the wrong value when stopped."+ + "\nexpected: %t\nreceived: %t", false, result) + } +} + +// Tests that Single.toStopping changes the status to stopping. +func TestSingle_toStopping(t *testing.T) { + single := NewSingle("threadName") + + err := single.toStopping() + if err != nil { + t.Errorf("toStopping returned an error: %+v", err) } - single.running = stopped - if single.IsRunning() { - t.Errorf("IsRunning returned the wrong value when not running."+ - "\nexpected: %t\nreceived: %t", false, single.IsRunning()) + if single.status != Stopping { + t.Errorf("toStopping failed to set the status correctly."+ + "\nexpected: %s\nreceived: %s", Stopping, single.status) } } +// Error path: tests that Single.toStopping returns an error when failing to +// change the status to stopping when the current status is not running. +func TestSingle_toStopping_StatusError(t *testing.T) { + single := NewSingle("threadName") + single.status = Stopped + expectedErr := fmt.Sprintf( + toStoppingErr, single.Name(), single.GetStatus(), Running) + + err := single.toStopping() + if err == nil || err.Error() != expectedErr { + t.Errorf("toStopping failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } + + if single.status != Stopped { + t.Errorf("toStopping changed the status when the compare failed."+ + "\nexpected: %s\nreceived: %s", Stopped, single.status) + } +} + +// Tests that Single.ToStopped changes the status to stopped. +func TestSingle_ToStopped(t *testing.T) { + single := NewSingle("threadName") + + single.status = Stopping + single.ToStopped() + + if single.status != Stopped { + t.Errorf("ToStopped failed to set the status correctly."+ + "\nexpected: %s\nreceived: %s", Stopped, single.status) + } +} + +// Panic path: tests that Single.ToStopped panics when failing to change the +// status to stopped when the current status is not stopping. +func TestSingle_ToStopped_StatusPanic(t *testing.T) { + single := NewSingle("threadName") + + defer func() { + if r := recover(); r == nil { + t.Errorf("ToStopped failed to panic when the status should not " + + "have changed.") + } else { + if single.status != Running { + t.Errorf("ToStopped changed the status when the compare failed."+ + "\nexpected: %s\nreceived: %s", Running, single.status) + } + } + }() + + single.status = Running + single.ToStopped() +} + // Tests that Single.Quit returns a channel that is triggered when the Single // quit channel is triggered. func TestSingle_Quit(t *testing.T) { @@ -62,48 +220,40 @@ func TestSingle_Quit(t *testing.T) { single.quit <- struct{}{} } -// Unit test of Single.Name. -func TestSingle_Name(t *testing.T) { - name := "threadName" - single := NewSingle(name) - - if name != single.Name() { - t.Errorf("Name did not return the expected name."+ - "\nexpected: %s\nreceived: %s", name, single.Name()) - } -} - // Test happy path of Single.Close(). func TestSingle_Close(t *testing.T) { single := NewSingle("threadName") + timeout := 10 * time.Millisecond go func() { select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting for quit channel.") + case <-time.NewTimer(timeout).C: + t.Errorf("Timed out waiting to receive on quit channel after %s.", + timeout) case <-single.Quit(): + if !single.IsStopping() { + t.Errorf("Status of stoppable incorrect."+ + "\nexpected: %s\nreceived: %s", Stopping, single.status) + } + atomic.StoreUint32((*uint32)(&single.status), uint32(Stopped)) } }() - err := single.Close(5 * time.Millisecond) + err := single.Close() if err != nil { t.Errorf("Close returned an error: %v", err) } } -// Error path: tests that Single.Close returns an error when the timeout is -// reached. +// Error path: tests that Single.Close returns an error when the status fails +// to change to stopping. func TestSingle_Close_Error(t *testing.T) { single := NewSingle("threadName") - timeout := time.Millisecond - expectedErr := fmt.Sprintf(closeTimeoutErr, single.Name(), timeout) - - go func() { - time.Sleep(5 * time.Millisecond) - <-single.Quit() - }() + single.status = Stopped + expectedErr := fmt.Sprintf( + toStoppingErr, single.Name(), single.GetStatus(), Running) - err := single.Close(timeout) + err := single.Close() if err == nil || err.Error() != expectedErr { t.Errorf("Close did not return the expected error."+ "\nexpected: %s\nreceived: %v", expectedErr, err) diff --git a/stoppable/status.go b/stoppable/status.go new file mode 100644 index 0000000000000000000000000000000000000000..1b306bd69a13394d8321c52b742ef5ecf82a83a9 --- /dev/null +++ b/stoppable/status.go @@ -0,0 +1,36 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package stoppable + +import ( + "strconv" +) + +const ( + Running Status = iota + Stopping + Stopped +) + +// Status holds the current status of a Stoppable. +type Status uint32 + +// String prints a string representation of the current Status. This functions +// satisfies the fmt.Stringer interface. +func (s Status) String() string { + switch s { + case Running: + return "running" + case Stopping: + return "stopping" + case Stopped: + return "stopped" + default: + return "INVALID STATUS: " + strconv.FormatUint(uint64(s), 10) + } +} diff --git a/stoppable/status_test.go b/stoppable/status_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c2f4bcd05f106e1f0bc6e6083a88096a3cdde1f7 --- /dev/null +++ b/stoppable/status_test.go @@ -0,0 +1,32 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package stoppable + +import ( + "testing" +) + +// Unit test of Status.String. +func TestStatus_String(t *testing.T) { + testValues := []struct { + status Status + expected string + }{ + {Running, "running"}, + {Stopping, "stopping"}, + {Stopped, "stopped"}, + {100, "INVALID STATUS: 100"}, + } + + for i, val := range testValues { + if val.status.String() != val.expected { + t.Errorf("String did not return the expected value (%d)."+ + "\nexpected: %s\nreceived: %s", i, val.status.String(), val.expected) + } + } +} diff --git a/stoppable/stoppable.go b/stoppable/stoppable.go index b3b219edf9883e12cc301c8570199c35e40bbb27..b5b072d1424feddf97cd029fa08d519ef2998c01 100644 --- a/stoppable/stoppable.go +++ b/stoppable/stoppable.go @@ -7,16 +7,80 @@ package stoppable -import "time" +import ( + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "strings" + "time" +) +// Error message returned after a comms operations ends and finds that its +// parent thread is stopping or stopped. const ( - stopped = 0 - running = 1 + errKey = "[StoppableNotRunning]" + ErrMsg = "stoppable %q is not running, exiting %s early " + errKey + timeoutErr = "timed out after %s waiting for the stoppable to stop for %q" ) -// Stoppable interface for stopping a goroutine. +// pollPeriod is the duration to wait between polls to see of stoppables are +// stopped. +const pollPeriod = 100 * time.Millisecond + +// Stoppable interface for stopping a goroutine. All functions are thread safe. type Stoppable interface { - Close(timeout time.Duration) error - IsRunning() bool + // Name returns the name of the Stoppable. Name() string + + // GetStatus returns the status of the Stoppable. + GetStatus() Status + + // IsRunning returns true if the Stoppable is running. + IsRunning() bool + + // IsStopping returns true if Stoppable is marked as stopping. + IsStopping() bool + + // IsStopped returns true if Stoppable is marked as stopped. + IsStopped() bool + + // Close marks the Stoppable as stopping and issues a close signal to the + // Stoppable or any children it may have. + Close() error +} + +// WaitForStopped polls the stoppable and all its children to see if they are +// stopped. Returns an error if its times out waiting for all children to stop. +func WaitForStopped(s Stoppable, timeout time.Duration) error { + done := make(chan struct{}) + + // Launch the processes to check if all stoppables are stopped in separate + // goroutine so that when the timeout is reached, no time is wasted exiting + go func() { + for !s.IsStopped() { + time.Sleep(pollPeriod) + } + + select { + case done <- struct{}{}: + case <-time.NewTimer(50 * time.Millisecond).C: + } + }() + + select { + case <-done: + jww.INFO.Printf("All stoppables have stopped for %q.", s.Name()) + return nil + case <-time.NewTimer(timeout).C: + return errors.Errorf(timeoutErr, timeout, s.Name()) + } +} + +// CheckErr returns true if the error contains a stoppable error message. This +// function is used by callers to determine if a sub function quit due to a +// stoppable closing and tells the caller to exit. +func CheckErr(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), errKey) } diff --git a/stoppable/stoppable_test.go b/stoppable/stoppable_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0f070317aedef9fca29a0b66d023b4e7d151517f --- /dev/null +++ b/stoppable/stoppable_test.go @@ -0,0 +1,123 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package stoppable + +import ( + "fmt" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "os" + "strconv" + "testing" + "time" +) + +func TestMain(m *testing.M) { + jww.SetStdoutThreshold(jww.LevelTrace) + + os.Exit(m.Run()) +} + +// Tests that WaitForStopped does not return an error when all children are +// stopped. +func TestWaitForStopped(t *testing.T) { + m := newTestMulti() + + err := m.Close() + if err != nil { + t.Errorf("Failed to close multi stoppable: %+v", err) + } + + err = WaitForStopped(m, 2*time.Second) + if err != nil { + t.Errorf("WaitForStopped returned an error: %+v", err) + } +} + +// Error path: tests that WaitForStopped returns an error if the timeout is +// reached before all stoppables are checked. +func TestWaitForStopped_TimeoutError(t *testing.T) { + m := newTestMulti() + + err := m.Close() + if err != nil { + t.Errorf("Failed to close multi stoppable: %+v", err) + } + + expectedErr := fmt.Sprintf(timeoutErr, time.Duration(0), m.Name()) + + err = WaitForStopped(m, 0) + if err == nil || err.Error() != expectedErr { + t.Errorf("WaitForStopped did not return the expected error."+ + "\nexpected: %s\nrecieved: %+v", expectedErr, err) + } +} + +// Tests that TestCheckErr returns true for stoppable errors and false for all +// other errors +func TestCheckErr(t *testing.T) { + testValues := []struct { + err error + expected bool + }{ + {errors.Errorf(ErrMsg, "testThre", "testFunc"), true}, + {errors.Errorf(ErrMsg, "", ""), true}, + {errors.Errorf(errKey), true}, + {errors.Errorf("Random error"), false}, + {errors.Errorf(""), false}, + {nil, false}, + } + + for i, val := range testValues { + result := CheckErr(val.err) + if result != val.expected { + t.Errorf("CheckErr failed to return the expected value (%d)."+ + "\nexpected: %t\nreceived: %t", i, val.expected, result) + } + } +} + +// newTestMulti creates a new Multi Stoppable that has many Single and Multi +// stoppable children. +func newTestMulti() *Multi { + singles := make([]*Single, 15) + for i := range singles { + singles[i] = NewSingle("testSingle_" + strconv.Itoa(i)) + go func(single *Single) { + <-single.Quit() + time.Sleep(600 * time.Millisecond) + single.ToStopped() + }(singles[i]) + } + + m := NewMulti("testMulti") + for _, s := range singles[:5] { + m.Add(s) + } + m0 := NewMulti("testMulti_0") + for _, s := range singles[5:8] { + m0.Add(s) + } + m.Add(m0) + m1 := NewMulti("testMulti_1") + for _, s := range singles[8:10] { + m1.Add(s) + } + m2 := NewMulti("testMulti_2") + for _, s := range singles[10:13] { + m2.Add(s) + } + m1.Add(m2) + m.Add(m1) + for _, s := range singles[13:] { + m.Add(s) + } + m.Add(NewMulti("testMulti_3")) + + return m +} diff --git a/storage/e2e/manager.go b/storage/e2e/manager.go index 5f12e6b449e35d91b3cfc3f15e0240cc89f6fdae..6434f5283e342d2c1bc36a26006f24d6798f0bb1 100644 --- a/storage/e2e/manager.go +++ b/storage/e2e/manager.go @@ -8,6 +8,8 @@ package e2e import ( + "bytes" + "encoding/base64" "fmt" "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" @@ -17,6 +19,8 @@ import ( "gitlab.com/elixxir/crypto/cyclic" dh "gitlab.com/elixxir/crypto/diffieHellman" "gitlab.com/xx_network/primitives/id" + "golang.org/x/crypto/blake2b" + "sort" ) const managerPrefix = "Manager{partner:%s}" @@ -223,3 +227,24 @@ func (m *Manager) GetMyOriginPrivateKey() *cyclic.Int { func (m *Manager) GetPartnerOriginPublicKey() *cyclic.Int { return m.originPartnerPubKey.DeepCopy() } + +const relationshipFpLength = 15 + +// GetRelationshipFingerprint returns a unique fingerprint for an E2E +// relationship. The fingerprint is a base 64 encoded hash of of the two +// relationship fingerprints truncated to 15 characters. +func (m *Manager) GetRelationshipFingerprint() string { + // Sort fingerprints + fps := [][]byte{m.receive.fingerprint, m.send.fingerprint} + less := func(i, j int) bool { return bytes.Compare(fps[i], fps[j]) == -1 } + sort.Slice(fps, less) + + // Hash fingerprints + h, _ := blake2b.New256(nil) + for _, fp := range fps { + h.Write(fp) + } + + // Base 64 encode hash and truncate + return base64.StdEncoding.EncodeToString(h.Sum(nil))[:relationshipFpLength] +} diff --git a/storage/e2e/manager_test.go b/storage/e2e/manager_test.go index fed60f22bb7ba288a9b1d757d4431a706a784612..114afac841407a4c4f8a0116ac452b7bd01c68cc 100644 --- a/storage/e2e/manager_test.go +++ b/storage/e2e/manager_test.go @@ -9,12 +9,14 @@ package e2e import ( "bytes" + "encoding/base64" "fmt" "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/storage/versioned" "gitlab.com/elixxir/ekv" "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/netTime" + "golang.org/x/crypto/blake2b" "math/rand" "reflect" "testing" @@ -307,3 +309,65 @@ func managersEqual(expected, received *Manager, t *testing.T) bool { return equal } + +// Unit test of Manager.GetRelationshipFingerprint. +func TestManager_GetRelationshipFingerprint(t *testing.T) { + m, _ := newTestManager(t) + m.receive.fingerprint = []byte{5} + m.send.fingerprint = []byte{10} + h, _ := blake2b.New256(nil) + h.Write(append(m.receive.fingerprint, m.send.fingerprint...)) + expected := base64.StdEncoding.EncodeToString(h.Sum(nil))[:relationshipFpLength] + + fp := m.GetRelationshipFingerprint() + if fp != expected { + t.Errorf("GetRelationshipFingerprint did not return the expected "+ + "fingerprint.\nexpected: %s\nreceived: %s", expected, fp) + } + + // Flip the order and show that the output is the same. + m.receive.fingerprint, m.send.fingerprint = m.send.fingerprint, m.receive.fingerprint + + fp = m.GetRelationshipFingerprint() + if fp != expected { + t.Errorf("GetRelationshipFingerprint did not return the expected "+ + "fingerprint.\nexpected: %s\nreceived: %s", expected, fp) + } +} + +// Tests the consistency of the output of Manager.GetRelationshipFingerprint. +func TestManager_GetRelationshipFingerprint_Consistency(t *testing.T) { + m, _ := newTestManager(t) + prng := rand.New(rand.NewSource(42)) + expectedFps := []string{ + "GmeTCfxGOqRqeID", "gbpJjHd3tIe8BKy", "2/ZdG+WNzODJBiF", + "+V1ySeDLQfQNSkv", "23OMC+rBmCk+gsu", "qHu5MUVs83oMqy8", + "kuXqxsezI0kS9Bc", "SlEhsoZ4BzAMTtr", "yG8m6SPQfV/sbTR", + "j01ZSSm762TH7mj", "SKFDbFvsPcohKPw", "6JB5HK8DHGwS4uX", + "dU3mS1ujduGD+VY", "BDXAy3trbs8P4mu", "I4HoXW45EwWR0oD", + "661YH2l2jfOkHbA", "cSS9ZyTOQKVx67a", "ojfubzDIsMNYc/t", + "2WrEw83Yz6Rhq9I", "TQILxBIUWMiQS2j", "rEqdieDTXJfCQ6I", + } + + for i, expected := range expectedFps { + prng.Read(m.receive.fingerprint) + prng.Read(m.send.fingerprint) + + fp := m.GetRelationshipFingerprint() + if fp != expected { + t.Errorf("GetRelationshipFingerprint did not return the expected "+ + "fingerprint (%d).\nexpected: %s\nreceived: %s", i, expected, fp) + } + + // Flip the order and show that the output is the same. + m.receive.fingerprint, m.send.fingerprint = m.send.fingerprint, m.receive.fingerprint + + fp = m.GetRelationshipFingerprint() + if fp != expected { + t.Errorf("GetRelationshipFingerprint did not return the expected "+ + "fingerprint (%d).\nexpected: %s\nreceived: %s", i, expected, fp) + } + + // fmt.Printf("\"%s\",\n", fp) // Uncomment to reprint expected values + } +} diff --git a/storage/e2e/relationship.go b/storage/e2e/relationship.go index 9089b04d871e6a394a563fc141cb24b5481726c2..2006cb451a4d70b3c41772b95bd98803ae6dc5c7 100644 --- a/storage/e2e/relationship.go +++ b/storage/e2e/relationship.go @@ -71,7 +71,7 @@ func NewRelationship(manager *Manager, t RelationshipType, if err := s.save(); err != nil { jww.FATAL.Panicf("Failed to Send session after setting to "+ - "confimred: %+v", err) + "confirmed: %+v", err) } r.addSession(s) diff --git a/storage/e2e/session.go b/storage/e2e/session.go index ad684725a04918389f706566e44642c6c1fbaf16..6e85d873e36768d2372b5f713c45527625233c3c 100644 --- a/storage/e2e/session.go +++ b/storage/e2e/session.go @@ -105,8 +105,8 @@ func newSession(ship *relationship, t RelationshipType, myPrivKey, partnerPubKey negotiationStatus Negotiation, e2eParams params.E2ESessionParams) *Session { if e2eParams.MinKeys < 10 { - jww.FATAL.Panicf("Cannot create a session with a minnimum number " + - "of keys less than 10") + jww.FATAL.Panicf("Cannot create a session with a minimum number "+ + "of keys (%d) less than 10", e2eParams.MinKeys) } session := &Session{ diff --git a/storage/e2e/store.go b/storage/e2e/store.go index a7868e0aa5322510824453a621e4bc5f0deec605..2a8006902a9adf5a94704175f55f548f61eb0d7a 100644 --- a/storage/e2e/store.go +++ b/storage/e2e/store.go @@ -14,6 +14,7 @@ import ( "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/storage/utility" "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/crypto/contact" "gitlab.com/elixxir/crypto/cyclic" "gitlab.com/elixxir/crypto/diffieHellman" "gitlab.com/elixxir/crypto/fastRNG" @@ -180,7 +181,7 @@ func (s *Store) AddPartner(partnerID *id.ID, partnerPubKey, myPrivKey *cyclic.In s.managers[*partnerID] = m if err := s.save(); err != nil { - jww.FATAL.Printf("Failed to add Parter %s: Save of store failed: %s", + jww.FATAL.Printf("Failed to add Partner %s: Save of store failed: %s", partnerID, err) } @@ -215,6 +216,28 @@ func (s *Store) GetPartner(partnerID *id.ID) (*Manager, error) { return m, nil } +// GetPartnerContact find the partner with the given ID and assembles and +// returns a contact.Contact with their ID and DH key. An error is returned if +// no partner exists for the given ID. +func (s *Store) GetPartnerContact(partnerID *id.ID) (contact.Contact, error) { + s.mux.RLock() + defer s.mux.RUnlock() + + // Get partner + m, exists := s.managers[*partnerID] + if !exists { + return contact.Contact{}, errors.New(NoPartnerErrorStr) + } + + // Assemble Contact + c := contact.Contact{ + ID: m.GetPartnerID(), + DhPubKey: m.GetPartnerOriginPublicKey(), + } + + return c, nil +} + // PopKey pops a key for use based upon its fingerprint. func (s *Store) PopKey(f format.Fingerprint) (*Key, bool) { return s.fingerprints.Pop(f) diff --git a/storage/e2e/store_test.go b/storage/e2e/store_test.go index f5d3e7f7ccbc037ea6185470b857aaef3428a8b6..82a536165e8b38f30d713c4d4b151c01dc6b27eb 100644 --- a/storage/e2e/store_test.go +++ b/storage/e2e/store_test.go @@ -11,6 +11,7 @@ import ( "bytes" "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/crypto/contact" "gitlab.com/elixxir/crypto/cyclic" "gitlab.com/elixxir/crypto/diffieHellman" "gitlab.com/elixxir/crypto/fastRNG" @@ -145,7 +146,7 @@ func TestStore_GetPartner(t *testing.T) { p := params.GetDefaultE2ESessionParams() expectedManager := newManager(s.context, s.kv, partnerID, s.dhPrivateKey, pubKey, p, p) - s.AddPartner(partnerID, pubKey, s.dhPrivateKey, p, p) + _ = s.AddPartner(partnerID, pubKey, s.dhPrivateKey, p, p) m, err := s.GetPartner(partnerID) if err != nil { @@ -174,6 +175,41 @@ func TestStore_GetPartner_Error(t *testing.T) { } } +// Tests happy path of Store.GetPartnerContact. +func TestStore_GetPartnerContact(t *testing.T) { + s, _, _ := makeTestStore() + partnerID := id.NewIdFromUInt(rand.Uint64(), id.User, t) + pubKey := diffieHellman.GeneratePublicKey(s.dhPrivateKey, s.grp) + p := params.GetDefaultE2ESessionParams() + expected := contact.Contact{ + ID: partnerID, + DhPubKey: pubKey, + } + _ = s.AddPartner(partnerID, pubKey, s.dhPrivateKey, p, p) + + c, err := s.GetPartnerContact(partnerID) + if err != nil { + t.Errorf("GetPartnerContact() produced an error: %+v", err) + } + + if !reflect.DeepEqual(expected, c) { + t.Errorf("GetPartnerContact() returned wrong Contact."+ + "\nexpected: %s\nreceived: %s", expected, c) + } +} + +// Tests that Store.GetPartnerContact returns an error for non existent partnerID. +func TestStore_GetPartnerContact_Error(t *testing.T) { + s, _, _ := makeTestStore() + partnerID := id.NewIdFromUInt(rand.Uint64(), id.User, t) + + _, err := s.GetPartnerContact(partnerID) + if err == nil || err.Error() != NoPartnerErrorStr { + t.Errorf("GetPartnerContact() did not produce the expected error."+ + "\nexpected: %s\nreceived: %+v", NoPartnerErrorStr, err) + } +} + // Tests happy path of Store.PopKey. func TestStore_PopKey(t *testing.T) { s, _, _ := makeTestStore() diff --git a/storage/session.go b/storage/session.go index 0277166a98b9dd51f06509efb6f35e514aae249d..7fdeea29e807fac8de73b3e92b20da789682962f 100644 --- a/storage/session.go +++ b/storage/session.go @@ -315,6 +315,13 @@ func (s *Session) Delete(key string) error { return s.kv.Delete(key, currentSessionVersion) } +// GetKV returns the Session versioned.KV. +func (s *Session) GetKV() *versioned.KV { + s.mux.RLock() + defer s.mux.RUnlock() + return s.kv +} + // Initializes a Session object wrapped around a MemStore object. // FOR TESTING ONLY func InitTestingSession(i interface{}) *Session { @@ -343,7 +350,6 @@ func InitTestingSession(i interface{}) *Session { } u.SetRegistrationTimestamp(testTime.UnixNano()) - s.user = u cmixGrp := cyclic.NewGroup( large.NewIntFromString("9DB6FB5951B66BB6FE1E140F1D2CE5502374161FD6538DF1648218642F0B5C48"+