diff --git a/auth/interface.go b/auth/interface.go index c6f5a5193a13b6cebb42a4e930a0f58d6c52e23d..ea1b108744ed2fceb201eb6268c2d54d684f8002 100644 --- a/auth/interface.go +++ b/auth/interface.go @@ -100,6 +100,10 @@ type State interface { // auth callback for the given partner ID. DeletePartnerCallback(partnerId *id.ID) + // DeletePartner deletes the request and/or confirmation for the given + // partner. + DeletePartner(partner *id.ID) error + // Closer stops listening to auth. io.Closer } diff --git a/auth/state.go b/auth/state.go index fda5ba2c98c08f36be9ae3b58aedcec618ef0b4b..3b82f15e46049abd5cc8be2367fb4abb45b68e71 100644 --- a/auth/state.go +++ b/auth/state.go @@ -142,6 +142,22 @@ func (s *state) Close() error { return nil } +// DeletePartner deletes the request and/or confirmation for the given partner. +func (s *state) DeletePartner(partner *id.ID) error { + err := s.store.DeleteRequest(partner) + err2 := s.store.DeleteConfirmation(partner) + + // Only return an error if both failed to delete + if err != nil && err2 != nil { + return errors.Errorf("Failed to delete partner: no requests or "+ + "confirmations found: %s, %s", err, err2) + } + + s.DeletePartnerCallback(partner) + + return nil +} + // AddPartnerCallback that overrides the generic auth callback for the given partnerId func (s *state) AddPartnerCallback(partnerId *id.ID, cb Callbacks) { s.partnerCallbacks.AddPartnerCallback(partnerId, cb) diff --git a/auth/utils_test.go b/auth/utils_test.go index 4d6b50ade9db9ff930e3824af62515a098578e77..4c53ff91a95c40e4b138258592cb292b12d8e749 100644 --- a/auth/utils_test.go +++ b/auth/utils_test.go @@ -88,6 +88,8 @@ func (m mockE2eHandler) Unregister(listenerID receive.ListenerID) { return } +func (m mockE2eHandler) UnregisterUserListeners(*id.ID) {} + func (m mockE2eHandler) AddPartner(partnerID *id.ID, partnerPubKey, myPrivKey *cyclic.Int, partnerSIDHPubKey *sidh.PublicKey, mySIDHPrivKey *sidh.PrivateKey, diff --git a/bindings/connect.go b/bindings/connect.go index 5d61f5baf5a52fabea7cd668950a69e1791da233..98f9795398040d1cb372e8c54fb4d273efd84dbc 100644 --- a/bindings/connect.go +++ b/bindings/connect.go @@ -88,6 +88,7 @@ func (c *Connection) GetPartner() []byte { // RegisterListener is used for E2E reception // and allows for reading data sent from the partner.Manager // Returns marshalled ListenerID -func (c *Connection) RegisterListener(messageType int, newListener Listener) { - _ = c.connection.RegisterListener(catalog.MessageType(messageType), listener{l: newListener}) +func (c *Connection) RegisterListener(messageType int, newListener Listener) error { + _, err := c.connection.RegisterListener(catalog.MessageType(messageType), listener{l: newListener}) + return err } diff --git a/connect/authCallbacks.go b/connect/authCallbacks.go index 3c4670bd3fce5be27d230e4b67d8fef43143d661..cdf8b849fdf3b4065d3e19b62ec5191b79fbf48a 100644 --- a/connect/authCallbacks.go +++ b/connect/authCallbacks.go @@ -90,6 +90,9 @@ type serverAuthCallback struct { confirmCallback Callback requestCallback Callback + // Used to track stale connections + cl *ConnectionList + // Used for building new Connection objects connectionParams Params } @@ -97,10 +100,12 @@ type serverAuthCallback struct { // getServerAuthCallback returns a callback interface to be passed into the creation // of a xxdk.E2e object. // it will accept requests only if a request callback is passed in -func getServerAuthCallback(confirm, request Callback, params Params) *serverAuthCallback { +func getServerAuthCallback(confirm, request Callback, cl *ConnectionList, + params Params) *serverAuthCallback { return &serverAuthCallback{ confirmCallback: confirm, requestCallback: request, + cl: cl, connectionParams: params, } } @@ -136,8 +141,10 @@ func (a serverAuthCallback) Request(requestor contact.Contact, } // Return the new Connection object - a.requestCallback(BuildConnection(newPartner, e2e.GetE2E(), - e2e.GetAuth(), a.connectionParams)) + c := BuildConnection( + newPartner, e2e.GetE2E(), e2e.GetAuth(), a.connectionParams) + a.cl.Add(c) + a.requestCallback(c) } // Reset will be called when an auth Reset operation occurs. diff --git a/connect/authenticated.go b/connect/authenticated.go index 89311eff3ee33dd957b790c06a98e3de7bd745b9..7fc0d0e7c6d15266f3c04b876cb46ff179123f5c 100644 --- a/connect/authenticated.go +++ b/connect/authenticated.go @@ -170,7 +170,7 @@ func connectWithAuthentication(conn Connection, timeStart time.Time, // authenticate themselves. An established AuthenticatedConnection will // be passed via the callback. func StartAuthenticatedServer(identity xxdk.ReceptionIdentity, - cb AuthenticatedCallback, net *xxdk.Cmix, p Params) (*xxdk.E2e, error) { + cb AuthenticatedCallback, net *xxdk.Cmix, p Params) (*ConnectionServer, error) { // Register the waiter for a connection establishment connCb := Callback(func(connection Connection) { @@ -178,8 +178,14 @@ func StartAuthenticatedServer(identity xxdk.ReceptionIdentity, // client's identity proof. If an identity authentication // message is received and validated, an authenticated connection will // be passed along via the AuthenticatedCallback - connection.RegisterListener(catalog.ConnectionAuthenticationRequest, + _, err := connection.RegisterListener( + catalog.ConnectionAuthenticationRequest, buildAuthConfirmationHandler(cb, connection)) + if err != nil { + jww.ERROR.Printf( + "Failed to register listener on connection with %s: %+v", + connection.GetPartner().PartnerId(), err) + } }) return StartServer(identity, connCb, net, p) } diff --git a/connect/connect.go b/connect/connect.go index 429f7331d4f1923cec27497f048fa9a3764567a0..96f4b900fc8f93805359981c33b8f9cafbf56493 100644 --- a/connect/connect.go +++ b/connect/connect.go @@ -8,8 +8,12 @@ package connect import ( "encoding/json" + "gitlab.com/elixxir/client/e2e/rekey" + "gitlab.com/elixxir/client/event" "gitlab.com/elixxir/client/xxdk" + "gitlab.com/xx_network/primitives/netTime" "io" + "sync/atomic" "time" "github.com/pkg/errors" @@ -19,8 +23,6 @@ import ( clientE2e "gitlab.com/elixxir/client/e2e" "gitlab.com/elixxir/client/e2e/ratchet/partner" "gitlab.com/elixxir/client/e2e/receive" - "gitlab.com/elixxir/client/e2e/rekey" - "gitlab.com/elixxir/client/event" "gitlab.com/elixxir/crypto/contact" "gitlab.com/elixxir/crypto/e2e" "gitlab.com/xx_network/primitives/id" @@ -32,6 +34,8 @@ const ( connectionTimeout = 15 * time.Second ) +var alreadyClosedErr = errors.New("connection is closed") + // Connection is a wrapper for the E2E and auth packages. // It can be used to automatically establish an E2E partnership // with a partner.Manager, or be built from an existing E2E partnership. @@ -52,7 +56,7 @@ type Connection interface { // RegisterListener is used for E2E reception // and allows for reading data sent from the partner.Manager RegisterListener(messageType catalog.MessageType, - newListener receive.Listener) receive.ListenerID + newListener receive.Listener) (receive.ListenerID, error) // Unregister listener for E2E reception Unregister(listenerID receive.ListenerID) @@ -71,6 +75,10 @@ type Connection interface { // PayloadSize Returns the max payload size for a partitionable E2E // message PayloadSize() uint + + // LastUse returns the timestamp of the last time the connection was + // utilised. + LastUse() time.Time } // Callback is the callback format required to retrieve @@ -82,6 +90,7 @@ type Params struct { Auth auth.Params Rekey rekey.Params Event event.Reporter `json:"-"` + List ConnectionListParams Timeout time.Duration } @@ -91,6 +100,7 @@ func GetDefaultParams() Params { Auth: auth.GetDefaultTemporaryParams(), Rekey: rekey.GetDefaultEphemeralParams(), Event: event.NewEventManager(), + List: DefaultConnectionListParams(), Timeout: connectionTimeout, } } @@ -158,13 +168,31 @@ func Connect(recipient contact.Contact, e2eClient *xxdk.E2e, // This call does an xxDK.ephemeralLogin under the hood and the connection // server must be the only listener on auth. func StartServer(identity xxdk.ReceptionIdentity, cb Callback, net *xxdk.Cmix, - p Params) (*xxdk.E2e, error) { + p Params) (*ConnectionServer, error) { + + // Create connection list and start cleanup thread + cl := NewConnectionList(p.List) + err := net.AddService(cl.CleanupThread) + if err != nil { + return nil, err + } // Build callback for E2E negotiation - callback := getServerAuthCallback(nil, cb, p) + callback := getServerAuthCallback(nil, cb, cl, p) + + e2eClient, err := xxdk.LoginEphemeral(net, callback, identity) + if err != nil { + return nil, err + } // Return an ephemeral E2e object - return xxdk.LoginEphemeral(net, callback, identity) + return &ConnectionServer{e2eClient, cl}, nil +} + +// ConnectionServer contains +type ConnectionServer struct { + e2e *xxdk.E2e + cl *ConnectionList } // handler provides an implementation for the Connection interface. @@ -172,7 +200,14 @@ type handler struct { auth auth.State partner partner.Manager e2e clientE2e.Handler - params Params + + // Timestamp of last time a message was sent or received (Unix nanoseconds) + lastUse *int64 + + // Indicates if the connection has been closed (0 = open, 1 = closed) + closed *uint32 + + params Params } // BuildConnection assembles a Connection object @@ -188,12 +223,30 @@ func BuildConnection(partner partner.Manager, e2eHandler clientE2e.Handler, } } -// Close deletes this Connection's partner.Manager and releases resources. +// Close deletes this Connection's partner.Manager and releases resources. If +// the connection is already closed, then nil is returned. func (h *handler) Close() error { - if err := h.e2e.DeletePartner(h.partner.PartnerId()); err != nil { + if h.isClosed() { + return nil + } + + // Get partner ID once at the top because PartnerId makes a copy + partnerID := h.partner.PartnerId() + + // Unregister all listeners + h.e2e.UnregisterUserListeners(partnerID) + + // Delete partner from e2e and auth + if err := h.e2e.DeletePartner(partnerID); err != nil { return err } - return h.auth.Close() + if err := h.auth.DeletePartner(partnerID); err != nil { + return err + } + + atomic.StoreUint32(h.closed, 1) + + return nil } // GetPartner returns the partner.Manager for this Connection. @@ -204,17 +257,25 @@ func (h *handler) GetPartner() partner.Manager { // SendE2E is a wrapper for sending specifically to the Connection's // partner.Manager. func (h *handler) SendE2E(mt catalog.MessageType, payload []byte, - params clientE2e.Params) ( - []id.Round, e2e.MessageID, time.Time, error) { + params clientE2e.Params) ([]id.Round, e2e.MessageID, time.Time, error) { + if h.isClosed() { + return nil, e2e.MessageID{}, time.Time{}, alreadyClosedErr + } + + h.updateLastUse(netTime.Now()) + return h.e2e.SendE2E(mt, h.partner.PartnerId(), payload, params) } // RegisterListener is used for E2E reception // and allows for reading data sent from the partner.Manager. func (h *handler) RegisterListener(messageType catalog.MessageType, - newListener receive.Listener) receive.ListenerID { - return h.e2e.RegisterListener(h.partner.PartnerId(), - messageType, newListener) + newListener receive.Listener) (receive.ListenerID, error) { + if h.isClosed() { + return receive.ListenerID{}, alreadyClosedErr + } + lt := &listenerTracker{h, newListener} + return h.e2e.RegisterListener(h.partner.PartnerId(), messageType, lt), nil } // Unregister listener for E2E reception. @@ -245,3 +306,18 @@ func (h *handler) PartitionSize(payloadIndex uint) uint { func (h *handler) PayloadSize() uint { return h.e2e.PayloadSize() } + +// LastUse returns the timestamp of the last time the connection was utilised. +func (h *handler) LastUse() time.Time { + return time.Unix(0, atomic.LoadInt64(h.lastUse)) +} + +// updateLastUse updates the last use time stamp to the given time. +func (h *handler) updateLastUse(t time.Time) { + atomic.StoreInt64(h.lastUse, t.UnixNano()) +} + +// isClosed returns true if the connection is closed. +func (h *handler) isClosed() bool { + return atomic.LoadUint32(h.closed) == 1 +} diff --git a/connect/connectionList.go b/connect/connectionList.go new file mode 100644 index 0000000000000000000000000000000000000000..fab3404126334ce4c3e0b29f46f9e307a751894e --- /dev/null +++ b/connect/connectionList.go @@ -0,0 +1,119 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package connect + +import ( + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/stoppable" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "sync" + "time" +) + +// ConnectionList is a list of all connections. +type ConnectionList struct { + list map[id.ID]Connection + p ConnectionListParams + mux sync.Mutex +} + +// NewConnectionList initialises an empty ConnectionList. +func NewConnectionList(p ConnectionListParams) *ConnectionList { + return &ConnectionList{ + list: make(map[id.ID]Connection), + p: p, + } +} + +// Add adds the connection to the list. +func (cl *ConnectionList) Add(c Connection) { + cl.mux.Lock() + defer cl.mux.Unlock() + + cl.list[*c.GetPartner().PartnerId()] = c +} + +// CleanupThread runs the loop that runs the cleanup processes periodically. +func (cl *ConnectionList) CleanupThread() (stoppable.Stoppable, error) { + stop := stoppable.NewSingle("StaleConnectionCleanup") + + go func() { + jww.INFO.Printf("Starting stale connection cleanup thread to delete "+ + "connections older than %s. Running every %s.", + cl.p.MaxAge, cl.p.CleanupPeriod) + ticker := time.NewTicker(cl.p.CleanupPeriod) + for { + select { + case <-stop.Quit(): + jww.INFO.Print( + "Stopping connection cleanup thread: stoppable triggered") + ticker.Stop() + stop.ToStopped() + case <-ticker.C: + jww.DEBUG.Print("Starting connection cleanup.") + cl.Cleanup() + } + } + }() + + return stop, nil +} + +// Cleanup disconnects all connections that have been stale for longer than the +// max allowed time. +func (cl *ConnectionList) Cleanup() { + cl.mux.Lock() + defer cl.mux.Unlock() + + for partnerID, c := range cl.list { + lastUse := c.LastUse() + timeSinceLastUse := netTime.Since(lastUse) + if timeSinceLastUse > cl.p.MaxAge { + err := c.Close() + if err != nil { + jww.ERROR.Printf( + "Could not close connection with partner %s: %+v", + partnerID, err) + } + delete(cl.list, partnerID) + + jww.INFO.Printf("Deleted stale connection for partner %s. "+ + "Last use was %s ago (%s)", + &partnerID, timeSinceLastUse, lastUse.Local()) + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Parameters // +//////////////////////////////////////////////////////////////////////////////// + +// Default values. +const ( + cleanupPeriodDefault = 5 * time.Minute + maxAgeDefault = 30 * time.Minute +) + +// ConnectionListParams are the parameters used for the ConnectionList. +type ConnectionListParams struct { + // CleanupPeriod is the duration between when cleanups occur. + CleanupPeriod time.Duration + + // MaxAge is the maximum age of an unused connection before it is deleted. + MaxAge time.Duration +} + +// DefaultConnectionListParams returns a ConnectionListParams filled with +// default values. +func DefaultConnectionListParams() ConnectionListParams { + return ConnectionListParams{ + CleanupPeriod: cleanupPeriodDefault, + MaxAge: maxAgeDefault, + } +} diff --git a/connect/connectionList_test.go b/connect/connectionList_test.go new file mode 100644 index 0000000000000000000000000000000000000000..b76c6b35876ee6656b099a46b663a75820db08e5 --- /dev/null +++ b/connect/connectionList_test.go @@ -0,0 +1,125 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package connect + +import ( + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "reflect" + "testing" + "time" +) + +// Tests that NewConnectionList returned the expected new ConnectionList. +func TestNewConnectionList(t *testing.T) { + expected := &ConnectionList{ + list: make(map[id.ID]Connection), + p: DefaultConnectionListParams(), + } + + cl := NewConnectionList(expected.p) + + if !reflect.DeepEqual(expected, cl) { + t.Errorf("New ConnectionList did not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, cl) + } +} + +// Tests that ConnectionList.Add adds all the given connections to the list. +func TestConnectionList_Add(t *testing.T) { + cl := NewConnectionList(DefaultConnectionListParams()) + + expected := map[id.ID]Connection{ + *id.NewIdFromString("p1", id.User, t): &handler{ + partner: &mockPartner{partnerId: id.NewIdFromString("p1", id.User, t)}}, + *id.NewIdFromString("p2", id.User, t): &handler{ + partner: &mockPartner{partnerId: id.NewIdFromString("p2", id.User, t)}}, + *id.NewIdFromString("p3", id.User, t): &handler{ + partner: &mockPartner{partnerId: id.NewIdFromString("p3", id.User, t)}}, + *id.NewIdFromString("p4", id.User, t): &handler{ + partner: &mockPartner{partnerId: id.NewIdFromString("p4", id.User, t)}}, + *id.NewIdFromString("p5", id.User, t): &handler{ + partner: &mockPartner{partnerId: id.NewIdFromString("p5", id.User, t)}}, + } + + for _, c := range expected { + cl.Add(c) + } + + if !reflect.DeepEqual(expected, cl.list) { + t.Errorf("List does not have expected connections."+ + "\nexpected: %+v\nreceived: %+v", expected, cl.list) + } + +} + +// Tests that ConnectionList.Cleanup deletes only stale connections from the +// list and that they are closed. +func TestConnectionList_Cleanup(t *testing.T) { + cl := NewConnectionList(DefaultConnectionListParams()) + + list := []*mockConnection{ + { + partner: &mockPartner{partnerId: id.NewIdFromString("p0", id.User, t)}, + lastUse: netTime.Now().Add(-(cl.p.MaxAge * 2)), + }, { + partner: &mockPartner{partnerId: id.NewIdFromString("p1", id.User, t)}, + lastUse: netTime.Now().Add(-(cl.p.MaxAge / 2)), + }, { + partner: &mockPartner{partnerId: id.NewIdFromString("p2", id.User, t)}, + lastUse: netTime.Now().Add(-(cl.p.MaxAge + 10)), + }, { + partner: &mockPartner{partnerId: id.NewIdFromString("p3", id.User, t)}, + lastUse: netTime.Now().Add(-(cl.p.MaxAge - time.Second)), + }, + } + + for _, c := range list { + cl.Add(c) + } + + cl.Cleanup() + + for i, c := range list { + if i%2 == 0 { + if _, exists := cl.list[*c.GetPartner().PartnerId()]; exists { + t.Errorf("Connection #%d exists while being stale.", i) + } + if !c.closed { + t.Errorf("Connection #%d was not closed.", i) + } + } else { + if _, exists := cl.list[*c.GetPartner().PartnerId()]; !exists { + t.Errorf("Connection #%d was removed when it was not stale.", i) + } + if c.closed { + t.Errorf("Connection #%d was closed.", i) + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Parameters // +//////////////////////////////////////////////////////////////////////////////// + +// Tests that DefaultConnectionListParams returns a ConnectionListParams with +// the expected default values. +func TestDefaultConnectionListParams(t *testing.T) { + expected := ConnectionListParams{ + CleanupPeriod: cleanupPeriodDefault, + MaxAge: maxAgeDefault, + } + + p := DefaultConnectionListParams() + + if !reflect.DeepEqual(expected, p) { + t.Errorf("Default ConnectionListParams does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, p) + } +} diff --git a/connect/listenerTracker.go b/connect/listenerTracker.go new file mode 100644 index 0000000000000000000000000000000000000000..875598dfe53d314cf0052c2ab592ceff425b032c --- /dev/null +++ b/connect/listenerTracker.go @@ -0,0 +1,30 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package connect + +import ( + "gitlab.com/elixxir/client/e2e/receive" + "gitlab.com/xx_network/primitives/netTime" +) + +// listenerTracker wraps a listener and updates the last use timestamp on every +// call to Hear. +type listenerTracker struct { + h *handler + l receive.Listener +} + +// Hear updates the last call timestamp and then calls Hear on the wrapped +// listener. +func (lt *listenerTracker) Hear(item receive.Message) { + lt.h.updateLastUse(netTime.Now()) + lt.l.Hear(item) +} + +// Name returns a name, used for debugging. +func (lt *listenerTracker) Name() string { return lt.l.Name() } diff --git a/connect/listenerTracker_test.go b/connect/listenerTracker_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1cab220ad7955c6b25be5b716fc41197eefc055b --- /dev/null +++ b/connect/listenerTracker_test.go @@ -0,0 +1,74 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package connect + +import ( + "gitlab.com/elixxir/client/e2e/receive" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "reflect" + "testing" + "time" +) + +// Tests that listenerTracker.Hear correctly updates the last use timestamp and +// calls the wrapped Hear function. +func Test_listenerTracker_Hear(t *testing.T) { + lastUsed, closed := int64(0), uint32(0) + itemChan := make(chan receive.Message, 1) + lt := &listenerTracker{ + h: &handler{ + lastUse: &lastUsed, + closed: &closed, + }, + l: &mockListener{itemChan, "mockListener"}, + } + + expected := receive.Message{ + Payload: []byte("Message payload."), + Sender: id.NewIdFromString("senderID", id.User, t), + RecipientID: id.NewIdFromString("RecipientID", id.User, t), + } + + lt.Hear(expected) + + select { + case r := <-itemChan: + if !reflect.DeepEqual(expected, r) { + t.Errorf("Did not receive expected receive.Message."+ + "\nexpected: %+v\nreceived: %+v", expected, r) + } + if netTime.Since(lt.h.LastUse()) > 300*time.Millisecond { + t.Errorf("Last use has incorrect time: %s", lt.h.LastUse()) + } + + case <-time.After(30 * time.Millisecond): + t.Error("Timed out waiting for Hear to be called.") + } +} + +// Tests that listenerTracker.Name calls the wrapped listener Name function. +func Test_listenerTracker_Name(t *testing.T) { + expected := "mockListener" + lt := &listenerTracker{ + l: &mockListener{make(chan receive.Message, 1), "mockListener"}, + } + + if lt.Name() != expected { + t.Errorf("Did not get expected name.\nexected: %s\nreceived: %s", + expected, lt.Name()) + } +} + +type mockListener struct { + item chan receive.Message + name string +} + +func (m *mockListener) Hear(item receive.Message) { m.item <- item } +func (m *mockListener) Name() string { return m.name } diff --git a/connect/utils_test.go b/connect/utils_test.go index dec043d0433cfe2efe7d6016a45ae61a1c1d1488..6e4be5173d00cd8ba4b671d6abee3a8669e1c77c 100644 --- a/connect/utils_test.go +++ b/connect/utils_test.go @@ -22,15 +22,17 @@ import ( "gitlab.com/xx_network/crypto/large" "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/id/ephemeral" - "gitlab.com/xx_network/primitives/ndf" "gitlab.com/xx_network/primitives/netTime" "time" ) //////////////////////////////////////////////////////////////////////////////// -// Mock Partner Interface // +// Mock Partner Interface // //////////////////////////////////////////////////////////////////////////////// +// Tests that mockPartner adheres to the partner.Manager interface. +var _ partner.Manager = (*mockPartner)(nil) + type mockPartner struct { partnerId *id.ID myID *id.ID @@ -48,126 +50,73 @@ func newMockPartner(partnerId, myId *id.ID, } } -func (m mockPartner) PartnerId() *id.ID { - return m.partnerId -} - -func (m mockPartner) MyId() *id.ID { - return m.myID -} - -func (m mockPartner) MyRootPrivateKey() *cyclic.Int { - return m.myDhPrivKey -} - -func (m mockPartner) PartnerRootPublicKey() *cyclic.Int { - return m.partnerDhPubKey -} - -func (m mockPartner) SendRelationshipFingerprint() []byte { - return nil -} - -func (m mockPartner) ReceiveRelationshipFingerprint() []byte { - return nil -} - -func (m mockPartner) ConnectionFingerprint() partner.ConnectionFp { - return partner.ConnectionFp{} -} - -func (m mockPartner) Contact() contact.Contact { +func (m *mockPartner) PartnerId() *id.ID { return m.partnerId } +func (m *mockPartner) MyId() *id.ID { return m.myID } +func (m *mockPartner) MyRootPrivateKey() *cyclic.Int { return m.myDhPrivKey } +func (m *mockPartner) PartnerRootPublicKey() *cyclic.Int { return m.partnerDhPubKey } +func (m *mockPartner) SendRelationshipFingerprint() []byte { return nil } +func (m *mockPartner) ReceiveRelationshipFingerprint() []byte { return nil } +func (m *mockPartner) ConnectionFingerprint() partner.ConnectionFp { return partner.ConnectionFp{} } +func (m *mockPartner) Contact() contact.Contact { return contact.Contact{ ID: m.partnerId, DhPubKey: m.partnerDhPubKey, } } - -func (m mockPartner) PopSendCypher() (session.Cypher, error) { - return nil, nil -} - -func (m mockPartner) PopRekeyCypher() (session.Cypher, error) { - return nil, nil -} - -func (m mockPartner) NewReceiveSession(partnerPubKey *cyclic.Int, partnerSIDHPubKey *sidh.PublicKey, e2eParams session.Params, source *session.Session) (*session.Session, bool) { +func (m *mockPartner) PopSendCypher() (session.Cypher, error) { return nil, nil } +func (m *mockPartner) PopRekeyCypher() (session.Cypher, error) { return nil, nil } +func (m *mockPartner) NewReceiveSession(*cyclic.Int, *sidh.PublicKey, session.Params, *session.Session) (*session.Session, bool) { return nil, false } - -func (m mockPartner) NewSendSession(myDHPrivKey *cyclic.Int, mySIDHPrivateKey *sidh.PrivateKey, e2eParams session.Params, source *session.Session) *session.Session { - return nil -} - -func (m mockPartner) GetSendSession(sid session.SessionID) *session.Session { - return nil -} - -func (m mockPartner) GetReceiveSession(sid session.SessionID) *session.Session { - return nil -} - -func (m mockPartner) Confirm(sid session.SessionID) error { - return nil -} - -func (m mockPartner) TriggerNegotiations() []*session.Session { - return nil -} - -func (m mockPartner) MakeService(tag string) message.Service { - return message.Service{} -} - -func (m mockPartner) Delete() error { +func (m *mockPartner) NewSendSession(*cyclic.Int, *sidh.PrivateKey, session.Params, *session.Session) *session.Session { return nil } +func (m *mockPartner) GetSendSession(session.SessionID) *session.Session { return nil } +func (m *mockPartner) GetReceiveSession(session.SessionID) *session.Session { return nil } +func (m *mockPartner) Confirm(session.SessionID) error { return nil } +func (m *mockPartner) TriggerNegotiations() []*session.Session { return nil } +func (m *mockPartner) MakeService(string) message.Service { return message.Service{} } +func (m *mockPartner) Delete() error { return nil } //////////////////////////////////////////////////////////////////////////////// -// Mock Connection Interface // +// Mock Connection Interface // //////////////////////////////////////////////////////////////////////////////// +// Tests that mockConnection adheres to the Connection interface. +var _ Connection = (*mockConnection)(nil) + type mockConnection struct { partner *mockPartner payloadChan chan []byte listener server + lastUse time.Time + closed bool } -func newMockConnection(partnerId, myId *id.ID, - myDhPrivKey, partnerDhPubKey *cyclic.Int) *mockConnection { +func newMockConnection(partnerId, myId *id.ID, myDhPrivKey, + partnerDhPubKey *cyclic.Int) *mockConnection { return &mockConnection{ - partner: newMockPartner(partnerId, myId, - myDhPrivKey, partnerDhPubKey), + partner: newMockPartner(partnerId, myId, myDhPrivKey, partnerDhPubKey), payloadChan: make(chan []byte, 1), } } -func (m mockConnection) FirstPartitionSize() uint { - return 0 -} - -func (m mockConnection) SecondPartitionSize() uint { - return 0 -} - -func (m mockConnection) PartitionSize(payloadIndex uint) uint { - return 0 -} - -func (m mockConnection) PayloadSize() uint { - return 0 -} +func (m *mockConnection) FirstPartitionSize() uint { return 0 } +func (m *mockConnection) SecondPartitionSize() uint { return 0 } +func (m *mockConnection) PartitionSize(uint) uint { return 0 } +func (m *mockConnection) PayloadSize() uint { return 0 } -func (m mockConnection) Close() error { +func (m *mockConnection) Close() error { + m.closed = true return nil } -func (m mockConnection) GetPartner() partner.Manager { - return m.partner -} +func (m *mockConnection) GetPartner() partner.Manager { return m.partner } -func (m mockConnection) SendE2E(mt catalog.MessageType, payload []byte, params e2e.Params) ([]id.Round, cryptoE2e.MessageID, time.Time, error) { +func (m *mockConnection) SendE2E( + mt catalog.MessageType, payload []byte, _ e2e.Params) ( + []id.Round, cryptoE2e.MessageID, time.Time, error) { m.payloadChan <- payload m.listener.Hear(receive.Message{ MessageType: mt, @@ -178,250 +127,86 @@ func (m mockConnection) SendE2E(mt catalog.MessageType, payload []byte, params e return nil, cryptoE2e.MessageID{}, time.Time{}, nil } -func (m mockConnection) RegisterListener(messageType catalog.MessageType, newListener receive.Listener) receive.ListenerID { - return receive.ListenerID{} -} - -func (m mockConnection) Unregister(listenerID receive.ListenerID) { - +func (m *mockConnection) RegisterListener( + catalog.MessageType, receive.Listener) (receive.ListenerID, error) { + return receive.ListenerID{}, nil } +func (m *mockConnection) Unregister(receive.ListenerID) {} +func (m *mockConnection) LastUse() time.Time { return m.lastUse } //////////////////////////////////////////////////////////////////////////////// -// Mock cMix // +// Mock cMix // //////////////////////////////////////////////////////////////////////////////// +// Tests that mockCmix adheres to the cmix.Client interface. +var _ cmix.Client = (*mockCmix)(nil) + type mockCmix struct { instance *network.Instance } func newMockCmix() *mockCmix { - return &mockCmix{} } -func (m *mockCmix) Follow(report cmix.ClientErrorReport) (stoppable.Stoppable, error) { - return nil, nil -} +func (m *mockCmix) Follow(cmix.ClientErrorReport) (stoppable.Stoppable, error) { return nil, nil } -func (m *mockCmix) GetMaxMessageLength() int { - return 4096 -} - -func (m *mockCmix) Send(recipient *id.ID, fingerprint format.Fingerprint, - service message.Service, payload, mac []byte, - cmixParams cmix.CMIXParams) (id.Round, ephemeral.Id, error) { +func (m *mockCmix) GetMaxMessageLength() int { return 4096 } +func (m *mockCmix) Send(*id.ID, format.Fingerprint, message.Service, []byte, + []byte, cmix.CMIXParams) (id.Round, ephemeral.Id, error) { return 0, ephemeral.Id{}, nil } - -func (m *mockCmix) SendMany(messages []cmix.TargetedCmixMessage, p cmix.CMIXParams) (id.Round, []ephemeral.Id, error) { +func (m *mockCmix) SendMany([]cmix.TargetedCmixMessage, cmix.CMIXParams) (id.Round, []ephemeral.Id, error) { return 0, []ephemeral.Id{}, nil } - -func (m *mockCmix) AddIdentity(id *id.ID, validUntil time.Time, persistent bool) { -} - -func (m *mockCmix) RemoveIdentity(id *id.ID) { -} - -func (m *mockCmix) GetIdentity(get *id.ID) (identity.TrackedID, error) { - return identity.TrackedID{ - Creation: netTime.Now().Add(-time.Minute), - }, nil -} - -func (m *mockCmix) AddFingerprint(identity *id.ID, fp format.Fingerprint, mp message.Processor) error { - return nil -} - -func (m *mockCmix) DeleteFingerprint(identity *id.ID, fingerprint format.Fingerprint) { - return -} - -func (m *mockCmix) DeleteClientFingerprints(identity *id.ID) { - return -} - -func (m *mockCmix) AddService(clientID *id.ID, newService message.Service, response message.Processor) { - return -} - -func (m *mockCmix) DeleteService(clientID *id.ID, toDelete message.Service, processor message.Processor) { - return -} - -func (m *mockCmix) DeleteClientService(clientID *id.ID) { -} - -func (m *mockCmix) TrackServices(tracker message.ServicesTracker) { - return -} - -func (m *mockCmix) CheckInProgressMessages() { - return -} - -func (m *mockCmix) IsHealthy() bool { - return true -} - -func (m *mockCmix) WasHealthy() bool { - return true -} - -func (m *mockCmix) AddHealthCallback(f func(bool)) uint64 { - return 0 -} - -func (m *mockCmix) RemoveHealthCallback(u uint64) { - return -} - -func (m *mockCmix) HasNode(nid *id.ID) bool { - return true -} - -func (m *mockCmix) NumRegisteredNodes() int { - return 24 -} - -func (m *mockCmix) TriggerNodeRegistration(nid *id.ID) { - return -} - -func (m *mockCmix) GetRoundResults(timeout time.Duration, roundCallback cmix.RoundEventCallback, roundList ...id.Round) error { +func (m *mockCmix) AddIdentity(*id.ID, time.Time, bool) {} +func (m *mockCmix) RemoveIdentity(*id.ID) {} + +func (m *mockCmix) GetIdentity(*id.ID) (identity.TrackedID, error) { + return identity.TrackedID{Creation: netTime.Now().Add(-time.Minute)}, nil +} + +func (m *mockCmix) AddFingerprint(*id.ID, format.Fingerprint, message.Processor) error { return nil } +func (m *mockCmix) DeleteFingerprint(*id.ID, format.Fingerprint) {} +func (m *mockCmix) DeleteClientFingerprints(*id.ID) {} +func (m *mockCmix) AddService(*id.ID, message.Service, message.Processor) {} +func (m *mockCmix) DeleteService(*id.ID, message.Service, message.Processor) {} +func (m *mockCmix) DeleteClientService(*id.ID) {} +func (m *mockCmix) TrackServices(message.ServicesTracker) {} +func (m *mockCmix) CheckInProgressMessages() {} +func (m *mockCmix) IsHealthy() bool { return true } +func (m *mockCmix) WasHealthy() bool { return true } +func (m *mockCmix) AddHealthCallback(func(bool)) uint64 { return 0 } +func (m *mockCmix) RemoveHealthCallback(uint64) {} +func (m *mockCmix) HasNode(*id.ID) bool { return true } +func (m *mockCmix) NumRegisteredNodes() int { return 24 } +func (m *mockCmix) TriggerNodeRegistration(*id.ID) {} + +func (m *mockCmix) GetRoundResults(_ time.Duration, roundCallback cmix.RoundEventCallback, _ ...id.Round) error { roundCallback(true, false, nil) return nil } -func (m *mockCmix) LookupHistoricalRound(rid id.Round, callback rounds.RoundResultCallback) error { - return nil -} - -func (m *mockCmix) SendToAny(sendFunc func(host *connect.Host) (interface{}, error), stop *stoppable.Single) (interface{}, error) { +func (m *mockCmix) LookupHistoricalRound(id.Round, rounds.RoundResultCallback) error { return nil } +func (m *mockCmix) SendToAny(func(host *connect.Host) (interface{}, error), *stoppable.Single) (interface{}, error) { return nil, nil } - -func (m *mockCmix) SendToPreferred(targets []*id.ID, sendFunc gateway.SendToPreferredFunc, stop *stoppable.Single, timeout time.Duration) (interface{}, error) { - return nil, nil -} - -func (m *mockCmix) SetGatewayFilter(f gateway.Filter) { - return -} - -func (m *mockCmix) GetHostParams() connect.HostParams { - return connect.GetDefaultHostParams() -} - -func (m *mockCmix) GetAddressSpace() uint8 { - return 32 -} - -func (m *mockCmix) RegisterAddressSpaceNotification(tag string) (chan uint8, error) { +func (m *mockCmix) SendToPreferred([]*id.ID, gateway.SendToPreferredFunc, *stoppable.Single, time.Duration) (interface{}, error) { return nil, nil } - -func (m *mockCmix) UnregisterAddressSpaceNotification(tag string) { - return -} - -func (m *mockCmix) GetInstance() *network.Instance { - return m.instance -} - -func (m *mockCmix) GetVerboseRounds() string { - return "" -} +func (m *mockCmix) SetGatewayFilter(gateway.Filter) {} +func (m *mockCmix) GetHostParams() connect.HostParams { return connect.GetDefaultHostParams() } +func (m *mockCmix) GetAddressSpace() uint8 { return 32 } +func (m *mockCmix) RegisterAddressSpaceNotification(string) (chan uint8, error) { return nil, nil } +func (m *mockCmix) UnregisterAddressSpaceNotification(string) {} +func (m *mockCmix) GetInstance() *network.Instance { return m.instance } +func (m *mockCmix) GetVerboseRounds() string { return "" } //////////////////////////////////////////////////////////////////////////////// // Misc set-up utils // //////////////////////////////////////////////////////////////////////////////// -var testCert = `-----BEGIN CERTIFICATE----- -MIIF4DCCA8igAwIBAgIUegUvihtQooWNIzsNqj6lucXn6g8wDQYJKoZIhvcNAQEL -BQAwgYwxCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTESMBAGA1UEBwwJQ2xhcmVt -b250MRAwDgYDVQQKDAdFbGl4eGlyMRQwEgYDVQQLDAtEZXZlbG9wbWVudDETMBEG -A1UEAwwKZWxpeHhpci5pbzEfMB0GCSqGSIb3DQEJARYQYWRtaW5AZWxpeHhpci5p -bzAeFw0yMTExMzAxODMwMTdaFw0zMTExMjgxODMwMTdaMIGMMQswCQYDVQQGEwJV -UzELMAkGA1UECAwCQ0ExEjAQBgNVBAcMCUNsYXJlbW9udDEQMA4GA1UECgwHRWxp -eHhpcjEUMBIGA1UECwwLRGV2ZWxvcG1lbnQxEzARBgNVBAMMCmVsaXh4aXIuaW8x -HzAdBgkqhkiG9w0BCQEWEGFkbWluQGVsaXh4aXIuaW8wggIiMA0GCSqGSIb3DQEB -AQUAA4ICDwAwggIKAoICAQCckGabzUitkySleveyD9Yrxrpj50FiGkOvwkmgN1jF -9r5StN3otiU5tebderkjD82mVqB781czRA9vPqAggbw1ZdAyQPTvDPTj7rmzkByq -QIkdZBMshV/zX1z8oXoNB9bzZlUFVF4HTY3dEytAJONJRkGGAw4FTa/wCkWsITiT -mKvkP3ciKgz7s8uMyZzZpj9ElBphK9Nbwt83v/IOgTqDmn5qDBnHtoLw4roKJkC8 -00GF4ZUhlVSQC3oFWOCu6tvSUVCBCTUzVKYJLmCnoilmiE/8nCOU0VOivtsx88f5 -9RSPfePUk8u5CRmgThwOpxb0CAO0gd+sY1YJrn+FaW+dSR8OkM3bFuTq7fz9CEkS -XFfUwbJL+HzT0ZuSA3FupTIExyDmM/5dF8lC0RB3j4FNQF+H+j5Kso86e83xnXPI -e+IKKIYa/LVdW24kYRuBDpoONN5KS/F+F/5PzOzH9Swdt07J9b7z1dzWcLnKGtkN -WVsZ7Ue6cuI2zOEWqF1OEr9FladgORcdVBoF/WlsA63C2c1J0tjXqqcl/27GmqGW -gvhaA8Jkm20qLCEhxQ2JzrBdk/X/lCZdP/7A5TxnLqSBq8xxMuLJlZZbUG8U/BT9 -sHF5mXZyiucMjTEU7qHMR2UGNFot8TQ7ZXntIApa2NlB/qX2qI5D13PoXI9Hnyxa -8wIDAQABozgwNjAVBgNVHREEDjAMggplbGl4eGlyLmlvMB0GA1UdDgQWBBQimFud -gCzDVFD3Xz68zOAebDN6YDANBgkqhkiG9w0BAQsFAAOCAgEAccsH9JIyFZdytGxC -/6qjSHPgV23ZGmW7alg+GyEATBIAN187Du4Lj6cLbox5nqLdZgYzizVop32JQAHv -N1QPKjViOOkLaJprSUuRULa5kJ5fe+XfMoyhISI4mtJXXbMwl/PbOaDSdeDjl0ZO -auQggWslyv8ZOkfcbC6goEtAxljNZ01zY1ofSKUj+fBw9Lmomql6GAt7NuubANs4 -9mSjXwD27EZf3Aqaaju7gX1APW2O03/q4hDqhrGW14sN0gFt751ddPuPr5COGzCS -c3Xg2HqMpXx//FU4qHrZYzwv8SuGSshlCxGJpWku9LVwci1Kxi4LyZgTm6/xY4kB -5fsZf6C2yAZnkIJ8bEYr0Up4KzG1lNskU69uMv+d7W2+4Ie3Evf3HdYad/WeUskG -tc6LKY6B2NX3RMVkQt0ftsDaWsktnR8VBXVZSBVYVEQu318rKvYRdOwZJn339obI -jyMZC/3D721e5Anj/EqHpc3I9Yn3jRKw1xc8kpNLg/JIAibub8JYyDvT1gO4xjBO -+6EWOBFgDAsf7bSP2xQn1pQFWcA/sY1MnRsWeENmKNrkLXffP+8l1tEcijN+KCSF -ek1mr+qBwSaNV9TA+RXVhvqd3DEKPPJ1WhfxP1K81RdUESvHOV/4kdwnSahDyao0 -EnretBzQkeKeBwoB2u6NTiOmUjk= ------END CERTIFICATE----- -` - -func getNDF() *ndf.NetworkDefinition { - return &ndf.NetworkDefinition{ - UDB: ndf.UDB{ - ID: id.DummyUser.Bytes(), - Cert: testCert, - Address: "address", - DhPubKey: []byte{123, 34, 86, 97, 108, 117, 101, 34, 58, 53, 48, 49, 53, 53, 53, 52, 54, 53, 49, 48, 54, 49, 56, 57, 53, 54, 51, 48, 54, 52, 49, 51, 53, 49, 57, 56, 55, 57, 52, 57, 50, 48, 56, 49, 52, 57, 52, 50, 57, 51, 57, 53, 49, 50, 51, 54, 52, 56, 49, 57, 55, 48, 50, 50, 49, 48, 55, 55, 50, 52, 52, 48, 49, 54, 57, 52, 55, 52, 57, 53, 53, 56, 55, 54, 50, 57, 53, 57, 53, 48, 54, 55, 57, 55, 48, 53, 48, 48, 54, 54, 56, 49, 57, 50, 56, 48, 52, 48, 53, 51, 50, 48, 57, 55, 54, 56, 56, 53, 57, 54, 57, 56, 57, 49, 48, 54, 56, 54, 50, 52, 50, 52, 50, 56, 49, 48, 51, 51, 51, 54, 55, 53, 55, 54, 52, 51, 54, 55, 54, 56, 53, 56, 48, 55, 56, 49, 52, 55, 49, 52, 53, 49, 52, 52, 52, 52, 53, 51, 57, 57, 51, 57, 57, 53, 50, 52, 52, 53, 51, 56, 48, 49, 48, 54, 54, 55, 48, 52, 50, 49, 55, 54, 57, 53, 57, 57, 57, 51, 52, 48, 54, 54, 54, 49, 50, 48, 54, 56, 57, 51, 54, 57, 48, 52, 55, 55, 54, 50, 49, 49, 56, 56, 53, 51, 50, 57, 57, 50, 54, 53, 48, 52, 57, 51, 54, 55, 54, 48, 57, 56, 56, 49, 55, 52, 52, 57, 53, 57, 54, 53, 50, 55, 53, 52, 52, 52, 49, 57, 55, 49, 54, 50, 52, 52, 56, 50, 55, 55, 50, 49, 48, 53, 56, 56, 57, 54, 51, 53, 54, 54, 53, 53, 53, 53, 49, 56, 50, 53, 49, 49, 50, 57, 50, 48, 49, 56, 48, 48, 54, 49, 56, 57, 48, 55, 48, 51, 53, 51, 51, 56, 57, 52, 49, 50, 57, 49, 55, 50, 56, 55, 57, 57, 52, 55, 53, 51, 49, 55, 55, 48, 53, 55, 55, 49, 50, 51, 57, 49, 51, 55, 54, 48, 50, 49, 55, 50, 54, 54, 52, 56, 52, 48, 48, 54, 48, 52, 48, 53, 56, 56, 53, 54, 52, 56, 56, 49, 52, 52, 51, 57, 56, 51, 51, 57, 54, 55, 48, 49, 53, 55, 52, 53, 50, 56, 51, 49, 51, 48, 53, 52, 49, 49, 49, 49, 49, 56, 51, 53, 52, 52, 52, 52, 48, 53, 54, 57, 48, 54, 52, 56, 57, 52, 54, 53, 50, 56, 51, 53, 50, 48, 48, 50, 48, 48, 49, 50, 51, 51, 48, 48, 53, 48, 49, 50, 52, 56, 57, 48, 49, 51, 54, 55, 52, 57, 55, 50, 49, 48, 55, 53, 54, 49, 50, 52, 52, 57, 55, 48, 50, 56, 55, 55, 51, 51, 50, 53, 50, 48, 57, 52, 56, 57, 49, 49, 56, 49, 54, 57, 50, 55, 50, 51, 57, 51, 57, 54, 50, 56, 48, 54, 54, 49, 57, 55, 48, 50, 48, 57, 49, 51, 54, 50, 49, 50, 53, 50, 54, 50, 53, 53, 55, 57, 54, 51, 56, 49, 57, 48, 51, 49, 54, 54, 53, 51, 56, 56, 49, 48, 56, 48, 51, 57, 53, 49, 53, 53, 55, 49, 53, 57, 48, 57, 57, 55, 49, 56, 53, 55, 54, 48, 50, 54, 48, 49, 55, 57, 52, 55, 53, 51, 57, 49, 51, 53, 52, 49, 48, 50, 49, 55, 52, 51, 57, 48, 50, 56, 48, 50, 51, 53, 51, 54, 56, 49, 56, 50, 49, 55, 50, 57, 52, 51, 49, 56, 48, 56, 56, 50, 51, 53, 52, 56, 55, 49, 52, 55, 53, 50, 56, 48, 57, 55, 49, 53, 48, 48, 51, 50, 48, 57, 50, 50, 53, 50, 56, 51, 57, 55, 57, 49, 57, 50, 53, 56, 51, 55, 48, 51, 57, 54, 48, 50, 55, 54, 48, 54, 57, 55, 52, 53, 54, 52, 51, 56, 52, 53, 54, 48, 51, 57, 55, 55, 55, 49, 53, 57, 57, 49, 57, 52, 57, 56, 56, 54, 56, 50, 49, 49, 54, 56, 55, 56, 55, 51, 51, 57, 52, 53, 49, 52, 52, 55, 57, 53, 57, 49, 57, 52, 48, 51, 53, 49, 49, 49, 51, 48, 53, 54, 54, 50, 49, 56, 57, 52, 55, 50, 49, 54, 53, 57, 53, 50, 57, 50, 48, 51, 51, 52, 48, 56, 55, 54, 50, 49, 49, 49, 56, 53, 54, 57, 51, 57, 50, 53, 48, 53, 56, 56, 55, 56, 53, 54, 55, 51, 56, 55, 50, 53, 57, 56, 52, 54, 53, 49, 51, 50, 54, 51, 50, 48, 56, 56, 57, 52, 53, 57, 53, 56, 57, 57, 54, 52, 55, 55, 50, 57, 51, 51, 52, 55, 51, 48, 52, 56, 56, 50, 51, 50, 52, 53, 48, 51, 50, 56, 56, 50, 49, 55, 51, 51, 53, 54, 55, 51, 50, 51, 52, 56, 53, 52, 55, 48, 51, 56, 50, 51, 49, 53, 55, 52, 53, 53, 48, 55, 55, 56, 55, 48, 50, 51, 52, 50, 53, 52, 51, 48, 57, 56, 56, 54, 56, 54, 49, 57, 54, 48, 55, 55, 52, 57, 55, 56, 51, 48, 51, 57, 49, 55, 52, 49, 51, 49, 54, 57, 54, 50, 49, 52, 50, 55, 57, 55, 56, 56, 51, 49, 55, 51, 50, 54, 56, 49, 56, 53, 57, 48, 49, 49, 53, 48, 52, 53, 51, 51, 56, 52, 57, 57, 55, 54, 51, 55, 55, 48, 55, 49, 52, 50, 49, 54, 48, 49, 54, 52, 49, 57, 53, 56, 49, 54, 50, 55, 49, 52, 49, 52, 56, 49, 51, 52, 50, 53, 56, 55, 53, 57, 55, 52, 49, 57, 49, 55, 51, 55, 49, 51, 57, 54, 51, 49, 51, 49, 56, 53, 50, 49, 53, 52, 49, 51, 44, 34, 70, 105, 110, 103, 101, 114, 112, 114, 105, 110, 116, 34, 58, 49, 54, 56, 48, 49, 53, 52, 49, 53, 49, 49, 50, 51, 51, 48, 57, 56, 51, 54, 51, 125}, - }, - E2E: ndf.Group{ - Prime: "E2EE983D031DC1DB6F1A7A67DF0E9A8E5561DB8E8D49413394C049B7A" + - "8ACCEDC298708F121951D9CF920EC5D146727AA4AE535B0922C688B55B3D" + - "D2AEDF6C01C94764DAB937935AA83BE36E67760713AB44A6337C20E78615" + - "75E745D31F8B9E9AD8412118C62A3E2E29DF46B0864D0C951C394A5CBBDC" + - "6ADC718DD2A3E041023DBB5AB23EBB4742DE9C1687B5B34FA48C3521632C" + - "4A530E8FFB1BC51DADDF453B0B2717C2BC6669ED76B4BDD5C9FF558E88F2" + - "6E5785302BEDBCA23EAC5ACE92096EE8A60642FB61E8F3D24990B8CB12EE" + - "448EEF78E184C7242DD161C7738F32BF29A841698978825B4111B4BC3E1E" + - "198455095958333D776D8B2BEEED3A1A1A221A6E37E664A64B83981C46FF" + - "DDC1A45E3D5211AAF8BFBC072768C4F50D7D7803D2D4F278DE8014A47323" + - "631D7E064DE81C0C6BFA43EF0E6998860F1390B5D3FEACAF1696015CB79C" + - "3F9C2D93D961120CD0E5F12CBB687EAB045241F96789C38E89D796138E63" + - "19BE62E35D87B1048CA28BE389B575E994DCA755471584A09EC723742DC3" + - "5873847AEF49F66E43873", - Generator: "2", - }, - CMIX: ndf.Group{ - Prime: "9DB6FB5951B66BB6FE1E140F1D2CE5502374161FD6538DF1648218642" + - "F0B5C48C8F7A41AADFA187324B87674FA1822B00F1ECF8136943D7C55757" + - "264E5A1A44FFE012E9936E00C1D3E9310B01C7D179805D3058B2A9F4BB6F" + - "9716BFE6117C6B5B3CC4D9BE341104AD4A80AD6C94E005F4B993E14F091E" + - "B51743BF33050C38DE235567E1B34C3D6A5C0CEAA1A0F368213C3D19843D" + - "0B4B09DCB9FC72D39C8DE41F1BF14D4BB4563CA28371621CAD3324B6A2D3" + - "92145BEBFAC748805236F5CA2FE92B871CD8F9C36D3292B5509CA8CAA77A" + - "2ADFC7BFD77DDA6F71125A7456FEA153E433256A2261C6A06ED3693797E7" + - "995FAD5AABBCFBE3EDA2741E375404AE25B", - Generator: "5C7FF6B06F8F143FE8288433493E4769C4D988ACE5BE25A0E2480" + - "9670716C613D7B0CEE6932F8FAA7C44D2CB24523DA53FBE4F6EC3595892D" + - "1AA58C4328A06C46A15662E7EAA703A1DECF8BBB2D05DBE2EB956C142A33" + - "8661D10461C0D135472085057F3494309FFA73C611F78B32ADBB5740C361" + - "C9F35BE90997DB2014E2EF5AA61782F52ABEB8BD6432C4DD097BC5423B28" + - "5DAFB60DC364E8161F4A2A35ACA3A10B1C4D203CC76A470A33AFDCBDD929" + - "59859ABD8B56E1725252D78EAC66E71BA9AE3F1DD2487199874393CD4D83" + - "2186800654760E1E34C09E4D155179F9EC0DC4473F996BDCE6EED1CABED8" + - "B6F116F7AD9CF505DF0F998E34AB27514B0FFE7", - }, - } -} - func getGroup() *cyclic.Group { return cyclic.NewGroup( large.NewIntFromString("E2EE983D031DC1DB6F1A7A67DF0E9A8E5561DB8E8D4941"+ diff --git a/e2e/interface.go b/e2e/interface.go index d9710650594926f528985ddf5f7aa1c41365d74c..c4c6b68af6ccacfc85876ce964120a5868e3e4a2 100644 --- a/e2e/interface.go +++ b/e2e/interface.go @@ -103,6 +103,10 @@ type Handler interface { // will no longer get called Unregister(listenerID receive.ListenerID) + // UnregisterUserListeners removes all the listeners registered with the + // specified user. + UnregisterUserListeners(userID *id.ID) + /* === Partners ===================================================== */ // AddPartner adds a partner. Automatically creates both send diff --git a/e2e/receive/byID.go b/e2e/receive/byID.go index 625aa05db6958642b25961e18daf5196d3dd8659..425cd036e6b7c0bccc307621a709b08098b5ba2e 100644 --- a/e2e/receive/byID.go +++ b/e2e/receive/byID.go @@ -44,13 +44,13 @@ func (bi *byId) Get(uid *id.ID) *set.Set { // adds a listener to a set for the given ID. Creates a new set to add it to if // the set does not exist -func (bi *byId) Add(uid *id.ID, l Listener) *set.Set { - s, ok := bi.list[*uid] +func (bi *byId) Add(lid ListenerID) *set.Set { + s, ok := bi.list[*lid.userID] if !ok { - s = set.New(l) - bi.list[*uid] = s + s = set.New(lid) + bi.list[*lid.userID] = s } else { - s.Insert(l) + s.Insert(lid) } return s @@ -58,13 +58,20 @@ func (bi *byId) Add(uid *id.ID, l Listener) *set.Set { // Removes the passed listener from the set for UserID and // deletes the set if it is empty if the ID is not a generic one -func (bi *byId) Remove(uid *id.ID, l Listener) { - s, ok := bi.list[*uid] +func (bi *byId) Remove(lid ListenerID) { + s, ok := bi.list[*lid.userID] if ok { - s.Remove(l) + s.Remove(lid) - if s.Len() == 0 && !uid.Cmp(AnyUser()) && !uid.Cmp(&id.ID{}) { - delete(bi.list, *uid) + if s.Len() == 0 && !lid.userID.Cmp(AnyUser()) && !lid.userID.Cmp(&id.ID{}) { + delete(bi.list, *lid.userID) } } } + +// RemoveId removes all listeners registered for the given user ID. +func (bi *byId) RemoveId(uid *id.ID) { + if !uid.Cmp(AnyUser()) && !uid.Cmp(&id.ID{}) { + delete(bi.list, *uid) + } +} diff --git a/e2e/receive/byID_test.go b/e2e/receive/byID_test.go index 750c1bc6c958d76b56315f0ad10f7bdd954e63bd..e8aebd2777242ff1e910876d0f2d964f493c3bc7 100644 --- a/e2e/receive/byID_test.go +++ b/e2e/receive/byID_test.go @@ -120,9 +120,9 @@ func TestById_Add_New(t *testing.T) { uid := id.NewIdFromUInt(42, id.User, t) - l := &funcListener{} + lid := ListenerID{uid, 1, &funcListener{}} - nbi.Add(uid, l) + nbi.Add(lid) s := nbi.list[*uid] @@ -130,7 +130,7 @@ func TestById_Add_New(t *testing.T) { t.Errorf("Should a set of the wrong size") } - if !s.Has(l) { + if !s.Has(lid) { t.Errorf("Wrong set returned") } } @@ -142,26 +142,26 @@ func TestById_Add_Old(t *testing.T) { uid := id.NewIdFromUInt(42, id.User, t) - l1 := &funcListener{} - l2 := &funcListener{} + lid1 := ListenerID{uid, 1, &funcListener{}} + lid2 := ListenerID{uid, 1, &funcListener{}} - set1 := set.New(l1) + set1 := set.New(lid1) nbi.list[*uid] = set1 - nbi.Add(uid, l2) + nbi.Add(lid2) s := nbi.list[*uid] if s.Len() != 2 { - t.Errorf("Should have returned a set") + t.Errorf("Incorrect set length.\nexpected: %d\nreceived: %d", 2, s.Len()) } - if !s.Has(l1) { + if !s.Has(lid1) { t.Errorf("Set does not include the initial listener") } - if !s.Has(l2) { + if !s.Has(lid2) { t.Errorf("Set does not include the new listener") } } @@ -171,11 +171,11 @@ func TestById_Add_Old(t *testing.T) { func TestById_Add_Generic(t *testing.T) { nbi := newById() - l1 := &funcListener{} - l2 := &funcListener{} + lid1 := ListenerID{&id.ID{}, 1, &funcListener{}} + lid2 := ListenerID{AnyUser(), 1, &funcListener{}} - nbi.Add(&id.ID{}, l1) - nbi.Add(AnyUser(), l2) + nbi.Add(lid1) + nbi.Add(lid2) s := nbi.generic @@ -183,11 +183,11 @@ func TestById_Add_Generic(t *testing.T) { t.Errorf("Should have returned a set of size 2") } - if !s.Has(l1) { + if !s.Has(lid1) { t.Errorf("Set does not include the ZeroUser listener") } - if !s.Has(l2) { + if !s.Has(lid2) { t.Errorf("Set does not include the empty user listener") } } @@ -199,14 +199,14 @@ func TestById_Remove_ManyInSet(t *testing.T) { uid := id.NewIdFromUInt(42, id.User, t) - l1 := &funcListener{} - l2 := &funcListener{} + lid1 := ListenerID{uid, 1, &funcListener{}} + lid2 := ListenerID{uid, 1, &funcListener{}} - set1 := set.New(l1, l2) + set1 := set.New(lid1, lid2) nbi.list[*uid] = set1 - nbi.Remove(uid, l1) + nbi.Remove(lid1) if _, ok := nbi.list[*uid]; !ok { t.Errorf("Set removed when it should not have been") @@ -217,11 +217,11 @@ func TestById_Remove_ManyInSet(t *testing.T) { set1.Len()) } - if set1.Has(l1) { + if set1.Has(lid1) { t.Errorf("Listener 1 still in set, it should not be") } - if !set1.Has(l2) { + if !set1.Has(lid2) { t.Errorf("Listener 2 not still in set, it should be") } @@ -234,13 +234,13 @@ func TestById_Remove_SingleInSet(t *testing.T) { uid := id.NewIdFromUInt(42, id.User, t) - l1 := &funcListener{} + lid1 := ListenerID{uid, 1, &funcListener{}} - set1 := set.New(l1) + set1 := set.New(lid1) nbi.list[*uid] = set1 - nbi.Remove(uid, l1) + nbi.Remove(lid1) if _, ok := nbi.list[*uid]; ok { t.Errorf("Set not removed when it should have been") @@ -251,7 +251,7 @@ func TestById_Remove_SingleInSet(t *testing.T) { set1.Len()) } - if set1.Has(l1) { + if set1.Has(lid1) { t.Errorf("Listener 1 still in set, it should not be") } } @@ -263,13 +263,13 @@ func TestById_Remove_SingleInSet_ZeroUser(t *testing.T) { uid := &id.ZeroUser - l1 := &funcListener{} + lid1 := ListenerID{uid, 1, &funcListener{}} - set1 := set.New(l1) + set1 := set.New(lid1) nbi.list[*uid] = set1 - nbi.Remove(uid, l1) + nbi.Remove(lid1) if _, ok := nbi.list[*uid]; !ok { t.Errorf("Set removed when it should not have been") @@ -280,7 +280,7 @@ func TestById_Remove_SingleInSet_ZeroUser(t *testing.T) { set1.Len()) } - if set1.Has(l1) { + if set1.Has(lid1) { t.Errorf("Listener 1 still in set, it should not be") } } @@ -292,13 +292,13 @@ func TestById_Remove_SingleInSet_EmptyUser(t *testing.T) { uid := &id.ID{} - l1 := &funcListener{} + lid1 := ListenerID{uid, 1, &funcListener{}} - set1 := set.New(l1) + set1 := set.New(lid1) nbi.list[*uid] = set1 - nbi.Remove(uid, l1) + nbi.Remove(lid1) if _, ok := nbi.list[*uid]; !ok { t.Errorf("Set removed when it should not have been") @@ -309,7 +309,7 @@ func TestById_Remove_SingleInSet_EmptyUser(t *testing.T) { set1.Len()) } - if set1.Has(l1) { + if set1.Has(lid1) { t.Errorf("Listener 1 still in set, it should not be") } } diff --git a/e2e/receive/byType.go b/e2e/receive/byType.go index 0ec3f6b4502a2f8ece67599ea95d0e65dabe8eb0..183e2da360b5c03fceba294bcc578b6c73f6fb86 100644 --- a/e2e/receive/byType.go +++ b/e2e/receive/byType.go @@ -45,13 +45,13 @@ func (bt *byType) Get(messageType catalog.MessageType) *set.Set { // adds a listener to a set for the given messageType. Creates a new set to add // it to if the set does not exist -func (bt *byType) Add(messageType catalog.MessageType, r Listener) *set.Set { - s, ok := bt.list[messageType] +func (bt *byType) Add(lid ListenerID) *set.Set { + s, ok := bt.list[lid.messageType] if !ok { - s = set.New(r) - bt.list[messageType] = s + s = set.New(lid) + bt.list[lid.messageType] = s } else { - s.Insert(r) + s.Insert(lid) } return s @@ -59,13 +59,13 @@ func (bt *byType) Add(messageType catalog.MessageType, r Listener) *set.Set { // Removes the passed listener from the set for messageType and // deletes the set if it is empty and the type is not AnyType -func (bt *byType) Remove(mt catalog.MessageType, l Listener) { - s, ok := bt.list[mt] +func (bt *byType) Remove(lid ListenerID) { + s, ok := bt.list[lid.messageType] if ok { - s.Remove(l) + s.Remove(lid) - if s.Len() == 0 && mt != AnyType { - delete(bt.list, mt) + if s.Len() == 0 && lid.messageType != AnyType { + delete(bt.list, lid.messageType) } } } diff --git a/e2e/receive/byType_test.go b/e2e/receive/byType_test.go index 30f3c45bd005d922f44190e20a3b3f48bd31188f..3ca6127d222f2b7b9038c7fbe7d2ac389c833969 100644 --- a/e2e/receive/byType_test.go +++ b/e2e/receive/byType_test.go @@ -10,6 +10,7 @@ package receive import ( "github.com/golang-collections/collections/set" "gitlab.com/elixxir/client/catalog" + "gitlab.com/xx_network/primitives/id" "testing" ) @@ -108,9 +109,9 @@ func TestByType_Add_New(t *testing.T) { m := catalog.MessageType(42) - l := &funcListener{} + l := ListenerID{&id.ZeroUser, m, &funcListener{}} - nbt.Add(m, l) + nbt.Add(l) s := nbt.list[m] @@ -130,14 +131,14 @@ func TestByType_Add_Old(t *testing.T) { m := catalog.MessageType(42) - l1 := &funcListener{} - l2 := &funcListener{} + lid1 := ListenerID{&id.ZeroUser, m, &funcListener{}} + lid2 := ListenerID{&id.ZeroUser, m, &funcListener{}} - set1 := set.New(l1) + set1 := set.New(lid1) nbt.list[m] = set1 - nbt.Add(m, l2) + nbt.Add(lid2) s := nbt.list[m] @@ -145,11 +146,11 @@ func TestByType_Add_Old(t *testing.T) { t.Errorf("Should have returned a set") } - if !s.Has(l1) { + if !s.Has(lid1) { t.Errorf("Set does not include the initial listener") } - if !s.Has(l2) { + if !s.Has(lid2) { t.Errorf("Set does not include the new listener") } } @@ -159,9 +160,9 @@ func TestByType_Add_Old(t *testing.T) { func TestByType_Add_Generic(t *testing.T) { nbt := newByType() - l1 := &funcListener{} + lid1 := ListenerID{&id.ZeroUser, AnyType, &funcListener{}} - nbt.Add(AnyType, l1) + nbt.Add(lid1) s := nbt.generic @@ -169,7 +170,7 @@ func TestByType_Add_Generic(t *testing.T) { t.Errorf("Should have returned a set of size 2") } - if !s.Has(l1) { + if !s.Has(lid1) { t.Errorf("Set does not include the ZeroUser listener") } } @@ -181,13 +182,13 @@ func TestByType_Remove_SingleInSet(t *testing.T) { m := catalog.MessageType(42) - l1 := &funcListener{} + lid1 := ListenerID{&id.ZeroUser, m, &funcListener{}} - set1 := set.New(l1) + set1 := set.New(lid1) nbt.list[m] = set1 - nbt.Remove(m, l1) + nbt.Remove(lid1) if _, ok := nbt.list[m]; ok { t.Errorf("Set not removed when it should have been") @@ -198,7 +199,7 @@ func TestByType_Remove_SingleInSet(t *testing.T) { set1.Len()) } - if set1.Has(l1) { + if set1.Has(lid1) { t.Errorf("Listener 1 still in set, it should not be") } } @@ -210,13 +211,13 @@ func TestByType_Remove_SingleInSet_AnyType(t *testing.T) { m := AnyType - l1 := &funcListener{} + lid1 := ListenerID{&id.ZeroUser, m, &funcListener{}} - set1 := set.New(l1) + set1 := set.New(lid1) nbt.list[m] = set1 - nbt.Remove(m, l1) + nbt.Remove(lid1) if _, ok := nbt.list[m]; !ok { t.Errorf("Set removed when it should not have been") @@ -227,7 +228,7 @@ func TestByType_Remove_SingleInSet_AnyType(t *testing.T) { set1.Len()) } - if set1.Has(l1) { + if set1.Has(lid1) { t.Errorf("Listener 1 still in set, it should not be") } } diff --git a/e2e/receive/listener.go b/e2e/receive/listener.go index 2bef0525c185a291ee861d9447dd9e18824a6109..5e1d34067c5b7709509c18de6000103d73107687 100644 --- a/e2e/receive/listener.go +++ b/e2e/receive/listener.go @@ -11,6 +11,8 @@ import ( jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/catalog" "gitlab.com/xx_network/primitives/id" + "strconv" + "strings" ) //Listener interface for a listener adhere to @@ -50,6 +52,19 @@ func (lid ListenerID) GetName() string { return lid.listener.Name() } +// String returns the values in the ListenerID in a human-readable format. This +// functions adheres to the fmt.Stringer interface. +func (lid ListenerID) String() string { + str := []string{ + "userID:" + lid.userID.String(), + "messageType:" + lid.messageType.String() + + "(" + strconv.FormatUint(uint64(lid.messageType), 10) + ")", + "listener:" + lid.listener.Name(), + } + + return "{" + strings.Join(str, " ") + "}" +} + /*internal listener implementations*/ //listener based off of a function diff --git a/e2e/receive/switchboard.go b/e2e/receive/switchboard.go index ad214f53ae524c36fa36ab62a598d76e6e732fa0..3bb5e23cfae466cad43a4a2ca665a079caa48413 100644 --- a/e2e/receive/switchboard.go +++ b/e2e/receive/switchboard.go @@ -55,20 +55,23 @@ func (sw *Switchboard) RegisterListener(user *id.ID, jww.FATAL.Panicf("cannot register nil listener") } + // Create new listener ID + lid := ListenerID{ + userID: user, + messageType: messageType, + listener: newListener, + } + //register the listener by both ID and messageType sw.mux.Lock() - sw.id.Add(user, newListener) - sw.messageType.Add(messageType, newListener) + sw.id.Add(lid) + sw.messageType.Add(lid) sw.mux.Unlock() //return a ListenerID so it can be unregistered in the future - return ListenerID{ - userID: user, - messageType: messageType, - listener: newListener, - } + return lid } // RegisterFunc Registers a new listener built around the passed function. @@ -141,8 +144,8 @@ func (sw *Switchboard) Speak(item Message) { //Execute hear on all matched listeners in a new goroutine matches.Do(func(i interface{}) { - r := i.(Listener) - go r.Hear(item) + lid := i.(ListenerID) + go lid.listener.Hear(item) }) // print to log if nothing was heard @@ -158,12 +161,32 @@ func (sw *Switchboard) Speak(item Message) { func (sw *Switchboard) Unregister(listenerID ListenerID) { sw.mux.Lock() - sw.id.Remove(listenerID.userID, listenerID.listener) - sw.messageType.Remove(listenerID.messageType, listenerID.listener) + sw.id.Remove(listenerID) + sw.messageType.Remove(listenerID) sw.mux.Unlock() } +// UnregisterUserListeners removes all the listeners registered with the +// specified user. +func (sw *Switchboard) UnregisterUserListeners(userID *id.ID) { + sw.mux.Lock() + defer sw.mux.Unlock() + + // Get list of all listeners for the specified user + idSet := sw.id.Get(userID) + + // Find each listener in the messageType list and delete it + idSet.Do(func(i interface{}) { + lid := i.(ListenerID) + mtSet := sw.messageType.list[lid.messageType] + mtSet.Remove(lid) + }) + + // Remove all listeners for the user from the ID list + sw.id.RemoveId(userID) +} + // finds all listeners who match the items sender or ID, or have those fields // as generic func (sw *Switchboard) matchListeners(item Message) *set.Set { diff --git a/e2e/receive/switchboard_test.go b/e2e/receive/switchboard_test.go index aaee91e2d9d5090175334b503a0ca3a54f929cfe..7d2f8c4ceff37b2be6a6abe09675d0e50a6822dc 100644 --- a/e2e/receive/switchboard_test.go +++ b/e2e/receive/switchboard_test.go @@ -85,13 +85,13 @@ func TestSwitchboard_RegisterListener(t *testing.T) { //check that the listener is registered in the appropriate location setID := sw.id.Get(uid) - if !setID.Has(l) { + if !setID.Has(lid) { t.Errorf("Listener is not registered by ID") } setType := sw.messageType.Get(mt) - if !setType.Has(l) { + if !setType.Has(lid) { t.Errorf("Listener is not registered by Message Tag") } @@ -152,13 +152,13 @@ func TestSwitchboard_RegisterFunc(t *testing.T) { //check that the listener is registered in the appropriate location setID := sw.id.Get(uid) - if !setID.Has(lid.listener) { + if !setID.Has(lid) { t.Errorf("Listener is not registered by ID") } setType := sw.messageType.Get(mt) - if !setType.Has(lid.listener) { + if !setType.Has(lid) { t.Errorf("Listener is not registered by Message Tag") } @@ -221,13 +221,13 @@ func TestSwitchboard_RegisterChan(t *testing.T) { //check that the listener is registered in the appropriate location setID := sw.id.Get(uid) - if !setID.Has(lid.listener) { + if !setID.Has(lid) { t.Errorf("Listener is not registered by ID") } setType := sw.messageType.Get(mt) - if !setType.Has(lid.listener) { + if !setType.Has(lid) { t.Errorf("Listener is not registered by Message Tag") } @@ -330,21 +330,67 @@ func TestSwitchboard_Unregister(t *testing.T) { setType := sw.messageType.Get(mt) //check that the removed listener is not registered - if setID.Has(lid1.listener) { + if setID.Has(lid1) { t.Errorf("Removed Listener is registered by ID, should not be") } - if setType.Has(lid1.listener) { + if setType.Has(lid1) { t.Errorf("Removed Listener not registered by Message Tag, " + "should not be") } //check that the not removed listener is still registered - if !setID.Has(lid2.listener) { + if !setID.Has(lid2) { t.Errorf("Remaining Listener is not registered by ID") } - if !setType.Has(lid2.listener) { + if !setType.Has(lid2) { t.Errorf("Remaining Listener is not registered by Message Tag") } } + +// Tests that three listeners with three different message types but the same +// user are deleted with Switchboard.UnregisterUserListeners and that three +// other listeners with the same message types but different users are not +// delete. +func TestSwitchboard_UnregisterUserListeners(t *testing.T) { + sw := New() + + uid1 := id.NewIdFromUInt(42, id.User, t) + uid2 := id.NewIdFromUInt(11, id.User, t) + + l := func(receive Message) {} + + lid1 := sw.RegisterFunc("a", uid1, catalog.NoType, l) + lid2 := sw.RegisterFunc("b", uid1, catalog.XxMessage, l) + lid3 := sw.RegisterFunc("c", uid1, catalog.KeyExchangeTrigger, l) + lid4 := sw.RegisterFunc("d", uid2, catalog.NoType, l) + lid5 := sw.RegisterFunc("e", uid2, catalog.XxMessage, l) + lid6 := sw.RegisterFunc("f", uid2, catalog.KeyExchangeTrigger, l) + + sw.UnregisterUserListeners(uid1) + + if s, exists := sw.id.list[*uid1]; exists { + t.Errorf("Set for ID %s still exists: %+v", uid1, s) + } + + if sw.messageType.Get(lid1.messageType).Has(lid1) { + t.Errorf("Listener %+v still exists", lid1) + } + if sw.messageType.Get(lid2.messageType).Has(lid2) { + t.Errorf("Listener %+v still exists", lid2) + } + if sw.messageType.Get(lid3.messageType).Has(lid3) { + t.Errorf("Listener %+v still exists", lid3) + } + + if !sw.messageType.Get(lid4.messageType).Has(lid4) { + t.Errorf("Listener %+v does not exist", lid4) + } + if !sw.messageType.Get(lid5.messageType).Has(lid5) { + t.Errorf("Listener %+v does not exist", lid5) + } + if !sw.messageType.Get(lid6.messageType).Has(lid6) { + t.Errorf("Listener %+v does not exist", lid6) + } +} diff --git a/fileTransfer/connect/utils_test.go b/fileTransfer/connect/utils_test.go index bd155a4825642ec6cf629241b10e3a4e1e561523..e5c5c6d902c9d22578d847b95d4dc866e315e935 100644 --- a/fileTransfer/connect/utils_test.go +++ b/fileTransfer/connect/utils_test.go @@ -32,7 +32,7 @@ import ( ) //////////////////////////////////////////////////////////////////////////////// -// Mock cMix // +// Mock cMix // //////////////////////////////////////////////////////////////////////////////// type mockCmixHandler struct { @@ -137,6 +137,7 @@ func (m *mockCmix) GetRoundResults(_ time.Duration, //////////////////////////////////////////////////////////////////////////////// // Mock Connection Handler // //////////////////////////////////////////////////////////////////////////////// + func newMockListener(hearChan chan receive.Message) *mockListener { return &mockListener{hearChan: hearChan} } @@ -162,6 +163,9 @@ func newMockConnectionHandler() *mockConnectionHandler { } } +// Tests that mockConnection adheres to the Connection interface. +var _ Connection = (*mockConnection)(nil) + type mockConnection struct { myID *id.ID handler *mockConnectionHandler @@ -193,12 +197,12 @@ func (m *mockConnection) SendE2E(mt catalog.MessageType, payload []byte, return []id.Round{42}, e2eCrypto.MessageID{}, netTime.Now(), nil } -func (m *mockConnection) RegisterListener( - mt catalog.MessageType, listener receive.Listener) receive.ListenerID { +func (m *mockConnection) RegisterListener(mt catalog.MessageType, + listener receive.Listener) (receive.ListenerID, error) { m.handler.Lock() defer m.handler.Unlock() m.handler.listeners[mt] = listener - return receive.ListenerID{} + return receive.ListenerID{}, nil } //////////////////////////////////////////////////////////////////////////////// diff --git a/fileTransfer/connect/wrapper.go b/fileTransfer/connect/wrapper.go index fa78ee94b75879898f0dcfcefc67b1038bd11991..0bfde770e1fbf236858485ce4d6b2ae6e267406e 100644 --- a/fileTransfer/connect/wrapper.go +++ b/fileTransfer/connect/wrapper.go @@ -41,7 +41,7 @@ type Connection interface { SendE2E(mt catalog.MessageType, payload []byte, params e2e.Params) ( []id.Round, e2eCrypto.MessageID, time.Time, error) RegisterListener(messageType catalog.MessageType, - newListener receive.Listener) receive.ListenerID + newListener receive.Listener) (receive.ListenerID, error) } // NewWrapper generates a new file transfer manager using Connection E2E. @@ -56,7 +56,10 @@ func NewWrapper(receiveCB ft.ReceiveCallback, p Params, ft ft.FileTransfer, } // Register listener to receive new file transfers - w.conn.RegisterListener(catalog.NewFileTransfer, &listener{w}) + _, err := w.conn.RegisterListener(catalog.NewFileTransfer, &listener{w}) + if err != nil { + return nil, err + } return w, nil } @@ -109,7 +112,7 @@ func (w *Wrapper) RegisterSentProgressCallback(tid *ftCrypto.TransferID, return w.ft.RegisterSentProgressCallback(tid, modifiedProgressCB, period) } -// addEndMessageToCallback adds the sending of an Connection E2E message when +// addEndMessageToCallback adds the sending of a Connection E2E message when // the transfer completed to the callback. If NotifyUponCompletion is not set, // then the message is not sent. func (w *Wrapper) addEndMessageToCallback(progressCB ft.SentProgressCallback) ft.SentProgressCallback {