From 0610e442803252f2ecfcbb182a9376596bd870f5 Mon Sep 17 00:00:00 2001 From: Jono Wenger <jono@elixxir.io> Date: Tue, 30 Aug 2022 20:00:34 +0000 Subject: [PATCH] XX-4120 / joinedChannel Tests --- bindings/broadcast.go | 2 +- broadcast/asymmetric_test.go | 2 +- broadcast/client.go | 18 +- broadcast/interface.go | 8 +- broadcast/symmetric_test.go | 2 +- channels/eventModel.go | 2 +- channels/eventModel_test.go | 4 +- channels/joinedChannel.go | 70 ++--- channels/joinedChannel_test.go | 464 +++++++++++++++++++++++++++++++++ channels/manager.go | 4 +- channels/send_test.go | 8 +- 11 files changed, 531 insertions(+), 53 deletions(-) create mode 100644 channels/joinedChannel_test.go diff --git a/bindings/broadcast.go b/bindings/broadcast.go index 6cea173b9..bf9cd1431 100644 --- a/bindings/broadcast.go +++ b/bindings/broadcast.go @@ -103,7 +103,7 @@ func NewBroadcastChannel(cmixId int, channelDefinition []byte) (*Channel, error) return nil, errors.WithMessage(err, "Failed to load public key") } - ch, err := broadcast.NewBroadcastChannel(cryptoBroadcast.Channel{ + ch, err := broadcast.NewBroadcastChannel(&cryptoBroadcast.Channel{ ReceptionID: channelID, Name: def.Name, Description: def.Description, diff --git a/broadcast/asymmetric_test.go b/broadcast/asymmetric_test.go index 776879127..fd0a0cca5 100644 --- a/broadcast/asymmetric_test.go +++ b/broadcast/asymmetric_test.go @@ -38,7 +38,7 @@ func Test_asymmetricClient_Smoke(t *testing.T) { if err != nil { t.Errorf("Failed to create channel ID: %+v", err) } - channel := crypto.Channel{ + channel := &crypto.Channel{ ReceptionID: cid, Name: cname, Description: cdesc, diff --git a/broadcast/client.go b/broadcast/client.go index b9daa8c11..37837f383 100644 --- a/broadcast/client.go +++ b/broadcast/client.go @@ -19,15 +19,15 @@ import ( // broadcastClient implements the Channel interface for sending/receiving asymmetric or symmetric broadcast messages type broadcastClient struct { - channel crypto.Channel + channel *crypto.Channel net Client rng *fastRNG.StreamGenerator } -type NewBroadcastChannelFunc func(channel crypto.Channel, net Client, rng *fastRNG.StreamGenerator) (Channel, error) +type NewBroadcastChannelFunc func(channel *crypto.Channel, net Client, rng *fastRNG.StreamGenerator) (Channel, error) // NewBroadcastChannel creates a channel interface based on crypto.Channel, accepts net client connection & callback for received messages -func NewBroadcastChannel(channel crypto.Channel, net Client, rng *fastRNG.StreamGenerator) (Channel, error) { +func NewBroadcastChannel(channel *crypto.Channel, net Client, rng *fastRNG.StreamGenerator) (Channel, error) { bc := &broadcastClient{ channel: channel, net: net, @@ -60,7 +60,7 @@ func (bc *broadcastClient) RegisterListener(listenerCb ListenerFunc, method Meth } p := &processor{ - c: &bc.channel, + c: bc.channel, cb: listenerCb, method: method, } @@ -84,14 +84,16 @@ func (bc *broadcastClient) Stop() { bc.net.DeleteClientService(bc.channel.ReceptionID) } -// Get returns the underlying crypto.Channel object -func (bc *broadcastClient) Get() crypto.Channel { +// Get returns the underlying crypto.Channel object. +func (bc *broadcastClient) Get() *crypto.Channel { return bc.channel } -// verifyID generates a symmetric ID based on the info in the channel & compares it to the one passed in +// verifyID generates a symmetric ID based on the info in the channel and +// compares it to the one passed in. func (bc *broadcastClient) verifyID() bool { - gen, err := crypto.NewChannelID(bc.channel.Name, bc.channel.Description, bc.channel.Salt, rsa.CreatePublicKeyPem(bc.channel.RsaPubKey)) + gen, err := crypto.NewChannelID(bc.channel.Name, bc.channel.Description, + bc.channel.Salt, rsa.CreatePublicKeyPem(bc.channel.RsaPubKey)) if err != nil { jww.FATAL.Panicf("[verifyID] Failed to generate verified channel ID") return false diff --git a/broadcast/interface.go b/broadcast/interface.go index 7eef2d0d7..ef322dd66 100644 --- a/broadcast/interface.go +++ b/broadcast/interface.go @@ -19,8 +19,8 @@ import ( "time" ) -// ListenerFunc is registered when creating a new broadcasting channel -// and receives all new broadcast messages for the channel. +// ListenerFunc is registered when creating a new broadcasting channel and +// receives all new broadcast messages for the channel. type ListenerFunc func(payload []byte, receptionID receptionID.EphemeralIdentity, round rounds.Round) @@ -30,11 +30,11 @@ type Channel interface { MaxPayloadSize() int // MaxAsymmetricPayloadSize returns the maximum size for an asymmetric - //broadcast payload + // broadcast payload MaxAsymmetricPayloadSize() int // Get returns the underlying crypto.Channel - Get() crypto.Channel + Get() *crypto.Channel // Broadcast broadcasts the payload to the channel. The payload size must be // equal to MaxPayloadSize. diff --git a/broadcast/symmetric_test.go b/broadcast/symmetric_test.go index 2eb26d29e..b8d750d31 100644 --- a/broadcast/symmetric_test.go +++ b/broadcast/symmetric_test.go @@ -44,7 +44,7 @@ func Test_symmetricClient_Smoke(t *testing.T) { if err != nil { t.Errorf("Failed to create channel ID: %+v", err) } - channel := crypto.Channel{ + channel := &crypto.Channel{ ReceptionID: cid, Name: cname, Description: cdesc, diff --git a/channels/eventModel.go b/channels/eventModel.go index cd67c8123..0dd58014e 100644 --- a/channels/eventModel.go +++ b/channels/eventModel.go @@ -26,7 +26,7 @@ var ( // system passed an object which adheres to in order to get events on the channel type EventModel interface { // JoinChannel is called whenever a channel is joined locally - JoinChannel(channel cryptoBroadcast.Channel) + JoinChannel(channel *cryptoBroadcast.Channel) // LeaveChannel is called whenever a channel is left locally LeaveChannel(channelID *id.ID) diff --git a/channels/eventModel_test.go b/channels/eventModel_test.go index 22c0af9e4..9379c1945 100644 --- a/channels/eventModel_test.go +++ b/channels/eventModel_test.go @@ -30,8 +30,8 @@ type MockEvent struct { eventReceive } -func (*MockEvent) JoinChannel(channel cryptoBroadcast.Channel) {} -func (*MockEvent) LeaveChannel(channelID *id.ID) {} +func (*MockEvent) JoinChannel(channel *cryptoBroadcast.Channel) {} +func (*MockEvent) LeaveChannel(channelID *id.ID) {} func (m *MockEvent) ReceiveMessage(channelID *id.ID, messageID cryptoChannel.MessageID, senderUsername string, text string, timestamp time.Time, lease time.Duration, round rounds.Round) { diff --git a/channels/joinedChannel.go b/channels/joinedChannel.go index ce53e29ac..7f058e3b5 100644 --- a/channels/joinedChannel.go +++ b/channels/joinedChannel.go @@ -48,10 +48,9 @@ func (m *manager) storeUnsafe() error { // loadChannels loads all currently joined channels from disk and registers them // for message reception. func (m *manager) loadChannels() { - obj, err := m.kv.Get(joinedChannelsKey, joinedChannelsVersion) if !m.kv.Exists(err) { - m.channels = make(map[*id.ID]*joinedChannel) + m.channels = make(map[id.ID]*joinedChannel) return } else if err != nil { jww.FATAL.Panicf("Failed to load channels: %+v", err) @@ -62,7 +61,7 @@ func (m *manager) loadChannels() { jww.FATAL.Panicf("Failed to load channels: %+v", err) } - chMap := make(map[*id.ID]*joinedChannel) + chMap := make(map[id.ID]*joinedChannel) for i := range chList { jc, err := loadJoinedChannel( @@ -70,17 +69,17 @@ func (m *manager) loadChannels() { if err != nil { jww.FATAL.Panicf("Failed to load channel %s: %+v", chList[i], err) } - chMap[chList[i]] = jc + chMap[*chList[i]] = jc } m.channels = chMap } // addChannel adds a channel. -func (m *manager) addChannel(channel cryptoBroadcast.Channel) error { +func (m *manager) addChannel(channel *cryptoBroadcast.Channel) error { m.mux.Lock() defer m.mux.Unlock() - if _, exists := m.channels[channel.ReceptionID]; exists { + if _, exists := m.channels[*channel.ReceptionID]; exists { return ChannelAlreadyExistsErr } @@ -89,6 +88,19 @@ func (m *manager) addChannel(channel cryptoBroadcast.Channel) error { return err } + jc := &joinedChannel{b} + if err = jc.Store(m.kv); err != nil { + go b.Stop() + return err + } + + m.channels[*jc.broadcast.Get().ReceptionID] = jc + + if err = m.storeUnsafe(); err != nil { + go b.Stop() + return err + } + // Connect to listeners err = b.RegisterListener((&userListener{ name: m.name, @@ -107,17 +119,6 @@ func (m *manager) addChannel(channel cryptoBroadcast.Channel) error { return err } - jc := &joinedChannel{b} - if err = jc.Store(m.kv); err != nil { - go b.Stop() - return err - } - - if err = m.storeUnsafe(); err != nil { - go b.Stop() - return err - } - return nil } @@ -128,16 +129,21 @@ func (m *manager) removeChannel(channelID *id.ID) error { m.mux.Lock() defer m.mux.Unlock() - ch, exists := m.channels[channelID] + ch, exists := m.channels[*channelID] if !exists { return ChannelDoesNotExistsErr } ch.broadcast.Stop() - delete(m.channels, channelID) + delete(m.channels, *channelID) - return nil + err := m.storeUnsafe() + if err != nil { + return err + } + + return ch.delete(m.kv) } // getChannel returns the given channel. Returns ChannelDoesNotExistsErr error @@ -146,7 +152,7 @@ func (m *manager) getChannel(channelID *id.ID) (*joinedChannel, error) { m.mux.RLock() defer m.mux.RUnlock() - jc, exists := m.channels[channelID] + jc, exists := m.channels[*channelID] if !exists { return nil, ChannelDoesNotExistsErr } @@ -168,7 +174,7 @@ func (m *manager) getChannels() []*id.ID { func (m *manager) getChannelsUnsafe() []*id.ID { list := make([]*id.ID, 0, len(m.channels)) for chID := range m.channels { - list = append(list, chID) + list = append(list, chID.DeepCopy()) } return list } @@ -181,12 +187,12 @@ type joinedChannel struct { // joinedChannelDisk is the representation of joinedChannel for storage. type joinedChannelDisk struct { - broadcast cryptoBroadcast.Channel + Broadcast *cryptoBroadcast.Channel } // Store writes the given channel to a unique storage location within the EKV. func (jc *joinedChannel) Store(kv *versioned.KV) error { - jcd := joinedChannelDisk{broadcast: jc.broadcast.Get()} + jcd := joinedChannelDisk{jc.broadcast.Get()} data, err := json.Marshal(&jcd) if err != nil { return err @@ -198,8 +204,8 @@ func (jc *joinedChannel) Store(kv *versioned.KV) error { Data: data, } - return kv.Set(makeJoinedChannelKey( - jc.broadcast.Get().ReceptionID), joinedChannelVersion, obj) + return kv.Set(makeJoinedChannelKey(jc.broadcast.Get().ReceptionID), + joinedChannelVersion, obj) } // loadJoinedChannel loads a given channel from ekv storage. @@ -217,14 +223,14 @@ func loadJoinedChannel(chId *id.ID, kv *versioned.KV, net broadcast.Client, if err != nil { return nil, err } - b, err := broadcastMaker(jcd.broadcast, net, rngGen) + b, err := broadcastMaker(jcd.Broadcast, net, rngGen) if err != nil { return nil, err } err = b.RegisterListener((&userListener{ name: name, - chID: jcd.broadcast.ReceptionID, + chID: jcd.Broadcast.ReceptionID, trigger: e.triggerEvent, }).Listen, broadcast.Symmetric) if err != nil { @@ -232,7 +238,7 @@ func loadJoinedChannel(chId *id.ID, kv *versioned.KV, net broadcast.Client, } err = b.RegisterListener((&adminListener{ - chID: jcd.broadcast.ReceptionID, + chID: jcd.Broadcast.ReceptionID, trigger: e.triggerAdminEvent, }).Listen, broadcast.Asymmetric) if err != nil { @@ -243,6 +249,12 @@ func loadJoinedChannel(chId *id.ID, kv *versioned.KV, net broadcast.Client, return jc, nil } +// delete removes the channel from the kv. +func (jc *joinedChannel) delete(kv *versioned.KV) error { + return kv.Delete(makeJoinedChannelKey(jc.broadcast.Get().ReceptionID), + joinedChannelVersion) +} + func makeJoinedChannelKey(chId *id.ID) string { return joinedChannelKey + chId.HexEncode() } diff --git a/channels/joinedChannel_test.go b/channels/joinedChannel_test.go new file mode 100644 index 000000000..2043ec7d6 --- /dev/null +++ b/channels/joinedChannel_test.go @@ -0,0 +1,464 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package channels + +import ( + "bytes" + "encoding/binary" + "gitlab.com/elixxir/client/broadcast" + "gitlab.com/elixxir/client/storage/versioned" + cryptoBroadcast "gitlab.com/elixxir/crypto/broadcast" + "gitlab.com/elixxir/crypto/cmix" + "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/crypto/signature/rsa" + "gitlab.com/xx_network/primitives/id" + "reflect" + "sort" + "strconv" + "testing" +) + +// Tests that manager.store stores the channel list in the ekv. +func Test_manager_store(t *testing.T) { + m := NewManager(versioned.NewKV(ekv.MakeMemstore()), + new(mockBroadcastClient), + fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + new(mockNameService), new(mockEventModel)).(*manager) + + for i := 0; i < 10; i++ { + ch, _, err := newTestChannel("name_"+strconv.Itoa(i), + "description_"+strconv.Itoa(i), m.rng.GetStream()) + if err != nil { + t.Errorf("Failed to create new channel %d: %+v", i, err) + } + + b, err := broadcast.NewBroadcastChannel(ch, m.client, m.rng) + if err != nil { + t.Errorf("Failed to make new broadcast channel: %+v", err) + } + + m.channels[*ch.ReceptionID] = &joinedChannel{b} + } + + err := m.store() + if err != nil { + t.Errorf("Error storing channels: %+v", err) + } + + _, err = m.kv.Get(joinedChannelsKey, joinedChannelsVersion) + if !ekv.Exists(err) { + t.Errorf("channel list not found in KV: %+v", err) + } +} + +// Tests that the manager.loadChannels loads all the expected channels from the +// ekv. +func Test_manager_loadChannels(t *testing.T) { + m := NewManager(versioned.NewKV(ekv.MakeMemstore()), + new(mockBroadcastClient), + fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + new(mockNameService), new(mockEventModel)).(*manager) + + expected := make([]*joinedChannel, 10) + + for i := range expected { + ch, _, err := newTestChannel("name_"+strconv.Itoa(i), + "description_"+strconv.Itoa(i), m.rng.GetStream()) + if err != nil { + t.Errorf("Failed to create new channel %d: %+v", i, err) + } + + b, err := broadcast.NewBroadcastChannel(ch, m.client, m.rng) + if err != nil { + t.Errorf("Failed to make new broadcast channel: %+v", err) + } + + jc := &joinedChannel{b} + if err = jc.Store(m.kv); err != nil { + t.Errorf("Failed to store joinedChannel %d: %+v", i, err) + } + + chID := *ch.ReceptionID + m.channels[chID] = jc + expected[i] = jc + } + + err := m.store() + if err != nil { + t.Errorf("Error storing channels: %+v", err) + } + + newManager := &manager{ + channels: make(map[id.ID]*joinedChannel), + kv: m.kv, + client: m.client, + rng: m.rng, + broadcastMaker: m.broadcastMaker, + } + + newManager.loadChannels() + + for chID, loadedCh := range newManager.channels { + ch, exists := m.channels[chID] + if !exists { + t.Errorf("Channel %s does not exist.", &chID) + } + + if !reflect.DeepEqual(ch.broadcast, loadedCh.broadcast) { + t.Errorf("Channel %s does not match loaded channel."+ + "\nexpected: %+v\nreceived: %+v", &chID, ch.broadcast, loadedCh.broadcast) + } + } +} + +// Tests that manager.addChannel adds the channel to the map and stores it in +// the kv. +func Test_manager_addChannel(t *testing.T) { + m := NewManager(versioned.NewKV(ekv.MakeMemstore()), + new(mockBroadcastClient), + fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + new(mockNameService), new(mockEventModel)).(*manager) + + ch, _, err := newTestChannel("name", "description", m.rng.GetStream()) + if err != nil { + t.Errorf("Failed to create new channel: %+v", err) + } + + err = m.addChannel(ch) + if err != nil { + t.Errorf("Failed to add new channel: %+v", err) + } + + if _, exists := m.channels[*ch.ReceptionID]; !exists { + t.Errorf("Channel %s not added to channel map.", ch.Name) + } + + _, err = m.kv.Get(makeJoinedChannelKey(ch.ReceptionID), joinedChannelVersion) + if err != nil { + t.Errorf("Failed to get joinedChannel from kv: %+v", err) + } + + _, err = m.kv.Get(joinedChannelsKey, joinedChannelsVersion) + if err != nil { + t.Errorf("Failed to get channels from kv: %+v", err) + } +} + +// Error path: tests that manager.addChannel returns ChannelAlreadyExistsErr +// when the channel was already added. +func Test_manager_addChannel_ChannelAlreadyExistsErr(t *testing.T) { + m := NewManager(versioned.NewKV(ekv.MakeMemstore()), + new(mockBroadcastClient), + fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + new(mockNameService), new(mockEventModel)).(*manager) + + ch, _, err := newTestChannel("name", "description", m.rng.GetStream()) + if err != nil { + t.Errorf("Failed to create new channel: %+v", err) + } + + err = m.addChannel(ch) + if err != nil { + t.Errorf("Failed to add new channel: %+v", err) + } + + err = m.addChannel(ch) + if err == nil || err != ChannelAlreadyExistsErr { + t.Errorf("Received incorrect error when adding a channel that already "+ + "exists.\nexpected: %s\nreceived: %+v", ChannelAlreadyExistsErr, err) + } +} + +// Tests the manager.removeChannel deletes the channel from the map. +func Test_manager_removeChannel(t *testing.T) { + m := NewManager(versioned.NewKV(ekv.MakeMemstore()), + new(mockBroadcastClient), + fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + new(mockNameService), new(mockEventModel)).(*manager) + + ch, _, err := newTestChannel("name", "description", m.rng.GetStream()) + if err != nil { + t.Errorf("Failed to create new channel: %+v", err) + } + + err = m.addChannel(ch) + if err != nil { + t.Errorf("Failed to add new channel: %+v", err) + } + + err = m.removeChannel(ch.ReceptionID) + if err != nil { + t.Errorf("Error removing channel: %+v", err) + } + + if _, exists := m.channels[*ch.ReceptionID]; exists { + t.Errorf("Channel %s was not remove from the channel map.", ch.Name) + } + + _, err = m.kv.Get(makeJoinedChannelKey(ch.ReceptionID), joinedChannelVersion) + if ekv.Exists(err) { + t.Errorf("joinedChannel not removed from kv: %+v", err) + } +} + +// Error path: tests that manager.removeChannel returns ChannelDoesNotExistsErr +// when the channel was never added. +func Test_manager_removeChannel_ChannelDoesNotExistsErr(t *testing.T) { + m := NewManager(versioned.NewKV(ekv.MakeMemstore()), + new(mockBroadcastClient), + fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + new(mockNameService), new(mockEventModel)).(*manager) + + ch, _, err := newTestChannel("name", "description", m.rng.GetStream()) + if err != nil { + t.Errorf("Failed to create new channel: %+v", err) + } + + err = m.removeChannel(ch.ReceptionID) + if err == nil || err != ChannelDoesNotExistsErr { + t.Errorf("Received incorrect error when removing a channel that does "+ + "not exists.\nexpected: %s\nreceived: %+v", + ChannelDoesNotExistsErr, err) + } +} + +// Tests the manager.getChannel returns the expected channel. +func Test_manager_getChannel(t *testing.T) { + m := NewManager(versioned.NewKV(ekv.MakeMemstore()), + new(mockBroadcastClient), + fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + new(mockNameService), new(mockEventModel)).(*manager) + + ch, _, err := newTestChannel("name", "description", m.rng.GetStream()) + if err != nil { + t.Errorf("Failed to create new channel: %+v", err) + } + + err = m.addChannel(ch) + if err != nil { + t.Errorf("Failed to add new channel: %+v", err) + } + + jc, err := m.getChannel(ch.ReceptionID) + if err != nil { + t.Errorf("Error getting channel: %+v", err) + } + + if !reflect.DeepEqual(ch, jc.broadcast.Get()) { + t.Errorf("Received unexpected channel.\nexpected: %+v\nreceived: %+v", + ch, jc.broadcast.Get()) + } +} + +// Error path: tests that manager.getChannel returns ChannelDoesNotExistsErr +// when the channel was never added. +func Test_manager_getChannel_ChannelDoesNotExistsErr(t *testing.T) { + m := NewManager(versioned.NewKV(ekv.MakeMemstore()), + new(mockBroadcastClient), + fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + new(mockNameService), new(mockEventModel)).(*manager) + + ch, _, err := newTestChannel("name", "description", m.rng.GetStream()) + if err != nil { + t.Errorf("Failed to create new channel: %+v", err) + } + + _, err = m.getChannel(ch.ReceptionID) + if err == nil || err != ChannelDoesNotExistsErr { + t.Errorf("Received incorrect error when getting a channel that does "+ + "not exists.\nexpected: %s\nreceived: %+v", + ChannelDoesNotExistsErr, err) + } +} + +// Tests that manager.getChannels returns all the channels that were added to +// the map. +func Test_manager_getChannels(t *testing.T) { + m := NewManager(versioned.NewKV(ekv.MakeMemstore()), + new(mockBroadcastClient), + fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + new(mockNameService), new(mockEventModel)).(*manager) + + expected := make([]*id.ID, 10) + + for i := range expected { + ch, _, err := newTestChannel("name_"+strconv.Itoa(i), + "description_"+strconv.Itoa(i), m.rng.GetStream()) + if err != nil { + t.Errorf("Failed to create new channel %d: %+v", i, err) + } + expected[i] = ch.ReceptionID + + err = m.addChannel(ch) + if err != nil { + t.Errorf("Failed to add new channel %d: %+v", i, err) + } + } + + channelIDs := m.getChannelsUnsafe() + + sort.SliceStable(expected, func(i, j int) bool { + return bytes.Compare(expected[i][:], expected[j][:]) == -1 + }) + sort.SliceStable(channelIDs, func(i, j int) bool { + return bytes.Compare(channelIDs[i][:], channelIDs[j][:]) == -1 + }) + + if !reflect.DeepEqual(expected, channelIDs) { + t.Errorf("ID list does not match expected.\nexpected: %v\nreceived: %v", + expected, channelIDs) + } +} + +// Tests that joinedChannel.Store saves the joinedChannel to the expected place +// in the ekv. +func Test_joinedChannel_Store(t *testing.T) { + kv := versioned.NewKV(ekv.MakeMemstore()) + rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) + ch, _, err := newTestChannel("name", "description", rng.GetStream()) + if err != nil { + t.Errorf("Failed to create new channel: %+v", err) + } + + b, err := broadcast.NewBroadcastChannel(ch, new(mockBroadcastClient), rng) + if err != nil { + t.Errorf("Failed to create new broadcast channel: %+v", err) + } + + jc := &joinedChannel{b} + + err = jc.Store(kv) + if err != nil { + t.Errorf("Error storing joinedChannel: %+v", err) + } + + _, err = kv.Get(makeJoinedChannelKey(ch.ReceptionID), joinedChannelVersion) + if !ekv.Exists(err) { + t.Errorf("joinedChannel not found in KV: %+v", err) + } +} + +// Tests that loadJoinedChannel returns a joinedChannel from storage that +// matches the original. +func Test_loadJoinedChannel(t *testing.T) { + m := NewManager(versioned.NewKV(ekv.MakeMemstore()), + new(mockBroadcastClient), + fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + new(mockNameService), new(mockEventModel)).(*manager) + + ch, _, err := newTestChannel("name", "description", m.rng.GetStream()) + if err != nil { + t.Errorf("Failed to create new channel: %+v", err) + } + + err = m.addChannel(ch) + if err != nil { + t.Errorf("Failed to add channel: %+v", err) + } + + loadedJc, err := loadJoinedChannel(ch.ReceptionID, m.kv, m.client, m.rng, + m.name, m.events, m.broadcastMaker) + if err != nil { + t.Errorf("Failed to load joinedChannel: %+v", err) + } + + if !reflect.DeepEqual(ch, loadedJc.broadcast.Get()) { + t.Errorf("Loaded joinedChannel does not match original."+ + "\nexpected: %+v\nreceived: %+v", ch, loadedJc.broadcast.Get()) + } +} + +// Tests that joinedChannel.delete deletes the stored joinedChannel from the +// ekv. +func Test_joinedChannel_delete(t *testing.T) { + kv := versioned.NewKV(ekv.MakeMemstore()) + rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) + ch, _, err := newTestChannel("name", "description", rng.GetStream()) + if err != nil { + t.Errorf("Failed to create new channel: %+v", err) + } + + b, err := broadcast.NewBroadcastChannel(ch, new(mockBroadcastClient), rng) + if err != nil { + t.Errorf("Failed to create new broadcast channel: %+v", err) + } + + jc := &joinedChannel{b} + + err = jc.Store(kv) + if err != nil { + t.Errorf("Error storing joinedChannel: %+v", err) + } + + err = jc.delete(kv) + if err != nil { + t.Errorf("Error deleting joinedChannel: %+v", err) + } + + _, err = kv.Get(makeJoinedChannelKey(ch.ReceptionID), joinedChannelVersion) + if ekv.Exists(err) { + t.Errorf("joinedChannel found in KV: %+v", err) + } +} + +// Consistency test of makeJoinedChannelKey. +func Test_makeJoinedChannelKey_Consistency(t *testing.T) { + values := map[*id.ID]string{ + id.NewIdFromUInt(0, id.User, t): "JoinedChannelKey-0x0000000000000000000000000000000000000000000000000000000000000000", + id.NewIdFromUInt(1, id.User, t): "JoinedChannelKey-0x0000000000000001000000000000000000000000000000000000000000000000", + id.NewIdFromUInt(2, id.User, t): "JoinedChannelKey-0x0000000000000002000000000000000000000000000000000000000000000000", + id.NewIdFromUInt(3, id.User, t): "JoinedChannelKey-0x0000000000000003000000000000000000000000000000000000000000000000", + id.NewIdFromUInt(4, id.User, t): "JoinedChannelKey-0x0000000000000004000000000000000000000000000000000000000000000000", + id.NewIdFromUInt(5, id.User, t): "JoinedChannelKey-0x0000000000000005000000000000000000000000000000000000000000000000", + id.NewIdFromUInt(6, id.User, t): "JoinedChannelKey-0x0000000000000006000000000000000000000000000000000000000000000000", + id.NewIdFromUInt(7, id.User, t): "JoinedChannelKey-0x0000000000000007000000000000000000000000000000000000000000000000", + id.NewIdFromUInt(8, id.User, t): "JoinedChannelKey-0x0000000000000008000000000000000000000000000000000000000000000000", + id.NewIdFromUInt(9, id.User, t): "JoinedChannelKey-0x0000000000000009000000000000000000000000000000000000000000000000", + } + + for chID, expected := range values { + key := makeJoinedChannelKey(chID) + + if expected != key { + t.Errorf("Unexpected key for ID %d.\nexpected: %s\nreceived: %s", + binary.BigEndian.Uint64(chID[:8]), expected, key) + } + } + +} + +// newTestChannel creates a new cryptoBroadcast.Channel in the same way that +// cryptoBroadcast.NewChannel does but with a smaller RSA key and salt to make +// tests run quicker. +func newTestChannel(name, description string, rng csprng.Source) ( + *cryptoBroadcast.Channel, *rsa.PrivateKey, error) { + // Uses 128 bits instead of 4096 bits + pk, err := rsa.GenerateKey(rng, 128) + if err != nil { + return nil, nil, err + } + + // Uses 16 bits instead of 512 bits + salt := cmix.NewSalt(rng, 16) + + channelID, err := cryptoBroadcast.NewChannelID( + name, description, salt, rsa.CreatePublicKeyPem(pk.GetPublic())) + if err != nil { + return nil, nil, err + } + + return &cryptoBroadcast.Channel{ + ReceptionID: channelID, + Name: name, + Description: description, + Salt: salt, + RsaPubKey: pk.GetPublic(), + }, pk, nil +} diff --git a/channels/manager.go b/channels/manager.go index 2d4b958de..0e931e359 100644 --- a/channels/manager.go +++ b/channels/manager.go @@ -11,7 +11,7 @@ import ( type manager struct { // List of all channels - channels map[*id.ID]*joinedChannel + channels map[id.ID]*joinedChannel mux sync.RWMutex // External references @@ -53,7 +53,7 @@ func NewManager(kv *versioned.KV, client broadcast.Client, // JoinChannel joins the given channel. It will fail if the channel has already // been joined. -func (m *manager) JoinChannel(channel cryptoBroadcast.Channel) error { +func (m *manager) JoinChannel(channel *cryptoBroadcast.Channel) error { err := m.addChannel(channel) if err != nil { return err diff --git a/channels/send_test.go b/channels/send_test.go index 453fe9332..cf0fe3992 100644 --- a/channels/send_test.go +++ b/channels/send_test.go @@ -31,8 +31,8 @@ func (m *mockBroadcastChannel) MaxAsymmetricPayloadSize() int { return 123 } -func (m *mockBroadcastChannel) Get() cryptoBroadcast.Channel { - return cryptoBroadcast.Channel{} +func (m *mockBroadcastChannel) Get() *cryptoBroadcast.Channel { + return &cryptoBroadcast.Channel{} } func (m *mockBroadcastChannel) Broadcast(payload []byte, cMixParams cmix.CMIXParams) ( @@ -123,7 +123,7 @@ func (m *mockNameService) ValidateChannelMessage(username string, lease time.Tim type mockEventModel struct{} -func (m *mockEventModel) JoinChannel(channel cryptoBroadcast.Channel) {} +func (m *mockEventModel) JoinChannel(channel *cryptoBroadcast.Channel) {} func (m *mockEventModel) LeaveChannel(channelID *id.ID) {} @@ -164,7 +164,7 @@ func TestSendGeneric(t *testing.T) { validUntil := time.Hour params := new(cmix.CMIXParams) - m.channels[channelID] = &joinedChannel{ + m.channels[*channelID] = &joinedChannel{ broadcast: &mockBroadcastChannel{}, } -- GitLab