Skip to content
Snippets Groups Projects
Commit cfe0d3f1 authored by Jono Wenger's avatar Jono Wenger
Browse files

Merge branch 'jono/staleConnectionCleanup' into 'release'

Stale connection cleanup

See merge request !254
parents 0f58d3df 3aa9510e
Branches
Tags
2 merge requests!510Release,!254Stale connection cleanup
Showing
with 761 additions and 420 deletions
......@@ -100,6 +100,10 @@ type State interface {
// auth callback for the given partner ID.
DeletePartnerCallback(partnerId *id.ID)
// DeletePartner deletes the request and/or confirmation for the given
// partner.
DeletePartner(partner *id.ID) error
// Closer stops listening to auth.
io.Closer
}
......
......@@ -142,6 +142,22 @@ func (s *state) Close() error {
return nil
}
// DeletePartner deletes the request and/or confirmation for the given partner.
func (s *state) DeletePartner(partner *id.ID) error {
err := s.store.DeleteRequest(partner)
err2 := s.store.DeleteConfirmation(partner)
// Only return an error if both failed to delete
if err != nil && err2 != nil {
return errors.Errorf("Failed to delete partner: no requests or "+
"confirmations found: %s, %s", err, err2)
}
s.DeletePartnerCallback(partner)
return nil
}
// AddPartnerCallback that overrides the generic auth callback for the given partnerId
func (s *state) AddPartnerCallback(partnerId *id.ID, cb Callbacks) {
s.partnerCallbacks.AddPartnerCallback(partnerId, cb)
......
......@@ -88,6 +88,8 @@ func (m mockE2eHandler) Unregister(listenerID receive.ListenerID) {
return
}
func (m mockE2eHandler) UnregisterUserListeners(*id.ID) {}
func (m mockE2eHandler) AddPartner(partnerID *id.ID,
partnerPubKey, myPrivKey *cyclic.Int,
partnerSIDHPubKey *sidh.PublicKey, mySIDHPrivKey *sidh.PrivateKey,
......
......@@ -88,6 +88,7 @@ func (c *Connection) GetPartner() []byte {
// RegisterListener is used for E2E reception
// and allows for reading data sent from the partner.Manager
// Returns marshalled ListenerID
func (c *Connection) RegisterListener(messageType int, newListener Listener) {
_ = c.connection.RegisterListener(catalog.MessageType(messageType), listener{l: newListener})
func (c *Connection) RegisterListener(messageType int, newListener Listener) error {
_, err := c.connection.RegisterListener(catalog.MessageType(messageType), listener{l: newListener})
return err
}
......@@ -90,6 +90,9 @@ type serverAuthCallback struct {
confirmCallback Callback
requestCallback Callback
// Used to track stale connections
cl *ConnectionList
// Used for building new Connection objects
connectionParams Params
}
......@@ -97,10 +100,12 @@ type serverAuthCallback struct {
// getServerAuthCallback returns a callback interface to be passed into the creation
// of a xxdk.E2e object.
// it will accept requests only if a request callback is passed in
func getServerAuthCallback(confirm, request Callback, params Params) *serverAuthCallback {
func getServerAuthCallback(confirm, request Callback, cl *ConnectionList,
params Params) *serverAuthCallback {
return &serverAuthCallback{
confirmCallback: confirm,
requestCallback: request,
cl: cl,
connectionParams: params,
}
}
......@@ -136,8 +141,10 @@ func (a serverAuthCallback) Request(requestor contact.Contact,
}
// Return the new Connection object
a.requestCallback(BuildConnection(newPartner, e2e.GetE2E(),
e2e.GetAuth(), a.connectionParams))
c := BuildConnection(
newPartner, e2e.GetE2E(), e2e.GetAuth(), a.connectionParams)
a.cl.Add(c)
a.requestCallback(c)
}
// Reset will be called when an auth Reset operation occurs.
......
......@@ -170,7 +170,7 @@ func connectWithAuthentication(conn Connection, timeStart time.Time,
// authenticate themselves. An established AuthenticatedConnection will
// be passed via the callback.
func StartAuthenticatedServer(identity xxdk.ReceptionIdentity,
cb AuthenticatedCallback, net *xxdk.Cmix, p Params) (*xxdk.E2e, error) {
cb AuthenticatedCallback, net *xxdk.Cmix, p Params) (*ConnectionServer, error) {
// Register the waiter for a connection establishment
connCb := Callback(func(connection Connection) {
......@@ -178,8 +178,14 @@ func StartAuthenticatedServer(identity xxdk.ReceptionIdentity,
// client's identity proof. If an identity authentication
// message is received and validated, an authenticated connection will
// be passed along via the AuthenticatedCallback
connection.RegisterListener(catalog.ConnectionAuthenticationRequest,
_, err := connection.RegisterListener(
catalog.ConnectionAuthenticationRequest,
buildAuthConfirmationHandler(cb, connection))
if err != nil {
jww.ERROR.Printf(
"Failed to register listener on connection with %s: %+v",
connection.GetPartner().PartnerId(), err)
}
})
return StartServer(identity, connCb, net, p)
}
......
......@@ -8,8 +8,12 @@ package connect
import (
"encoding/json"
"gitlab.com/elixxir/client/e2e/rekey"
"gitlab.com/elixxir/client/event"
"gitlab.com/elixxir/client/xxdk"
"gitlab.com/xx_network/primitives/netTime"
"io"
"sync/atomic"
"time"
"github.com/pkg/errors"
......@@ -19,8 +23,6 @@ import (
clientE2e "gitlab.com/elixxir/client/e2e"
"gitlab.com/elixxir/client/e2e/ratchet/partner"
"gitlab.com/elixxir/client/e2e/receive"
"gitlab.com/elixxir/client/e2e/rekey"
"gitlab.com/elixxir/client/event"
"gitlab.com/elixxir/crypto/contact"
"gitlab.com/elixxir/crypto/e2e"
"gitlab.com/xx_network/primitives/id"
......@@ -32,6 +34,8 @@ const (
connectionTimeout = 15 * time.Second
)
var alreadyClosedErr = errors.New("connection is closed")
// Connection is a wrapper for the E2E and auth packages.
// It can be used to automatically establish an E2E partnership
// with a partner.Manager, or be built from an existing E2E partnership.
......@@ -52,7 +56,7 @@ type Connection interface {
// RegisterListener is used for E2E reception
// and allows for reading data sent from the partner.Manager
RegisterListener(messageType catalog.MessageType,
newListener receive.Listener) receive.ListenerID
newListener receive.Listener) (receive.ListenerID, error)
// Unregister listener for E2E reception
Unregister(listenerID receive.ListenerID)
......@@ -71,6 +75,10 @@ type Connection interface {
// PayloadSize Returns the max payload size for a partitionable E2E
// message
PayloadSize() uint
// LastUse returns the timestamp of the last time the connection was
// utilised.
LastUse() time.Time
}
// Callback is the callback format required to retrieve
......@@ -82,6 +90,7 @@ type Params struct {
Auth auth.Params
Rekey rekey.Params
Event event.Reporter `json:"-"`
List ConnectionListParams
Timeout time.Duration
}
......@@ -91,6 +100,7 @@ func GetDefaultParams() Params {
Auth: auth.GetDefaultTemporaryParams(),
Rekey: rekey.GetDefaultEphemeralParams(),
Event: event.NewEventManager(),
List: DefaultConnectionListParams(),
Timeout: connectionTimeout,
}
}
......@@ -158,13 +168,31 @@ func Connect(recipient contact.Contact, e2eClient *xxdk.E2e,
// This call does an xxDK.ephemeralLogin under the hood and the connection
// server must be the only listener on auth.
func StartServer(identity xxdk.ReceptionIdentity, cb Callback, net *xxdk.Cmix,
p Params) (*xxdk.E2e, error) {
p Params) (*ConnectionServer, error) {
// Create connection list and start cleanup thread
cl := NewConnectionList(p.List)
err := net.AddService(cl.CleanupThread)
if err != nil {
return nil, err
}
// Build callback for E2E negotiation
callback := getServerAuthCallback(nil, cb, p)
callback := getServerAuthCallback(nil, cb, cl, p)
e2eClient, err := xxdk.LoginEphemeral(net, callback, identity)
if err != nil {
return nil, err
}
// Return an ephemeral E2e object
return xxdk.LoginEphemeral(net, callback, identity)
return &ConnectionServer{e2eClient, cl}, nil
}
// ConnectionServer contains
type ConnectionServer struct {
e2e *xxdk.E2e
cl *ConnectionList
}
// handler provides an implementation for the Connection interface.
......@@ -172,6 +200,13 @@ type handler struct {
auth auth.State
partner partner.Manager
e2e clientE2e.Handler
// Timestamp of last time a message was sent or received (Unix nanoseconds)
lastUse *int64
// Indicates if the connection has been closed (0 = open, 1 = closed)
closed *uint32
params Params
}
......@@ -188,12 +223,30 @@ func BuildConnection(partner partner.Manager, e2eHandler clientE2e.Handler,
}
}
// Close deletes this Connection's partner.Manager and releases resources.
// Close deletes this Connection's partner.Manager and releases resources. If
// the connection is already closed, then nil is returned.
func (h *handler) Close() error {
if err := h.e2e.DeletePartner(h.partner.PartnerId()); err != nil {
if h.isClosed() {
return nil
}
// Get partner ID once at the top because PartnerId makes a copy
partnerID := h.partner.PartnerId()
// Unregister all listeners
h.e2e.UnregisterUserListeners(partnerID)
// Delete partner from e2e and auth
if err := h.e2e.DeletePartner(partnerID); err != nil {
return err
}
return h.auth.Close()
if err := h.auth.DeletePartner(partnerID); err != nil {
return err
}
atomic.StoreUint32(h.closed, 1)
return nil
}
// GetPartner returns the partner.Manager for this Connection.
......@@ -204,17 +257,25 @@ func (h *handler) GetPartner() partner.Manager {
// SendE2E is a wrapper for sending specifically to the Connection's
// partner.Manager.
func (h *handler) SendE2E(mt catalog.MessageType, payload []byte,
params clientE2e.Params) (
[]id.Round, e2e.MessageID, time.Time, error) {
params clientE2e.Params) ([]id.Round, e2e.MessageID, time.Time, error) {
if h.isClosed() {
return nil, e2e.MessageID{}, time.Time{}, alreadyClosedErr
}
h.updateLastUse(netTime.Now())
return h.e2e.SendE2E(mt, h.partner.PartnerId(), payload, params)
}
// RegisterListener is used for E2E reception
// and allows for reading data sent from the partner.Manager.
func (h *handler) RegisterListener(messageType catalog.MessageType,
newListener receive.Listener) receive.ListenerID {
return h.e2e.RegisterListener(h.partner.PartnerId(),
messageType, newListener)
newListener receive.Listener) (receive.ListenerID, error) {
if h.isClosed() {
return receive.ListenerID{}, alreadyClosedErr
}
lt := &listenerTracker{h, newListener}
return h.e2e.RegisterListener(h.partner.PartnerId(), messageType, lt), nil
}
// Unregister listener for E2E reception.
......@@ -245,3 +306,18 @@ func (h *handler) PartitionSize(payloadIndex uint) uint {
func (h *handler) PayloadSize() uint {
return h.e2e.PayloadSize()
}
// LastUse returns the timestamp of the last time the connection was utilised.
func (h *handler) LastUse() time.Time {
return time.Unix(0, atomic.LoadInt64(h.lastUse))
}
// updateLastUse updates the last use time stamp to the given time.
func (h *handler) updateLastUse(t time.Time) {
atomic.StoreInt64(h.lastUse, t.UnixNano())
}
// isClosed returns true if the connection is closed.
func (h *handler) isClosed() bool {
return atomic.LoadUint32(h.closed) == 1
}
////////////////////////////////////////////////////////////////////////////////
// Copyright © 2020 xx network SEZC //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file //
////////////////////////////////////////////////////////////////////////////////
package connect
import (
jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/client/stoppable"
"gitlab.com/xx_network/primitives/id"
"gitlab.com/xx_network/primitives/netTime"
"sync"
"time"
)
// ConnectionList is a list of all connections.
type ConnectionList struct {
list map[id.ID]Connection
p ConnectionListParams
mux sync.Mutex
}
// NewConnectionList initialises an empty ConnectionList.
func NewConnectionList(p ConnectionListParams) *ConnectionList {
return &ConnectionList{
list: make(map[id.ID]Connection),
p: p,
}
}
// Add adds the connection to the list.
func (cl *ConnectionList) Add(c Connection) {
cl.mux.Lock()
defer cl.mux.Unlock()
cl.list[*c.GetPartner().PartnerId()] = c
}
// CleanupThread runs the loop that runs the cleanup processes periodically.
func (cl *ConnectionList) CleanupThread() (stoppable.Stoppable, error) {
stop := stoppable.NewSingle("StaleConnectionCleanup")
go func() {
jww.INFO.Printf("Starting stale connection cleanup thread to delete "+
"connections older than %s. Running every %s.",
cl.p.MaxAge, cl.p.CleanupPeriod)
ticker := time.NewTicker(cl.p.CleanupPeriod)
for {
select {
case <-stop.Quit():
jww.INFO.Print(
"Stopping connection cleanup thread: stoppable triggered")
ticker.Stop()
stop.ToStopped()
case <-ticker.C:
jww.DEBUG.Print("Starting connection cleanup.")
cl.Cleanup()
}
}
}()
return stop, nil
}
// Cleanup disconnects all connections that have been stale for longer than the
// max allowed time.
func (cl *ConnectionList) Cleanup() {
cl.mux.Lock()
defer cl.mux.Unlock()
for partnerID, c := range cl.list {
lastUse := c.LastUse()
timeSinceLastUse := netTime.Since(lastUse)
if timeSinceLastUse > cl.p.MaxAge {
err := c.Close()
if err != nil {
jww.ERROR.Printf(
"Could not close connection with partner %s: %+v",
partnerID, err)
}
delete(cl.list, partnerID)
jww.INFO.Printf("Deleted stale connection for partner %s. "+
"Last use was %s ago (%s)",
&partnerID, timeSinceLastUse, lastUse.Local())
}
}
}
////////////////////////////////////////////////////////////////////////////////
// Parameters //
////////////////////////////////////////////////////////////////////////////////
// Default values.
const (
cleanupPeriodDefault = 5 * time.Minute
maxAgeDefault = 30 * time.Minute
)
// ConnectionListParams are the parameters used for the ConnectionList.
type ConnectionListParams struct {
// CleanupPeriod is the duration between when cleanups occur.
CleanupPeriod time.Duration
// MaxAge is the maximum age of an unused connection before it is deleted.
MaxAge time.Duration
}
// DefaultConnectionListParams returns a ConnectionListParams filled with
// default values.
func DefaultConnectionListParams() ConnectionListParams {
return ConnectionListParams{
CleanupPeriod: cleanupPeriodDefault,
MaxAge: maxAgeDefault,
}
}
////////////////////////////////////////////////////////////////////////////////
// Copyright © 2020 xx network SEZC //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file //
////////////////////////////////////////////////////////////////////////////////
package connect
import (
"gitlab.com/xx_network/primitives/id"
"gitlab.com/xx_network/primitives/netTime"
"reflect"
"testing"
"time"
)
// Tests that NewConnectionList returned the expected new ConnectionList.
func TestNewConnectionList(t *testing.T) {
expected := &ConnectionList{
list: make(map[id.ID]Connection),
p: DefaultConnectionListParams(),
}
cl := NewConnectionList(expected.p)
if !reflect.DeepEqual(expected, cl) {
t.Errorf("New ConnectionList did not match expected."+
"\nexpected: %+v\nreceived: %+v", expected, cl)
}
}
// Tests that ConnectionList.Add adds all the given connections to the list.
func TestConnectionList_Add(t *testing.T) {
cl := NewConnectionList(DefaultConnectionListParams())
expected := map[id.ID]Connection{
*id.NewIdFromString("p1", id.User, t): &handler{
partner: &mockPartner{partnerId: id.NewIdFromString("p1", id.User, t)}},
*id.NewIdFromString("p2", id.User, t): &handler{
partner: &mockPartner{partnerId: id.NewIdFromString("p2", id.User, t)}},
*id.NewIdFromString("p3", id.User, t): &handler{
partner: &mockPartner{partnerId: id.NewIdFromString("p3", id.User, t)}},
*id.NewIdFromString("p4", id.User, t): &handler{
partner: &mockPartner{partnerId: id.NewIdFromString("p4", id.User, t)}},
*id.NewIdFromString("p5", id.User, t): &handler{
partner: &mockPartner{partnerId: id.NewIdFromString("p5", id.User, t)}},
}
for _, c := range expected {
cl.Add(c)
}
if !reflect.DeepEqual(expected, cl.list) {
t.Errorf("List does not have expected connections."+
"\nexpected: %+v\nreceived: %+v", expected, cl.list)
}
}
// Tests that ConnectionList.Cleanup deletes only stale connections from the
// list and that they are closed.
func TestConnectionList_Cleanup(t *testing.T) {
cl := NewConnectionList(DefaultConnectionListParams())
list := []*mockConnection{
{
partner: &mockPartner{partnerId: id.NewIdFromString("p0", id.User, t)},
lastUse: netTime.Now().Add(-(cl.p.MaxAge * 2)),
}, {
partner: &mockPartner{partnerId: id.NewIdFromString("p1", id.User, t)},
lastUse: netTime.Now().Add(-(cl.p.MaxAge / 2)),
}, {
partner: &mockPartner{partnerId: id.NewIdFromString("p2", id.User, t)},
lastUse: netTime.Now().Add(-(cl.p.MaxAge + 10)),
}, {
partner: &mockPartner{partnerId: id.NewIdFromString("p3", id.User, t)},
lastUse: netTime.Now().Add(-(cl.p.MaxAge - time.Second)),
},
}
for _, c := range list {
cl.Add(c)
}
cl.Cleanup()
for i, c := range list {
if i%2 == 0 {
if _, exists := cl.list[*c.GetPartner().PartnerId()]; exists {
t.Errorf("Connection #%d exists while being stale.", i)
}
if !c.closed {
t.Errorf("Connection #%d was not closed.", i)
}
} else {
if _, exists := cl.list[*c.GetPartner().PartnerId()]; !exists {
t.Errorf("Connection #%d was removed when it was not stale.", i)
}
if c.closed {
t.Errorf("Connection #%d was closed.", i)
}
}
}
}
////////////////////////////////////////////////////////////////////////////////
// Parameters //
////////////////////////////////////////////////////////////////////////////////
// Tests that DefaultConnectionListParams returns a ConnectionListParams with
// the expected default values.
func TestDefaultConnectionListParams(t *testing.T) {
expected := ConnectionListParams{
CleanupPeriod: cleanupPeriodDefault,
MaxAge: maxAgeDefault,
}
p := DefaultConnectionListParams()
if !reflect.DeepEqual(expected, p) {
t.Errorf("Default ConnectionListParams does not match expected."+
"\nexpected: %+v\nreceived: %+v", expected, p)
}
}
////////////////////////////////////////////////////////////////////////////////
// Copyright © 2020 xx network SEZC //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file //
////////////////////////////////////////////////////////////////////////////////
package connect
import (
"gitlab.com/elixxir/client/e2e/receive"
"gitlab.com/xx_network/primitives/netTime"
)
// listenerTracker wraps a listener and updates the last use timestamp on every
// call to Hear.
type listenerTracker struct {
h *handler
l receive.Listener
}
// Hear updates the last call timestamp and then calls Hear on the wrapped
// listener.
func (lt *listenerTracker) Hear(item receive.Message) {
lt.h.updateLastUse(netTime.Now())
lt.l.Hear(item)
}
// Name returns a name, used for debugging.
func (lt *listenerTracker) Name() string { return lt.l.Name() }
////////////////////////////////////////////////////////////////////////////////
// Copyright © 2020 xx network SEZC //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file //
////////////////////////////////////////////////////////////////////////////////
package connect
import (
"gitlab.com/elixxir/client/e2e/receive"
"gitlab.com/xx_network/primitives/id"
"gitlab.com/xx_network/primitives/netTime"
"reflect"
"testing"
"time"
)
// Tests that listenerTracker.Hear correctly updates the last use timestamp and
// calls the wrapped Hear function.
func Test_listenerTracker_Hear(t *testing.T) {
lastUsed, closed := int64(0), uint32(0)
itemChan := make(chan receive.Message, 1)
lt := &listenerTracker{
h: &handler{
lastUse: &lastUsed,
closed: &closed,
},
l: &mockListener{itemChan, "mockListener"},
}
expected := receive.Message{
Payload: []byte("Message payload."),
Sender: id.NewIdFromString("senderID", id.User, t),
RecipientID: id.NewIdFromString("RecipientID", id.User, t),
}
lt.Hear(expected)
select {
case r := <-itemChan:
if !reflect.DeepEqual(expected, r) {
t.Errorf("Did not receive expected receive.Message."+
"\nexpected: %+v\nreceived: %+v", expected, r)
}
if netTime.Since(lt.h.LastUse()) > 300*time.Millisecond {
t.Errorf("Last use has incorrect time: %s", lt.h.LastUse())
}
case <-time.After(30 * time.Millisecond):
t.Error("Timed out waiting for Hear to be called.")
}
}
// Tests that listenerTracker.Name calls the wrapped listener Name function.
func Test_listenerTracker_Name(t *testing.T) {
expected := "mockListener"
lt := &listenerTracker{
l: &mockListener{make(chan receive.Message, 1), "mockListener"},
}
if lt.Name() != expected {
t.Errorf("Did not get expected name.\nexected: %s\nreceived: %s",
expected, lt.Name())
}
}
type mockListener struct {
item chan receive.Message
name string
}
func (m *mockListener) Hear(item receive.Message) { m.item <- item }
func (m *mockListener) Name() string { return m.name }
This diff is collapsed.
......@@ -103,6 +103,10 @@ type Handler interface {
// will no longer get called
Unregister(listenerID receive.ListenerID)
// UnregisterUserListeners removes all the listeners registered with the
// specified user.
UnregisterUserListeners(userID *id.ID)
/* === Partners ===================================================== */
// AddPartner adds a partner. Automatically creates both send
......
......@@ -44,13 +44,13 @@ func (bi *byId) Get(uid *id.ID) *set.Set {
// adds a listener to a set for the given ID. Creates a new set to add it to if
// the set does not exist
func (bi *byId) Add(uid *id.ID, l Listener) *set.Set {
s, ok := bi.list[*uid]
func (bi *byId) Add(lid ListenerID) *set.Set {
s, ok := bi.list[*lid.userID]
if !ok {
s = set.New(l)
bi.list[*uid] = s
s = set.New(lid)
bi.list[*lid.userID] = s
} else {
s.Insert(l)
s.Insert(lid)
}
return s
......@@ -58,13 +58,20 @@ func (bi *byId) Add(uid *id.ID, l Listener) *set.Set {
// Removes the passed listener from the set for UserID and
// deletes the set if it is empty if the ID is not a generic one
func (bi *byId) Remove(uid *id.ID, l Listener) {
s, ok := bi.list[*uid]
func (bi *byId) Remove(lid ListenerID) {
s, ok := bi.list[*lid.userID]
if ok {
s.Remove(l)
s.Remove(lid)
if s.Len() == 0 && !uid.Cmp(AnyUser()) && !uid.Cmp(&id.ID{}) {
delete(bi.list, *uid)
if s.Len() == 0 && !lid.userID.Cmp(AnyUser()) && !lid.userID.Cmp(&id.ID{}) {
delete(bi.list, *lid.userID)
}
}
}
// RemoveId removes all listeners registered for the given user ID.
func (bi *byId) RemoveId(uid *id.ID) {
if !uid.Cmp(AnyUser()) && !uid.Cmp(&id.ID{}) {
delete(bi.list, *uid)
}
}
......@@ -120,9 +120,9 @@ func TestById_Add_New(t *testing.T) {
uid := id.NewIdFromUInt(42, id.User, t)
l := &funcListener{}
lid := ListenerID{uid, 1, &funcListener{}}
nbi.Add(uid, l)
nbi.Add(lid)
s := nbi.list[*uid]
......@@ -130,7 +130,7 @@ func TestById_Add_New(t *testing.T) {
t.Errorf("Should a set of the wrong size")
}
if !s.Has(l) {
if !s.Has(lid) {
t.Errorf("Wrong set returned")
}
}
......@@ -142,26 +142,26 @@ func TestById_Add_Old(t *testing.T) {
uid := id.NewIdFromUInt(42, id.User, t)
l1 := &funcListener{}
l2 := &funcListener{}
lid1 := ListenerID{uid, 1, &funcListener{}}
lid2 := ListenerID{uid, 1, &funcListener{}}
set1 := set.New(l1)
set1 := set.New(lid1)
nbi.list[*uid] = set1
nbi.Add(uid, l2)
nbi.Add(lid2)
s := nbi.list[*uid]
if s.Len() != 2 {
t.Errorf("Should have returned a set")
t.Errorf("Incorrect set length.\nexpected: %d\nreceived: %d", 2, s.Len())
}
if !s.Has(l1) {
if !s.Has(lid1) {
t.Errorf("Set does not include the initial listener")
}
if !s.Has(l2) {
if !s.Has(lid2) {
t.Errorf("Set does not include the new listener")
}
}
......@@ -171,11 +171,11 @@ func TestById_Add_Old(t *testing.T) {
func TestById_Add_Generic(t *testing.T) {
nbi := newById()
l1 := &funcListener{}
l2 := &funcListener{}
lid1 := ListenerID{&id.ID{}, 1, &funcListener{}}
lid2 := ListenerID{AnyUser(), 1, &funcListener{}}
nbi.Add(&id.ID{}, l1)
nbi.Add(AnyUser(), l2)
nbi.Add(lid1)
nbi.Add(lid2)
s := nbi.generic
......@@ -183,11 +183,11 @@ func TestById_Add_Generic(t *testing.T) {
t.Errorf("Should have returned a set of size 2")
}
if !s.Has(l1) {
if !s.Has(lid1) {
t.Errorf("Set does not include the ZeroUser listener")
}
if !s.Has(l2) {
if !s.Has(lid2) {
t.Errorf("Set does not include the empty user listener")
}
}
......@@ -199,14 +199,14 @@ func TestById_Remove_ManyInSet(t *testing.T) {
uid := id.NewIdFromUInt(42, id.User, t)
l1 := &funcListener{}
l2 := &funcListener{}
lid1 := ListenerID{uid, 1, &funcListener{}}
lid2 := ListenerID{uid, 1, &funcListener{}}
set1 := set.New(l1, l2)
set1 := set.New(lid1, lid2)
nbi.list[*uid] = set1
nbi.Remove(uid, l1)
nbi.Remove(lid1)
if _, ok := nbi.list[*uid]; !ok {
t.Errorf("Set removed when it should not have been")
......@@ -217,11 +217,11 @@ func TestById_Remove_ManyInSet(t *testing.T) {
set1.Len())
}
if set1.Has(l1) {
if set1.Has(lid1) {
t.Errorf("Listener 1 still in set, it should not be")
}
if !set1.Has(l2) {
if !set1.Has(lid2) {
t.Errorf("Listener 2 not still in set, it should be")
}
......@@ -234,13 +234,13 @@ func TestById_Remove_SingleInSet(t *testing.T) {
uid := id.NewIdFromUInt(42, id.User, t)
l1 := &funcListener{}
lid1 := ListenerID{uid, 1, &funcListener{}}
set1 := set.New(l1)
set1 := set.New(lid1)
nbi.list[*uid] = set1
nbi.Remove(uid, l1)
nbi.Remove(lid1)
if _, ok := nbi.list[*uid]; ok {
t.Errorf("Set not removed when it should have been")
......@@ -251,7 +251,7 @@ func TestById_Remove_SingleInSet(t *testing.T) {
set1.Len())
}
if set1.Has(l1) {
if set1.Has(lid1) {
t.Errorf("Listener 1 still in set, it should not be")
}
}
......@@ -263,13 +263,13 @@ func TestById_Remove_SingleInSet_ZeroUser(t *testing.T) {
uid := &id.ZeroUser
l1 := &funcListener{}
lid1 := ListenerID{uid, 1, &funcListener{}}
set1 := set.New(l1)
set1 := set.New(lid1)
nbi.list[*uid] = set1
nbi.Remove(uid, l1)
nbi.Remove(lid1)
if _, ok := nbi.list[*uid]; !ok {
t.Errorf("Set removed when it should not have been")
......@@ -280,7 +280,7 @@ func TestById_Remove_SingleInSet_ZeroUser(t *testing.T) {
set1.Len())
}
if set1.Has(l1) {
if set1.Has(lid1) {
t.Errorf("Listener 1 still in set, it should not be")
}
}
......@@ -292,13 +292,13 @@ func TestById_Remove_SingleInSet_EmptyUser(t *testing.T) {
uid := &id.ID{}
l1 := &funcListener{}
lid1 := ListenerID{uid, 1, &funcListener{}}
set1 := set.New(l1)
set1 := set.New(lid1)
nbi.list[*uid] = set1
nbi.Remove(uid, l1)
nbi.Remove(lid1)
if _, ok := nbi.list[*uid]; !ok {
t.Errorf("Set removed when it should not have been")
......@@ -309,7 +309,7 @@ func TestById_Remove_SingleInSet_EmptyUser(t *testing.T) {
set1.Len())
}
if set1.Has(l1) {
if set1.Has(lid1) {
t.Errorf("Listener 1 still in set, it should not be")
}
}
......@@ -45,13 +45,13 @@ func (bt *byType) Get(messageType catalog.MessageType) *set.Set {
// adds a listener to a set for the given messageType. Creates a new set to add
// it to if the set does not exist
func (bt *byType) Add(messageType catalog.MessageType, r Listener) *set.Set {
s, ok := bt.list[messageType]
func (bt *byType) Add(lid ListenerID) *set.Set {
s, ok := bt.list[lid.messageType]
if !ok {
s = set.New(r)
bt.list[messageType] = s
s = set.New(lid)
bt.list[lid.messageType] = s
} else {
s.Insert(r)
s.Insert(lid)
}
return s
......@@ -59,13 +59,13 @@ func (bt *byType) Add(messageType catalog.MessageType, r Listener) *set.Set {
// Removes the passed listener from the set for messageType and
// deletes the set if it is empty and the type is not AnyType
func (bt *byType) Remove(mt catalog.MessageType, l Listener) {
s, ok := bt.list[mt]
func (bt *byType) Remove(lid ListenerID) {
s, ok := bt.list[lid.messageType]
if ok {
s.Remove(l)
s.Remove(lid)
if s.Len() == 0 && mt != AnyType {
delete(bt.list, mt)
if s.Len() == 0 && lid.messageType != AnyType {
delete(bt.list, lid.messageType)
}
}
}
......@@ -10,6 +10,7 @@ package receive
import (
"github.com/golang-collections/collections/set"
"gitlab.com/elixxir/client/catalog"
"gitlab.com/xx_network/primitives/id"
"testing"
)
......@@ -108,9 +109,9 @@ func TestByType_Add_New(t *testing.T) {
m := catalog.MessageType(42)
l := &funcListener{}
l := ListenerID{&id.ZeroUser, m, &funcListener{}}
nbt.Add(m, l)
nbt.Add(l)
s := nbt.list[m]
......@@ -130,14 +131,14 @@ func TestByType_Add_Old(t *testing.T) {
m := catalog.MessageType(42)
l1 := &funcListener{}
l2 := &funcListener{}
lid1 := ListenerID{&id.ZeroUser, m, &funcListener{}}
lid2 := ListenerID{&id.ZeroUser, m, &funcListener{}}
set1 := set.New(l1)
set1 := set.New(lid1)
nbt.list[m] = set1
nbt.Add(m, l2)
nbt.Add(lid2)
s := nbt.list[m]
......@@ -145,11 +146,11 @@ func TestByType_Add_Old(t *testing.T) {
t.Errorf("Should have returned a set")
}
if !s.Has(l1) {
if !s.Has(lid1) {
t.Errorf("Set does not include the initial listener")
}
if !s.Has(l2) {
if !s.Has(lid2) {
t.Errorf("Set does not include the new listener")
}
}
......@@ -159,9 +160,9 @@ func TestByType_Add_Old(t *testing.T) {
func TestByType_Add_Generic(t *testing.T) {
nbt := newByType()
l1 := &funcListener{}
lid1 := ListenerID{&id.ZeroUser, AnyType, &funcListener{}}
nbt.Add(AnyType, l1)
nbt.Add(lid1)
s := nbt.generic
......@@ -169,7 +170,7 @@ func TestByType_Add_Generic(t *testing.T) {
t.Errorf("Should have returned a set of size 2")
}
if !s.Has(l1) {
if !s.Has(lid1) {
t.Errorf("Set does not include the ZeroUser listener")
}
}
......@@ -181,13 +182,13 @@ func TestByType_Remove_SingleInSet(t *testing.T) {
m := catalog.MessageType(42)
l1 := &funcListener{}
lid1 := ListenerID{&id.ZeroUser, m, &funcListener{}}
set1 := set.New(l1)
set1 := set.New(lid1)
nbt.list[m] = set1
nbt.Remove(m, l1)
nbt.Remove(lid1)
if _, ok := nbt.list[m]; ok {
t.Errorf("Set not removed when it should have been")
......@@ -198,7 +199,7 @@ func TestByType_Remove_SingleInSet(t *testing.T) {
set1.Len())
}
if set1.Has(l1) {
if set1.Has(lid1) {
t.Errorf("Listener 1 still in set, it should not be")
}
}
......@@ -210,13 +211,13 @@ func TestByType_Remove_SingleInSet_AnyType(t *testing.T) {
m := AnyType
l1 := &funcListener{}
lid1 := ListenerID{&id.ZeroUser, m, &funcListener{}}
set1 := set.New(l1)
set1 := set.New(lid1)
nbt.list[m] = set1
nbt.Remove(m, l1)
nbt.Remove(lid1)
if _, ok := nbt.list[m]; !ok {
t.Errorf("Set removed when it should not have been")
......@@ -227,7 +228,7 @@ func TestByType_Remove_SingleInSet_AnyType(t *testing.T) {
set1.Len())
}
if set1.Has(l1) {
if set1.Has(lid1) {
t.Errorf("Listener 1 still in set, it should not be")
}
}
......@@ -11,6 +11,8 @@ import (
jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/client/catalog"
"gitlab.com/xx_network/primitives/id"
"strconv"
"strings"
)
//Listener interface for a listener adhere to
......@@ -50,6 +52,19 @@ func (lid ListenerID) GetName() string {
return lid.listener.Name()
}
// String returns the values in the ListenerID in a human-readable format. This
// functions adheres to the fmt.Stringer interface.
func (lid ListenerID) String() string {
str := []string{
"userID:" + lid.userID.String(),
"messageType:" + lid.messageType.String() +
"(" + strconv.FormatUint(uint64(lid.messageType), 10) + ")",
"listener:" + lid.listener.Name(),
}
return "{" + strings.Join(str, " ") + "}"
}
/*internal listener implementations*/
//listener based off of a function
......
......@@ -55,20 +55,23 @@ func (sw *Switchboard) RegisterListener(user *id.ID,
jww.FATAL.Panicf("cannot register nil listener")
}
// Create new listener ID
lid := ListenerID{
userID: user,
messageType: messageType,
listener: newListener,
}
//register the listener by both ID and messageType
sw.mux.Lock()
sw.id.Add(user, newListener)
sw.messageType.Add(messageType, newListener)
sw.id.Add(lid)
sw.messageType.Add(lid)
sw.mux.Unlock()
//return a ListenerID so it can be unregistered in the future
return ListenerID{
userID: user,
messageType: messageType,
listener: newListener,
}
return lid
}
// RegisterFunc Registers a new listener built around the passed function.
......@@ -141,8 +144,8 @@ func (sw *Switchboard) Speak(item Message) {
//Execute hear on all matched listeners in a new goroutine
matches.Do(func(i interface{}) {
r := i.(Listener)
go r.Hear(item)
lid := i.(ListenerID)
go lid.listener.Hear(item)
})
// print to log if nothing was heard
......@@ -158,12 +161,32 @@ func (sw *Switchboard) Speak(item Message) {
func (sw *Switchboard) Unregister(listenerID ListenerID) {
sw.mux.Lock()
sw.id.Remove(listenerID.userID, listenerID.listener)
sw.messageType.Remove(listenerID.messageType, listenerID.listener)
sw.id.Remove(listenerID)
sw.messageType.Remove(listenerID)
sw.mux.Unlock()
}
// UnregisterUserListeners removes all the listeners registered with the
// specified user.
func (sw *Switchboard) UnregisterUserListeners(userID *id.ID) {
sw.mux.Lock()
defer sw.mux.Unlock()
// Get list of all listeners for the specified user
idSet := sw.id.Get(userID)
// Find each listener in the messageType list and delete it
idSet.Do(func(i interface{}) {
lid := i.(ListenerID)
mtSet := sw.messageType.list[lid.messageType]
mtSet.Remove(lid)
})
// Remove all listeners for the user from the ID list
sw.id.RemoveId(userID)
}
// finds all listeners who match the items sender or ID, or have those fields
// as generic
func (sw *Switchboard) matchListeners(item Message) *set.Set {
......
......@@ -85,13 +85,13 @@ func TestSwitchboard_RegisterListener(t *testing.T) {
//check that the listener is registered in the appropriate location
setID := sw.id.Get(uid)
if !setID.Has(l) {
if !setID.Has(lid) {
t.Errorf("Listener is not registered by ID")
}
setType := sw.messageType.Get(mt)
if !setType.Has(l) {
if !setType.Has(lid) {
t.Errorf("Listener is not registered by Message Tag")
}
......@@ -152,13 +152,13 @@ func TestSwitchboard_RegisterFunc(t *testing.T) {
//check that the listener is registered in the appropriate location
setID := sw.id.Get(uid)
if !setID.Has(lid.listener) {
if !setID.Has(lid) {
t.Errorf("Listener is not registered by ID")
}
setType := sw.messageType.Get(mt)
if !setType.Has(lid.listener) {
if !setType.Has(lid) {
t.Errorf("Listener is not registered by Message Tag")
}
......@@ -221,13 +221,13 @@ func TestSwitchboard_RegisterChan(t *testing.T) {
//check that the listener is registered in the appropriate location
setID := sw.id.Get(uid)
if !setID.Has(lid.listener) {
if !setID.Has(lid) {
t.Errorf("Listener is not registered by ID")
}
setType := sw.messageType.Get(mt)
if !setType.Has(lid.listener) {
if !setType.Has(lid) {
t.Errorf("Listener is not registered by Message Tag")
}
......@@ -330,21 +330,67 @@ func TestSwitchboard_Unregister(t *testing.T) {
setType := sw.messageType.Get(mt)
//check that the removed listener is not registered
if setID.Has(lid1.listener) {
if setID.Has(lid1) {
t.Errorf("Removed Listener is registered by ID, should not be")
}
if setType.Has(lid1.listener) {
if setType.Has(lid1) {
t.Errorf("Removed Listener not registered by Message Tag, " +
"should not be")
}
//check that the not removed listener is still registered
if !setID.Has(lid2.listener) {
if !setID.Has(lid2) {
t.Errorf("Remaining Listener is not registered by ID")
}
if !setType.Has(lid2.listener) {
if !setType.Has(lid2) {
t.Errorf("Remaining Listener is not registered by Message Tag")
}
}
// Tests that three listeners with three different message types but the same
// user are deleted with Switchboard.UnregisterUserListeners and that three
// other listeners with the same message types but different users are not
// delete.
func TestSwitchboard_UnregisterUserListeners(t *testing.T) {
sw := New()
uid1 := id.NewIdFromUInt(42, id.User, t)
uid2 := id.NewIdFromUInt(11, id.User, t)
l := func(receive Message) {}
lid1 := sw.RegisterFunc("a", uid1, catalog.NoType, l)
lid2 := sw.RegisterFunc("b", uid1, catalog.XxMessage, l)
lid3 := sw.RegisterFunc("c", uid1, catalog.KeyExchangeTrigger, l)
lid4 := sw.RegisterFunc("d", uid2, catalog.NoType, l)
lid5 := sw.RegisterFunc("e", uid2, catalog.XxMessage, l)
lid6 := sw.RegisterFunc("f", uid2, catalog.KeyExchangeTrigger, l)
sw.UnregisterUserListeners(uid1)
if s, exists := sw.id.list[*uid1]; exists {
t.Errorf("Set for ID %s still exists: %+v", uid1, s)
}
if sw.messageType.Get(lid1.messageType).Has(lid1) {
t.Errorf("Listener %+v still exists", lid1)
}
if sw.messageType.Get(lid2.messageType).Has(lid2) {
t.Errorf("Listener %+v still exists", lid2)
}
if sw.messageType.Get(lid3.messageType).Has(lid3) {
t.Errorf("Listener %+v still exists", lid3)
}
if !sw.messageType.Get(lid4.messageType).Has(lid4) {
t.Errorf("Listener %+v does not exist", lid4)
}
if !sw.messageType.Get(lid5.messageType).Has(lid5) {
t.Errorf("Listener %+v does not exist", lid5)
}
if !sw.messageType.Get(lid6.messageType).Has(lid6) {
t.Errorf("Listener %+v does not exist", lid6)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment