diff --git a/api/client.go b/api/client.go index 34c0cf9df33354b06048abf1c7a2197828bc3a77..c8ae08a392afeba240f06a932dfb067c2812d531 100644 --- a/api/client.go +++ b/api/client.go @@ -45,7 +45,7 @@ import ( type Client struct { storage globals.Storage session user.Session - commManager *io.CommManager + commManager *io.ReceptionManager ndf *ndf.NetworkDefinition topology *circuit.Circuit opStatus OperationProgressCallback @@ -220,7 +220,7 @@ func NewClient(s globals.Storage, locA, locB string, ndfJSON *ndf.NetworkDefinit cl := new(Client) cl.storage = store - cl.commManager = io.NewCommManager() + cl.commManager = io.NewReceptionManager(cl.rekeyChan) cl.ndf = ndfJSON //build the topology nodeIDs := make([]*id.Node, len(cl.ndf.Nodes)) @@ -637,7 +637,19 @@ func (cl *Client) StartMessageReceiver(callback func(error)) error { if !ok { return errors.New("Failed to retrieve host for transmission") } - go cl.commManager.MessageReceiver(cl.session, pollWaitTimeMillis, cl.rekeyChan, receptionHost, callback) + + go func() { + defer func() { + if r := recover(); r != nil { + globals.Log.ERROR.Println("Message Receiver Panicked: ", r) + time.Sleep(1 * time.Second) + go func() { + callback(errors.New(fmt.Sprintln("Message Receiver Panicked", r))) + }() + } + }() + cl.commManager.MessageReceiver(cl.session, pollWaitTimeMillis, receptionHost, callback) + }() return nil } @@ -901,8 +913,8 @@ func (cl *Client) GetSession() user.Session { return cl.session } -// CommManager returns the comm manager object for external access. Access +// ReceptionManager returns the comm manager object for external access. Access // at your own risk -func (cl *Client) GetCommManager() *io.CommManager { +func (cl *Client) GetCommManager() *io.ReceptionManager { return cl.commManager } diff --git a/api/client_test.go b/api/client_test.go index 18d6a19285c5eef41fcbfb345352610540377217..f62601ddb50db7fe02ec077babe34625b832c454 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -644,7 +644,7 @@ func TestClient_GetCommManager(t *testing.T) { testClient, _ := NewClient(&globals.RamStorage{}, "", "", def, dummyConnectionStatusHandler) - testClient.commManager = &io.CommManager{} + testClient.commManager = &io.ReceptionManager{} if !reflect.DeepEqual(testClient.GetCommManager(), testClient.commManager) { t.Error("Received session not the same as the real session") diff --git a/api/connect.go b/api/connect.go index 14bdc8010d8028aca28f85e5c319fd2e455b0177..6dbb55a4baf9260972cf36a4b92a574439b645be 100644 --- a/api/connect.go +++ b/api/connect.go @@ -4,8 +4,10 @@ import ( "fmt" "github.com/pkg/errors" "gitlab.com/elixxir/client/globals" + "gitlab.com/elixxir/client/io" "gitlab.com/elixxir/primitives/circuit" "gitlab.com/elixxir/primitives/id" + "gitlab.com/elixxir/primitives/ndf" ) // Checks version and connects to gateways using TLS filepaths to create @@ -13,7 +15,7 @@ import ( func (cl *Client) InitNetwork() error { //InitNetwork to permissioning if cl.ndf.Registration.Address != "" { - isConnected, err := cl.AddPermissioningHost() + isConnected, err := AddPermissioningHost(cl.commManager, cl.ndf) if err != nil { return err @@ -39,14 +41,6 @@ func (cl *Client) InitNetwork() error { globals.Log.WARN.Println("Registration not defined, not contacted") } - //build the topology - nodeIDs := make([]*id.Node, len(cl.ndf.Nodes)) - for i, node := range cl.ndf.Nodes { - nodeIDs[i] = id.NewNodeFromBytes(node.ID) - } - - cl.topology = circuit.New(nodeIDs) - // Only check the version if we got a remote version // The remote version won't have been populated if we didn't connect to permissioning if cl.GetRegistrationVersion() != "" { @@ -64,29 +58,38 @@ func (cl *Client) InitNetwork() error { "registration server, because it's not populated. Do you have " + "access to the registration server?") } - return cl.AddGatewayHosts() + + //build the topology + nodeIDs := make([]*id.Node, len(cl.ndf.Nodes)) + for i, node := range cl.ndf.Nodes { + nodeIDs[i] = id.NewNodeFromBytes(node.ID) + } + + cl.topology = circuit.New(nodeIDs) + + return AddGatewayHosts(cl.commManager, cl.ndf) } // Connects to gateways using tls filepaths to create credential information // for connection establishment -func (cl *Client) AddGatewayHosts() error { // tear out - if len(cl.ndf.Gateways) < 1 { +func AddGatewayHosts(rm *io.ReceptionManager, definition *ndf.NetworkDefinition) error { // tear out + if len(definition.Gateways) < 1 { return errors.New("could not connect due to invalid number of nodes") } // connect to all gateways var errs error = nil - for i, gateway := range cl.ndf.Gateways { + for i, gateway := range definition.Gateways { var gwCreds []byte if gateway.TlsCertificate != "" { gwCreds = []byte(gateway.TlsCertificate) } - gwID := id.NewNodeFromBytes(cl.ndf.Nodes[i].ID).NewGateway() + gwID := id.NewNodeFromBytes(definition.Nodes[i].ID).NewGateway() gwAddr := gateway.Address - _, err := cl.commManager.Comms.AddHost(gwID.String(), gwAddr, gwCreds, false) + _, err := rm.Comms.AddHost(gwID.String(), gwAddr, gwCreds, false) if err != nil { err = errors.Errorf("Failed to create host for gateway %s at %s: %+v", gwID.String(), gwAddr, err) @@ -103,25 +106,21 @@ func (cl *Client) AddGatewayHosts() error { // tear out // There's currently no need to keep connected to permissioning constantly, // so we have functions to connect to and disconnect from it when a connection // to permissioning is needed -func (cl *Client) AddPermissioningHost() (bool, error) { // this disappears, make host in simple call - if cl.ndf.Registration.Address != "" { - _, ok := cl.commManager.Comms.GetHost(PermissioningAddrID) - if ok { - return true, nil - } +func AddPermissioningHost(rm *io.ReceptionManager, definition *ndf.NetworkDefinition) (bool, error) { // this disappears, make host in simple call + if definition.Registration.Address != "" { var regCert []byte - if cl.ndf.Registration.TlsCertificate != "" { - regCert = []byte(cl.ndf.Registration.TlsCertificate) + if definition.Registration.TlsCertificate != "" { + regCert = []byte(definition.Registration.TlsCertificate) } - _, err := cl.commManager.Comms.AddHost(PermissioningAddrID, cl.ndf.Registration.Address, regCert, false) + _, err := rm.Comms.AddHost(PermissioningAddrID, definition.Registration.Address, regCert, false) if err != nil { return false, errors.New(fmt.Sprintf( "Failed connecting to create host for permissioning: %+v", err)) } return true, nil } else { - globals.Log.DEBUG.Printf("failed to connect to %v silently", cl.ndf.Registration.Address) + globals.Log.DEBUG.Printf("failed to connect to %v silently", definition.Registration.Address) // Without an NDF, we can't connect to permissioning, but this isn't an // error per se, because we should be phasing out permissioning at some // point diff --git a/api/private.go b/api/private.go index bd6b93454ca25ae984ee2c8e503a2b40fc84be1d..aae6ba8571217f5e7e9c31b1be326d7d7d590d1d 100644 --- a/api/private.go +++ b/api/private.go @@ -70,7 +70,7 @@ func (cl *Client) precannedRegister(registrationCode, nick string, // It sends a registration message and returns the registration signature func (cl *Client) sendRegistrationMessage(registrationCode string, publicKeyRSA *rsa.PublicKey) ([]byte, error) { - connected, err := cl.AddPermissioningHost() + connected, err := AddPermissioningHost(cl.commManager, cl.ndf) if err != nil { return nil, errors.Wrap(err, "Couldn't connect to permissioning to send registration message") } diff --git a/bindings/client.go b/bindings/client.go index 76fa89d66dbb93efccffb1dedc7435c7003a4ec4..2e50eb3382cd787cb73938155546ee316b31b085 100644 --- a/bindings/client.go +++ b/bindings/client.go @@ -90,7 +90,7 @@ func (cl *Client) EnableDebugLogs() { // Connects to gateways and registration server (if needed) // using tls filepaths to create credential information // for connection establishment -func (cl *Client) Connect() error { +func (cl *Client) InitNetwork() error { globals.Log.INFO.Printf("Binding call: InitNetwork()") return cl.client.InitNetwork() } diff --git a/bindings/client_test.go b/bindings/client_test.go index a921801849be78b1586c8dadecdd63c5ffe9f22b..910ca5f1f1f8fb3ff3f9601cc16824adf8c89cf9 100644 --- a/bindings/client_test.go +++ b/bindings/client_test.go @@ -110,7 +110,7 @@ func TestRegister(t *testing.T) { t.Errorf("Failed to marshal group JSON: %s", err) } - err = client.Connect() + err = client.InitNetwork() if err != nil { t.Errorf("Could not connect: %+v", err) } @@ -144,7 +144,7 @@ func TestLoginLogout(t *testing.T) { t.Errorf("Error starting client: %+v", err) } // InitNetwork to gateway - err = client.Connect() + err = client.InitNetwork() if err != nil { t.Errorf("Could not connect: %+v", err) } @@ -186,7 +186,7 @@ func TestListen(t *testing.T) { d := DummyStorage{LocationA: "Blah", StoreA: []byte{'a', 'b', 'c'}} client, err := NewClient(&d, "hello", "", ndfStr, pubKey) // InitNetwork to gateway - err = client.Connect() + err = client.InitNetwork() if err != nil { t.Errorf("Could not connect: %+v", err) @@ -225,7 +225,7 @@ func TestStopListening(t *testing.T) { d := DummyStorage{LocationA: "Blah", StoreA: []byte{'a', 'b', 'c'}} client, err := NewClient(&d, "hello", "", ndfStr, pubKey) // InitNetwork to gateway - err = client.Connect() + err = client.InitNetwork() if err != nil { t.Errorf("Could not connect: %+v", err) diff --git a/bots/bots_test.go b/bots/bots_test.go index 8482173f2267b289b89e7dfe626c9ffd5ee6e44d..a5da9e435487fa9d023c4e9bb8c8363770f0fc66 100644 --- a/bots/bots_test.go +++ b/bots/bots_test.go @@ -56,7 +56,7 @@ func (d *dummyMessaging) SendMessageNoPartition(sess user.Session, // MessageReceiver thread to get new messages func (d *dummyMessaging) MessageReceiver(session user.Session, - delay time.Duration, rekeyChan chan struct{}, transmissionHost *connect.Host, callback func(error)) { + delay time.Duration, transmissionHost *connect.Host, callback func(error)) { } var pubKeyBits string @@ -214,7 +214,7 @@ func (e *errorMessaging) SendMessageNoPartition(sess user.Session, // MessageReceiver thread to get new messages func (e *errorMessaging) MessageReceiver(session user.Session, - delay time.Duration, rekeyChan chan struct{}, transmissionHost *connect.Host, callback func(error)) { + delay time.Duration, transmissionHost *connect.Host, callback func(error)) { } // Test LookupNick returns error on sending problem diff --git a/io/interface.go b/io/interface.go index 80853702db82608806f49af97476b5fb374de12a..09b1dd50b32f1a53d5d53216e4b700931f31605e 100644 --- a/io/interface.go +++ b/io/interface.go @@ -29,6 +29,6 @@ type Communications interface { // this can go recipientID *id.User, cryptoType parse.CryptoType, message []byte, transmissionHost *connect.Host) error // MessageReceiver thread to get new messages - MessageReceiver(session user.Session, delay time.Duration, rekeyChan chan struct{}, + MessageReceiver(session user.Session, delay time.Duration, receptionHost *connect.Host, callback func(error)) } diff --git a/io/receive.go b/io/receive.go index a25c612a54ba7bb065d3db7aeef5dad4ad8f3b3a..f5cb0991f20fecbe6ec81a61c679c413878e96ab 100644 --- a/io/receive.go +++ b/io/receive.go @@ -29,7 +29,7 @@ const reportDuration = 30 * time.Second var errE2ENotFound = errors.New("E2EKey for matching fingerprint not found, can't process message") // MessageReceiver is a polling thread for receiving messages -func (cm *CommManager) MessageReceiver(session user.Session, delay time.Duration, rekeyChan chan struct{}, +func (rm *ReceptionManager) MessageReceiver(session user.Session, delay time.Duration, receptionHost *connect.Host, callback func(error)) { // FIXME: It's not clear we should be doing decryption here. if session == nil { @@ -73,7 +73,7 @@ func (cm *CommManager) MessageReceiver(session user.Session, delay time.Duration var err error - encryptedMessages, err = cm.receiveMessagesFromGateway(session, &pollingMessage, receptionHost) + encryptedMessages, err = rm.receiveMessagesFromGateway(session, &pollingMessage, receptionHost) if err != nil { @@ -84,13 +84,13 @@ func (cm *CommManager) MessageReceiver(session user.Session, delay time.Duration callback(err) } NumMessages += len(encryptedMessages) - case <-rekeyChan: + case <-rm.rekeyChan: encryptedMessages = session.PopGarbledMessages() } if len(encryptedMessages) != 0 { - decryptedMessages, senders, garbledMessages := cm.decryptMessages(session, encryptedMessages) + decryptedMessages, senders, garbledMessages := rm.decryptMessages(session, encryptedMessages) if len(garbledMessages) != 0 { session.AppendGarbledMessage(garbledMessages...) @@ -99,7 +99,7 @@ func (cm *CommManager) MessageReceiver(session user.Session, delay time.Duration if decryptedMessages != nil { for i := range decryptedMessages { // TODO Handle messages that do not need partitioning - assembledMessage := cm.collator.AddMessage(decryptedMessages[i], + assembledMessage := rm.collator.AddMessage(decryptedMessages[i], senders[i], time.Minute) if assembledMessage != nil { // we got a fully assembled message. let's broadcast it @@ -163,14 +163,14 @@ func handleE2EReceiving(session user.Session, return sender, rekey, err } -func (cm *CommManager) receiveMessagesFromGateway(session user.Session, +func (rm *ReceptionManager) receiveMessagesFromGateway(session user.Session, pollingMessage *pb.ClientRequest, receiveGateway *connect.Host) ([]*format.Message, error) { // Get the last message ID received pollingMessage.LastMessageID = session.GetLastMessageID() // FIXME: dont do this over an over // Gets a list of mssages that are newer than the last one recieved - messageIDs, err := cm.Comms.SendCheckMessages(receiveGateway, pollingMessage) + messageIDs, err := rm.Comms.SendCheckMessages(receiveGateway, pollingMessage) if err != nil { return nil, err @@ -190,15 +190,15 @@ func (cm *CommManager) receiveMessagesFromGateway(session user.Session, bufLoc := 0 for _, messageID := range messageIDs.IDs { // Get the first unseen message from the list of IDs - cm.recievedMesageLock.RLock() - _, received := cm.receivedMessages[messageID] - cm.recievedMesageLock.RUnlock() + rm.recievedMesageLock.RLock() + _, received := rm.receivedMessages[messageID] + rm.recievedMesageLock.RUnlock() if !received { globals.Log.INFO.Printf("Got a message waiting on the gateway: %v", messageID) // We haven't seen this message before. // So, we should retrieve it from the gateway. - newMessage, err := cm.Comms.SendGetMessage(receiveGateway, + newMessage, err := rm.Comms.SendGetMessage(receiveGateway, &pb.ClientRequest{ UserID: session.GetCurrentUser().User. Bytes(), @@ -232,9 +232,9 @@ func (cm *CommManager) receiveMessagesFromGateway(session user.Session, for i := 0; i < bufLoc; i++ { globals.Log.INFO.Printf( "Adding message ID %v to received message IDs", mIDs[i]) - cm.recievedMesageLock.Lock() - cm.receivedMessages[mIDs[i]] = struct{}{} - cm.recievedMesageLock.Unlock() + rm.recievedMesageLock.Lock() + rm.receivedMessages[mIDs[i]] = struct{}{} + rm.recievedMesageLock.Unlock() } session.SetLastMessageID(mIDs[bufLoc-1]) err = session.StoreSession() @@ -247,7 +247,7 @@ func (cm *CommManager) receiveMessagesFromGateway(session user.Session, return messages[:bufLoc], nil } -func (cm *CommManager) decryptMessages(session user.Session, +func (rm *ReceptionManager) decryptMessages(session user.Session, encryptedMessages []*format.Message) ([]*format.Message, []*id.User, []*format.Message) { diff --git a/io/commManager.go b/io/receptionManager.go similarity index 71% rename from io/commManager.go rename to io/receptionManager.go index 39e87bb33328ce82cec61fb6a8ff221c7d415837..b61843781de6a7d7bb655f3981d8e3abdb9b146d 100644 --- a/io/commManager.go +++ b/io/receptionManager.go @@ -25,8 +25,8 @@ func (a ConnAddr) String() string { return string(a) } -// CommManager implements the Communications interface -type CommManager struct { +// ReceptionManager implements the Communications interface +type ReceptionManager struct { // Comms pointer to send/recv messages Comms *client.Comms @@ -35,9 +35,9 @@ type CommManager struct { // blockTransmissions will use a mutex to prevent multiple threads from sending // messages at the same time. - blockTransmissions bool + blockTransmissions bool // pass into receiver // transmitDelay is the minimum delay between transmissions. - transmitDelay time.Duration + transmitDelay time.Duration // same // Map that holds a record of the messages that this client successfully // received during this session receivedMessages map[string]struct{} @@ -45,17 +45,18 @@ type CommManager struct { sendLock sync.Mutex - lock sync.RWMutex + rekeyChan chan struct{} } -func NewCommManager() *CommManager { - cm := &CommManager{ +func NewReceptionManager(rekeyChan chan struct{}) *ReceptionManager { + cm := &ReceptionManager{ nextId: parse.IDCounter(), collator: NewCollator(), blockTransmissions: true, transmitDelay: 1000 * time.Millisecond, receivedMessages: make(map[string]struct{}), Comms: &client.Comms{}, + rekeyChan: rekeyChan, } return cm @@ -63,12 +64,12 @@ func NewCommManager() *CommManager { // Connects to the permissioning server, if we know about it, to get the latest // version from it -func (cm *CommManager) GetRemoteVersion() (string, error) { // need this but make getremoteversion, handle versioning in client - permissioningHost, ok := cm.Comms.GetHost(PermissioningAddrID) +func (rm *ReceptionManager) GetRemoteVersion() (string, error) { // need this but make getremoteversion, handle versioning in client + permissioningHost, ok := rm.Comms.GetHost(PermissioningAddrID) if !ok { return "", errors.Errorf("Failed to find permissioning host with id %s", PermissioningAddrID) } - registrationVersion, err := cm.Comms. + registrationVersion, err := rm.Comms. SendGetCurrentClientVersionMessage(permissioningHost) if err != nil { return "", errors.Wrap(err, "Couldn't get current version from permissioning") @@ -76,10 +77,10 @@ func (cm *CommManager) GetRemoteVersion() (string, error) { // need this but mak return registrationVersion.Version, nil } -func (cm *CommManager) DisableBlockingTransmission() { // flag passed into receiver - cm.blockTransmissions = false +func (rm *ReceptionManager) DisableBlockingTransmission() { // flag passed into receiver + rm.blockTransmissions = false } -func (cm *CommManager) SetRateLimit(delay time.Duration) { // pass into received - cm.transmitDelay = delay +func (rm *ReceptionManager) SetRateLimit(delay time.Duration) { // pass into received + rm.transmitDelay = delay } diff --git a/io/send.go b/io/send.go index 623e7b90a1cbe2b5c6202e657f681c7833c2b93c..a95a8dedc075f9573be0482f83b01aad0a33c7e2 100644 --- a/io/send.go +++ b/io/send.go @@ -31,7 +31,7 @@ import ( // the keys) here. I won't touch crypto at this time, though... // TODO This method would be cleaner if it took a parse.Message (particularly // w.r.t. generating message IDs for multi-part messages.) -func (cm *CommManager) SendMessage(session user.Session, topology *circuit.Circuit, +func (rm *ReceptionManager) SendMessage(session user.Session, topology *circuit.Circuit, recipientID *id.User, cryptoType parse.CryptoType, message []byte, transmissionHost *connect.Host) error { // FIXME: We should really bring the plaintext parts of the NewMessage logic @@ -44,7 +44,7 @@ func (cm *CommManager) SendMessage(session user.Session, topology *circuit.Circu // in this library? why not pass a sender object instead? globals.Log.DEBUG.Printf("Sending message to %q: %q", *recipientID, message) parts, err := parse.Partition([]byte(message), - cm.nextId()) + rm.nextId()) if err != nil { return fmt.Errorf("SendMessage Partition() error: %v", err.Error()) } @@ -61,25 +61,23 @@ func (cm *CommManager) SendMessage(session user.Session, topology *circuit.Circu } // Add a byte for later encryption (15->16 bytes) extendedNowBytes := append(nowBytes, 0) - cm.lock.RLock() for i := range parts { message := format.NewMessage() message.SetRecipient(recipientID) message.SetTimestamp(extendedNowBytes) message.Contents.SetRightAligned(parts[i]) - err = cm.send(session, topology, cryptoType, message, false, transmissionHost) + err = rm.send(session, topology, cryptoType, message, false, transmissionHost) if err != nil { return errors.Wrap(err, "SendMessage send() error:") } } - cm.lock.RUnlock() return nil } // Send Message without doing partitions // This function will be needed for example to send a Rekey // message, where a new public key will take up the whole message -func (cm *CommManager) SendMessageNoPartition(session user.Session, +func (rm *ReceptionManager) SendMessageNoPartition(session user.Session, topology *circuit.Circuit, recipientID *id.User, cryptoType parse.CryptoType, message []byte, transmissionHost *connect.Host) error { size := len(message) @@ -104,7 +102,7 @@ func (cm *CommManager) SendMessageNoPartition(session user.Session, msg.Contents.Set(message) globals.Log.DEBUG.Printf("Sending message to %v: %x", *recipientID, message) - err = cm.send(session, topology, cryptoType, msg, true, transmissionHost) + err = rm.send(session, topology, cryptoType, msg, true, transmissionHost) if err != nil { return fmt.Errorf("SendMessageNoPartition send() error: %v", err.Error()) } @@ -112,16 +110,16 @@ func (cm *CommManager) SendMessageNoPartition(session user.Session, } // send actually sends the message to the server -func (cm *CommManager) send(session user.Session, topology *circuit.Circuit, +func (rm *ReceptionManager) send(session user.Session, topology *circuit.Circuit, cryptoType parse.CryptoType, message *format.Message, rekey bool, transmitGateway *connect.Host) error { // Enable transmission blocking if enabled - if cm.blockTransmissions { - cm.sendLock.Lock() + if rm.blockTransmissions { + rm.sendLock.Lock() defer func() { - time.Sleep(cm.transmitDelay) - cm.sendLock.Unlock() + time.Sleep(rm.transmitDelay) + rm.sendLock.Unlock() }() } @@ -149,7 +147,7 @@ func (cm *CommManager) send(session user.Session, topology *circuit.Circuit, KMACs: kmacs, } - return cm.Comms.SendPutMessage(transmitGateway, msgPacket) + return rm.Comms.SendPutMessage(transmitGateway, msgPacket) } func handleE2ESending(session user.Session, diff --git a/rekey/rekey_test.go b/rekey/rekey_test.go index 6323fac9ebba5a484e937249b923d310770a857d..bc53bb920ef116a89587f3b44c835bf5e7815395 100644 --- a/rekey/rekey_test.go +++ b/rekey/rekey_test.go @@ -51,7 +51,7 @@ func (d *dummyMessaging) SendMessageNoPartition(sess user.Session, // MessageReceiver thread to get new messages func (d *dummyMessaging) MessageReceiver(session user.Session, - delay time.Duration, rekeyChan chan struct{}, transmissionHost *connect.Host, callback func(error)) { + delay time.Duration, transmissionHost *connect.Host, callback func(error)) { } func TestMain(m *testing.M) {