Skip to content
Snippets Groups Projects
Commit 68319306 authored by Jonah Husson's avatar Jonah Husson
Browse files

Merge branch 'hotfix/channel-methods' into 'release'

Integration update, allow client to use both broadcast methods

See merge request !240
parents 9caf0b14 24f70007
No related branches found
No related tags found
2 merge requests!510Release,!240Integration update, allow client to use both broadcast methods
...@@ -8,10 +8,10 @@ ...@@ -8,10 +8,10 @@
package broadcast package broadcast
import ( import (
"encoding/binary"
"github.com/pkg/errors" "github.com/pkg/errors"
"gitlab.com/elixxir/client/cmix" "gitlab.com/elixxir/client/cmix"
"gitlab.com/elixxir/client/cmix/message" "gitlab.com/elixxir/client/cmix/message"
"gitlab.com/elixxir/primitives/format"
"gitlab.com/xx_network/crypto/multicastRSA" "gitlab.com/xx_network/crypto/multicastRSA"
"gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/id"
"gitlab.com/xx_network/primitives/id/ephemeral" "gitlab.com/xx_network/primitives/id/ephemeral"
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
const ( const (
asymmetricBroadcastServiceTag = "AsymmBcast" asymmetricBroadcastServiceTag = "AsymmBcast"
asymmCMixSendTag = "AsymmetricBroadcast" asymmCMixSendTag = "AsymmetricBroadcast"
internalPayloadSizeLength = 2
) )
// MaxAsymmetricPayloadSize returns the maximum size for an asymmetric broadcast payload // MaxAsymmetricPayloadSize returns the maximum size for an asymmetric broadcast payload
...@@ -29,40 +30,29 @@ func (bc *broadcastClient) maxAsymmetricPayload() int { ...@@ -29,40 +30,29 @@ func (bc *broadcastClient) maxAsymmetricPayload() int {
// BroadcastAsymmetric broadcasts the payload to the channel. Requires a healthy network state to send // BroadcastAsymmetric broadcasts the payload to the channel. Requires a healthy network state to send
// Payload must be equal to bc.MaxAsymmetricPayloadSize, and the channel PrivateKey must be passed in // Payload must be equal to bc.MaxAsymmetricPayloadSize, and the channel PrivateKey must be passed in
// Broadcast method must be set to asymmetric
// When a payload is sent, it is split into partitons of size bc.channel.MaxAsymmetricPayloadSize
// which are each encrypted using multicastRSA
func (bc *broadcastClient) BroadcastAsymmetric(pk multicastRSA.PrivateKey, payload []byte, cMixParams cmix.CMIXParams) ( func (bc *broadcastClient) BroadcastAsymmetric(pk multicastRSA.PrivateKey, payload []byte, cMixParams cmix.CMIXParams) (
id.Round, ephemeral.Id, error) { id.Round, ephemeral.Id, error) {
if bc.param.Method != Asymmetric { // Confirm network health
return 0, ephemeral.Id{}, errors.Errorf(errBroadcastMethodType, Asymmetric, bc.param.Method)
}
if !bc.net.IsHealthy() { if !bc.net.IsHealthy() {
return 0, ephemeral.Id{}, errors.New(errNetworkHealth) return 0, ephemeral.Id{}, errors.New(errNetworkHealth)
} }
if len(payload) != bc.maxAsymmetricPayload() { // Check payload size
if len(payload) > bc.MaxAsymmetricPayloadSize() {
return 0, ephemeral.Id{}, return 0, ephemeral.Id{},
errors.Errorf(errPayloadSize, len(payload), bc.maxAsymmetricPayload()) errors.Errorf(errPayloadSize, len(payload), bc.maxAsymmetricPayload())
} }
payloadLength := uint16(len(payload))
numParts := bc.maxParts() finalPayload := make([]byte, bc.maxAsymmetricPayloadSizeRaw())
size := bc.channel.MaxAsymmetricPayloadSize() binary.BigEndian.PutUint16(finalPayload[:internalPayloadSizeLength], payloadLength)
var mac []byte copy(finalPayload[internalPayloadSizeLength:], payload)
var fp format.Fingerprint
var sequential []byte // Encrypt payload
for i := 0; i < numParts; i++ { encryptedPayload, mac, fp, err := bc.channel.EncryptAsymmetric(finalPayload, pk, bc.rng.GetStream())
// Encrypt payload to send using asymmetric channel
var encryptedPayload []byte
var err error
encryptedPayload, mac, fp, err = bc.channel.EncryptAsymmetric(payload[:size], pk, bc.rng.GetStream())
if err != nil { if err != nil {
return 0, ephemeral.Id{}, errors.WithMessage(err, "Failed to encrypt asymmetric broadcast message") return 0, ephemeral.Id{}, errors.WithMessage(err, "Failed to encrypt asymmetric broadcast message")
} }
payload = payload[size:]
sequential = append(sequential, encryptedPayload...)
}
// Create service object to send message // Create service object to send message
service := message.Service{ service := message.Service{
...@@ -74,10 +64,14 @@ func (bc *broadcastClient) BroadcastAsymmetric(pk multicastRSA.PrivateKey, paylo ...@@ -74,10 +64,14 @@ func (bc *broadcastClient) BroadcastAsymmetric(pk multicastRSA.PrivateKey, paylo
cMixParams.DebugTag = asymmCMixSendTag cMixParams.DebugTag = asymmCMixSendTag
} }
sizedPayload, err := NewSizedBroadcast(bc.net.GetMaxMessageLength(), sequential) // Create payload sized for sending over cmix
sizedPayload := make([]byte, bc.net.GetMaxMessageLength())
// Read random data into sized payload
_, err = bc.rng.GetStream().Read(sizedPayload)
if err != nil { if err != nil {
return id.Round(0), ephemeral.Id{}, err return 0, ephemeral.Id{}, errors.WithMessage(err, "Failed to add random data to sized broadcast")
} }
copy(sizedPayload[:len(encryptedPayload)], encryptedPayload)
return bc.net.Send( return bc.net.Send(
bc.channel.ReceptionID, fp, service, sizedPayload, mac, cMixParams) bc.channel.ReceptionID, fp, service, sizedPayload, mac, cMixParams)
......
...@@ -56,11 +56,16 @@ func Test_asymmetricClient_Smoke(t *testing.T) { ...@@ -56,11 +56,16 @@ func Test_asymmetricClient_Smoke(t *testing.T) {
cbChan <- payload cbChan <- payload
} }
s, err := NewBroadcastChannel(channel, cb, newMockCmix(cMixHandler), rngGen, Param{Method: Asymmetric}) s, err := NewBroadcastChannel(channel, newMockCmix(cMixHandler), rngGen)
if err != nil { if err != nil {
t.Errorf("Failed to create broadcast channel: %+v", err) t.Errorf("Failed to create broadcast channel: %+v", err)
} }
err = s.RegisterListener(cb, Asymmetric)
if err != nil {
t.Errorf("Failed to register listener: %+v", err)
}
cbChans[i] = cbChan cbChans[i] = cbChan
clients[i] = s clients[i] = s
...@@ -73,7 +78,7 @@ func Test_asymmetricClient_Smoke(t *testing.T) { ...@@ -73,7 +78,7 @@ func Test_asymmetricClient_Smoke(t *testing.T) {
// Send broadcast from each client // Send broadcast from each client
for i := range clients { for i := range clients {
payload := make([]byte, clients[i].MaxPayloadSize()) payload := make([]byte, clients[i].MaxAsymmetricPayloadSize())
copy(payload, copy(payload,
fmt.Sprintf("Hello from client %d of %d.", i, len(clients))) fmt.Sprintf("Hello from client %d of %d.", i, len(clients)))
...@@ -112,7 +117,7 @@ func Test_asymmetricClient_Smoke(t *testing.T) { ...@@ -112,7 +117,7 @@ func Test_asymmetricClient_Smoke(t *testing.T) {
clients[i].Stop() clients[i].Stop()
} }
payload := make([]byte, clients[0].MaxPayloadSize()) payload := make([]byte, clients[0].MaxAsymmetricPayloadSize())
copy(payload, "This message should not get through.") copy(payload, "This message should not get through.")
// Start waiting on channels and error if anything is received // Start waiting on channels and error if anything is received
......
...@@ -17,26 +17,19 @@ import ( ...@@ -17,26 +17,19 @@ import (
"gitlab.com/xx_network/crypto/signature/rsa" "gitlab.com/xx_network/crypto/signature/rsa"
) )
// Param encapsulates configuration options for a broadcastClient
type Param struct {
Method Method
}
// broadcastClient implements the Channel interface for sending/receiving asymmetric or symmetric broadcast messages // broadcastClient implements the Channel interface for sending/receiving asymmetric or symmetric broadcast messages
type broadcastClient struct { type broadcastClient struct {
channel crypto.Channel channel crypto.Channel
net Client net Client
rng *fastRNG.StreamGenerator rng *fastRNG.StreamGenerator
param Param
} }
// NewBroadcastChannel creates a channel interface based on crypto.Channel, accepts net client connection & callback for received messages // NewBroadcastChannel creates a channel interface based on crypto.Channel, accepts net client connection & callback for received messages
func NewBroadcastChannel(channel crypto.Channel, listenerCb ListenerFunc, net Client, rng *fastRNG.StreamGenerator, param Param) (Channel, error) { func NewBroadcastChannel(channel crypto.Channel, net Client, rng *fastRNG.StreamGenerator) (Channel, error) {
bc := &broadcastClient{ bc := &broadcastClient{
channel: channel, channel: channel,
net: net, net: net,
rng: rng, rng: rng,
param: param,
} }
if !bc.verifyID() { if !bc.verifyID() {
...@@ -46,31 +39,37 @@ func NewBroadcastChannel(channel crypto.Channel, listenerCb ListenerFunc, net Cl ...@@ -46,31 +39,37 @@ func NewBroadcastChannel(channel crypto.Channel, listenerCb ListenerFunc, net Cl
// Add channel's identity // Add channel's identity
net.AddIdentity(channel.ReceptionID, identity.Forever, true) net.AddIdentity(channel.ReceptionID, identity.Forever, true)
p := &processor{ jww.INFO.Printf("New broadcast channel client created for channel %q (%s)",
c: &channel, channel.Name, channel.ReceptionID)
cb: listenerCb,
method: param.Method, return bc, nil
} }
// RegisterListener adds a service to hear broadcast messages of a given type via the passed in callback
func (bc *broadcastClient) RegisterListener(listenerCb ListenerFunc, method Method) error {
var tag string var tag string
switch param.Method { switch method {
case Symmetric: case Symmetric:
tag = symmetricBroadcastServiceTag tag = symmetricBroadcastServiceTag
case Asymmetric: case Asymmetric:
tag = asymmetricBroadcastServiceTag tag = asymmetricBroadcastServiceTag
default: default:
return nil, errors.Errorf("Cannot make broadcast client for unknown broadcast method %s", param.Method) return errors.Errorf("Cannot register listener for broadcast method %s", method)
}
p := &processor{
c: &bc.channel,
cb: listenerCb,
method: method,
} }
service := message.Service{ service := message.Service{
Identifier: channel.ReceptionID.Bytes(), Identifier: bc.channel.ReceptionID.Bytes(),
Tag: tag, Tag: tag,
} }
net.AddService(channel.ReceptionID, service, p) bc.net.AddService(bc.channel.ReceptionID, service, p)
return nil
jww.INFO.Printf("New %s broadcast client created for channel %q (%s)",
param.Method, channel.Name, channel.ReceptionID)
return bc, nil
} }
// Stop unregisters the listener callback and stops the channel's identity // Stop unregisters the listener callback and stops the channel's identity
...@@ -89,7 +88,6 @@ func (bc *broadcastClient) Get() crypto.Channel { ...@@ -89,7 +88,6 @@ func (bc *broadcastClient) Get() crypto.Channel {
} }
// verifyID generates a symmetric ID based on the info in the channel & compares it to the one passed in // verifyID generates a symmetric ID based on the info in the channel & compares it to the one passed in
// TODO: it seems very odd to me that we do this, rather than just making the ID a private/ephemeral component like the key
func (bc *broadcastClient) verifyID() bool { func (bc *broadcastClient) verifyID() bool {
gen, err := crypto.NewChannelID(bc.channel.Name, bc.channel.Description, bc.channel.Salt, rsa.CreatePublicKeyPem(bc.channel.RsaPubKey)) gen, err := crypto.NewChannelID(bc.channel.Name, bc.channel.Description, bc.channel.Salt, rsa.CreatePublicKeyPem(bc.channel.RsaPubKey))
if err != nil { if err != nil {
...@@ -100,12 +98,13 @@ func (bc *broadcastClient) verifyID() bool { ...@@ -100,12 +98,13 @@ func (bc *broadcastClient) verifyID() bool {
} }
func (bc *broadcastClient) MaxPayloadSize() int { func (bc *broadcastClient) MaxPayloadSize() int {
switch bc.param.Method {
case Symmetric:
return bc.maxSymmetricPayload() return bc.maxSymmetricPayload()
case Asymmetric:
return bc.maxAsymmetricPayload()
default:
return -1
} }
func (bc *broadcastClient) MaxAsymmetricPayloadSize() int {
return bc.maxAsymmetricPayloadSizeRaw() - internalPayloadSizeLength
}
func (bc *broadcastClient) maxAsymmetricPayloadSizeRaw() int {
return bc.channel.MaxAsymmetricPayloadSize()
} }
...@@ -26,9 +26,12 @@ type ListenerFunc func(payload []byte, ...@@ -26,9 +26,12 @@ type ListenerFunc func(payload []byte,
receptionID receptionID.EphemeralIdentity, round rounds.Round) receptionID receptionID.EphemeralIdentity, round rounds.Round)
type Channel interface { type Channel interface {
// MaxPayloadSize returns the maximum size for a broadcast payload. Different math depending on broadcast method. // MaxPayloadSize returns the maximum size for a symmetric broadcast payload
MaxPayloadSize() int MaxPayloadSize() int
// MaxAsymmetricPayloadSize returns the maximum size for an asymmetric broadcast payload
MaxAsymmetricPayloadSize() int
// Get returns the underlying crypto.Channel // Get returns the underlying crypto.Channel
Get() crypto.Channel Get() crypto.Channel
...@@ -42,6 +45,8 @@ type Channel interface { ...@@ -42,6 +45,8 @@ type Channel interface {
BroadcastAsymmetric(pk multicastRSA.PrivateKey, payload []byte, cMixParams cmix.CMIXParams) ( BroadcastAsymmetric(pk multicastRSA.PrivateKey, payload []byte, cMixParams cmix.CMIXParams) (
id.Round, ephemeral.Id, error) id.Round, ephemeral.Id, error)
RegisterListener(listenerCb ListenerFunc, method Method) error
// Stop unregisters the listener callback and stops the channel's identity // Stop unregisters the listener callback and stops the channel's identity
// from being tracked. // from being tracked.
Stop() Stop()
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
package broadcast package broadcast
import ( import (
"encoding/binary"
jww "github.com/spf13/jwalterweatherman" jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/client/cmix/identity/receptionID" "gitlab.com/elixxir/client/cmix/identity/receptionID"
"gitlab.com/elixxir/client/cmix/rounds" "gitlab.com/elixxir/client/cmix/rounds"
...@@ -35,25 +36,15 @@ func (p *processor) Process(msg format.Message, ...@@ -35,25 +36,15 @@ func (p *processor) Process(msg format.Message,
var err error var err error
switch p.method { switch p.method {
case Asymmetric: case Asymmetric:
// We use sized broadcast to fill any remaining bytes in the cmix payload, decode it here encPartSize := p.c.RsaPubKey.Size() // Size returned by multicast RSA encryption
unsizedPayload, err := DecodeSizedBroadcast(msg.GetContents()) encodedMessage := msg.GetContents()[:encPartSize] // Only one message is encoded, rest of it is random data
if err != nil { decodedMessage, decryptErr := p.c.DecryptAsymmetric(encodedMessage)
jww.ERROR.Printf("Failed to decode sized broadcast: %+v", err) if decryptErr != nil {
return jww.ERROR.Printf(errDecrypt, p.c.ReceptionID, p.c.Name, decryptErr)
}
encPartSize := p.c.RsaPubKey.Size() // Size of each chunk returned by multicast RSA encryption
numParts := len(unsizedPayload) / encPartSize // Number of chunks in the payload
// Iterate through & decrypt each chunk, appending to aggregate payload
for i := 0; i < numParts; i++ {
var decrypted []byte
decrypted, err = p.c.DecryptAsymmetric(unsizedPayload[:encPartSize])
if err != nil {
jww.ERROR.Printf(errDecrypt, p.c.ReceptionID, p.c.Name, err)
return return
} }
unsizedPayload = unsizedPayload[encPartSize:] size := binary.BigEndian.Uint16(decodedMessage[:internalPayloadSizeLength])
payload = append(payload, decrypted...) payload = decodedMessage[internalPayloadSizeLength : size+internalPayloadSizeLength]
}
case Symmetric: case Symmetric:
payload, err = p.c.DecryptSymmetric(msg.GetContents(), msg.GetMac(), msg.GetKeyFP()) payload, err = p.c.DecryptSymmetric(msg.GetContents(), msg.GetMac(), msg.GetKeyFP())
......
...@@ -19,7 +19,7 @@ import ( ...@@ -19,7 +19,7 @@ import (
const ( const (
// symmetricClient.Broadcast // symmetricClient.Broadcast
errNetworkHealth = "cannot send broadcast when the network is not healthy" errNetworkHealth = "cannot send broadcast when the network is not healthy"
errPayloadSize = "size of payload %d must be %d" errPayloadSize = "size of payload %d must be less than %d"
errBroadcastMethodType = "cannot call %s broadcast using %s channel" errBroadcastMethodType = "cannot call %s broadcast using %s channel"
) )
...@@ -35,15 +35,10 @@ func (bc *broadcastClient) maxSymmetricPayload() int { ...@@ -35,15 +35,10 @@ func (bc *broadcastClient) maxSymmetricPayload() int {
} }
// Broadcast broadcasts a payload over a symmetric channel. // Broadcast broadcasts a payload over a symmetric channel.
// broadcast method must be set to Symmetric
// Network must be healthy to send // Network must be healthy to send
// Requires a payload of size bc.MaxSymmetricPayloadSize() // Requires a payload of size bc.MaxSymmetricPayloadSize()
func (bc *broadcastClient) Broadcast(payload []byte, cMixParams cmix.CMIXParams) ( func (bc *broadcastClient) Broadcast(payload []byte, cMixParams cmix.CMIXParams) (
id.Round, ephemeral.Id, error) { id.Round, ephemeral.Id, error) {
if bc.param.Method != Symmetric {
return 0, ephemeral.Id{}, errors.Errorf(errBroadcastMethodType, Symmetric, bc.param.Method)
}
if !bc.net.IsHealthy() { if !bc.net.IsHealthy() {
return 0, ephemeral.Id{}, errors.New(errNetworkHealth) return 0, ephemeral.Id{}, errors.New(errNetworkHealth)
} }
......
...@@ -63,10 +63,16 @@ func Test_symmetricClient_Smoke(t *testing.T) { ...@@ -63,10 +63,16 @@ func Test_symmetricClient_Smoke(t *testing.T) {
cbChan <- payload cbChan <- payload
} }
s, err := NewBroadcastChannel(channel, cb, newMockCmix(cMixHandler), rngGen, Param{Method: Symmetric}) s, err := NewBroadcastChannel(channel, newMockCmix(cMixHandler), rngGen)
if err != nil { if err != nil {
t.Errorf("Failed to create broadcast channel: %+v", err) t.Errorf("Failed to create broadcast channel: %+v", err)
} }
err = s.RegisterListener(cb, Symmetric)
if err != nil {
t.Errorf("Failed to register listener: %+v", err)
}
cbChans[i] = cbChan cbChans[i] = cbChan
clients[i] = s clients[i] = s
......
...@@ -15,9 +15,10 @@ import ( ...@@ -15,9 +15,10 @@ import (
crypto "gitlab.com/elixxir/crypto/broadcast" crypto "gitlab.com/elixxir/crypto/broadcast"
"gitlab.com/xx_network/crypto/signature/rsa" "gitlab.com/xx_network/crypto/signature/rsa"
"gitlab.com/xx_network/primitives/utils" "gitlab.com/xx_network/primitives/utils"
"sync"
) )
// singleCmd is the single-use subcommand that allows for sending and responding // singleCmd is the single-use subcommand that allows for sening and responding
// to single-use messages. // to single-use messages.
var broadcastCmd = &cobra.Command{ var broadcastCmd = &cobra.Command{
Use: "broadcast", Use: "broadcast",
...@@ -46,7 +47,6 @@ var broadcastCmd = &cobra.Command{ ...@@ -46,7 +47,6 @@ var broadcastCmd = &cobra.Command{
connected <- isConnected connected <- isConnected
}) })
waitUntilConnected(connected) waitUntilConnected(connected)
/* Set up underlying crypto broadcast.Channel */ /* Set up underlying crypto broadcast.Channel */
var channel *crypto.Channel var channel *crypto.Channel
var pk *rsa.PrivateKey var pk *rsa.PrivateKey
...@@ -81,13 +81,14 @@ var broadcastCmd = &cobra.Command{ ...@@ -81,13 +81,14 @@ var broadcastCmd = &cobra.Command{
} }
if keyPath != "" { if keyPath != "" {
err = utils.WriteFile(path, rsa.CreatePrivateKeyPem(pk), os.ModePerm, os.ModeDir) err = utils.WriteFile(keyPath, rsa.CreatePrivateKeyPem(pk), os.ModePerm, os.ModeDir)
if err != nil { if err != nil {
jww.ERROR.Printf("Failed to write private key to path %s: %+v", path, err) jww.ERROR.Printf("Failed to write private key to path %s: %+v", path, err)
} }
} else { } else {
fmt.Printf("Private key generated for channel: %+v", rsa.CreatePrivateKeyPem(pk)) fmt.Printf("Private key generated for channel: %+v", rsa.CreatePrivateKeyPem(pk))
} }
fmt.Printf("New broadcast channel generated")
} else { } else {
// Read rest of info from config & build object manually // Read rest of info from config & build object manually
pubKeyBytes := []byte(viper.GetString(broadcastRsaPubFlag)) pubKeyBytes := []byte(viper.GetString(broadcastRsaPubFlag))
...@@ -109,23 +110,6 @@ var broadcastCmd = &cobra.Command{ ...@@ -109,23 +110,6 @@ var broadcastCmd = &cobra.Command{
Salt: salt, Salt: salt,
RsaPubKey: pubKey, RsaPubKey: pubKey,
} }
// Load key if it's there
if keyPath != "" {
if ep, err := utils.ExpandPath(keyPath); err == nil {
keyBytes, err := utils.ReadFile(ep)
if err != nil {
jww.ERROR.Printf("Failed to read private key from %s: %+v", ep, err)
}
pk, err = rsa.LoadPrivateKeyFromPem(keyBytes)
if err != nil {
jww.ERROR.Printf("Failed to load private key %+v: %+v", keyBytes, err)
}
} else {
jww.ERROR.Printf("Failed to expand private key path: %+v", err)
}
}
} }
// Save channel to disk // Save channel to disk
...@@ -144,66 +128,123 @@ var broadcastCmd = &cobra.Command{ ...@@ -144,66 +128,123 @@ var broadcastCmd = &cobra.Command{
} }
} }
// Load key if needed
if pk == nil && keyPath != "" {
jww.DEBUG.Printf("Attempting to load private key at %s", keyPath)
if ep, err := utils.ExpandPath(keyPath); err == nil {
keyBytes, err := utils.ReadFile(ep)
if err != nil {
jww.ERROR.Printf("Failed to read private key from %s: %+v", ep, err)
}
pk, err = rsa.LoadPrivateKeyFromPem(keyBytes)
if err != nil {
jww.ERROR.Printf("Failed to load private key %+v: %+v", keyBytes, err)
}
} else {
jww.ERROR.Printf("Failed to expand private key path: %+v", err)
}
}
/* Broadcast client setup */ /* Broadcast client setup */
// Create receiver callback // Select broadcast method
symmetric := viper.GetString("symmetric")
asymmetric := viper.GetString("asymmetric")
// Connect to broadcast channel
bcl, err := broadcast.NewBroadcastChannel(*channel, client.GetCmix(), client.GetRng())
// Create & register symmetric receiver callback
receiveChan := make(chan []byte, 100) receiveChan := make(chan []byte, 100)
cb := func(payload []byte, scb := func(payload []byte,
receptionID receptionID.EphemeralIdentity, round rounds.Round) { receptionID receptionID.EphemeralIdentity, round rounds.Round) {
jww.INFO.Printf("Received symmetric message from %s over round %d", receptionID, round.ID) jww.INFO.Printf("Received symmetric message from %s over round %d", receptionID, round.ID)
receiveChan <- payload receiveChan <- payload
} }
err = bcl.RegisterListener(scb, broadcast.Symmetric)
if err != nil {
jww.FATAL.Panicf("Failed to register asymmetric listener: %+v", err)
}
// Select broadcast method // Create & register asymmetric receiver callback
var method broadcast.Method asymmetricReceiveChan := make(chan []byte, 100)
symmetric := viper.GetBool(broadcastSymmetricFlag) acb := func(payload []byte,
asymmetric := viper.GetBool(broadcastAsymmetricFlag) receptionID receptionID.EphemeralIdentity, round rounds.Round) {
if symmetric && asymmetric { jww.INFO.Printf("Received asymmetric message from %s over round %d", receptionID, round.ID)
jww.FATAL.Panicf("Cannot simultaneously broadcast symmetric & asymmetric") asymmetricReceiveChan <- payload
} }
if symmetric { err = bcl.RegisterListener(acb, broadcast.Asymmetric)
method = broadcast.Symmetric if err != nil {
} else if asymmetric { jww.FATAL.Panicf("Failed to register asymmetric listener: %+v", err)
method = broadcast.Asymmetric
} }
// Connect to broadcast channel jww.INFO.Printf("Broadcast listeners registered...")
bcl, err := broadcast.NewBroadcastChannel(*channel, cb, client.GetCmix(), client.GetRng(), broadcast.Param{Method: method})
/* Create properly sized broadcast message */ /* Broadcast messages to the channel */
message := viper.GetString(broadcastFlag) if symmetric != "" || asymmetric != "" {
fmt.Println(message) wg := sync.WaitGroup{}
var broadcastMessage []byte wg.Add(1)
if message != "" { go func() {
broadcastMessage, err = broadcast.NewSizedBroadcast(bcl.MaxPayloadSize(), []byte(message)) jww.INFO.Printf("Attempting to send broadcasts...")
if err != nil {
jww.ERROR.Printf("Failed to create sized broadcast: %+v", err)
}
sendDelay := time.Duration(viper.GetUint("sendDelay"))
maxRetries := 10
retries := 0
for {
// Wait for sendDelay before sending (to allow connection to establish)
if maxRetries == retries {
jww.FATAL.Panicf("Max retries reached")
} }
time.Sleep(sendDelay*time.Millisecond*time.Duration(retries) + 1)
/* Broadcast message to the channel */ /* Send symmetric broadcast */
switch method { if symmetric != "" {
case broadcast.Symmetric: // Create properly sized broadcast message
broadcastMessage, err := broadcast.NewSizedBroadcast(bcl.MaxPayloadSize(), []byte(symmetric))
if err != nil {
jww.FATAL.Panicf("Failed to create sized broadcast: %+v", err)
}
rid, eid, err := bcl.Broadcast(broadcastMessage, cmix.GetDefaultCMIXParams()) rid, eid, err := bcl.Broadcast(broadcastMessage, cmix.GetDefaultCMIXParams())
if err != nil { if err != nil {
jww.ERROR.Printf("Failed to send symmetric broadcast message: %+v", err) jww.ERROR.Printf("Failed to send symmetric broadcast message: %+v", err)
retries++
continue
} }
fmt.Printf("Sent symmetric broadcast message: %s", symmetric)
jww.INFO.Printf("Sent symmetric broadcast message to %s over round %d", eid, rid) jww.INFO.Printf("Sent symmetric broadcast message to %s over round %d", eid, rid)
case broadcast.Asymmetric: }
/* Send asymmetric broadcast */
if asymmetric != "" {
// Create properly sized broadcast message
broadcastMessage, err := broadcast.NewSizedBroadcast(bcl.MaxAsymmetricPayloadSize(), []byte(asymmetric))
if err != nil {
jww.FATAL.Panicf("Failed to create sized broadcast: %+v", err)
}
if pk == nil { if pk == nil {
jww.FATAL.Panicf("CANNOT SEND ASYMMETRIC BROADCAST WITHOUT PRIVATE KEY") jww.FATAL.Panicf("CANNOT SEND ASYMMETRIC BROADCAST WITHOUT PRIVATE KEY")
} }
rid, eid, err := bcl.BroadcastAsymmetric(pk, broadcastMessage, cmix.GetDefaultCMIXParams()) rid, eid, err := bcl.BroadcastAsymmetric(pk, broadcastMessage, cmix.GetDefaultCMIXParams())
if err != nil { if err != nil {
jww.ERROR.Printf("Failed to send asymmetric broadcast message: %+v", err) jww.ERROR.Printf("Failed to send asymmetric broadcast message: %+v", err)
retries++
continue
} }
fmt.Printf("Sent asymmetric broadcast message: %s", asymmetric)
jww.INFO.Printf("Sent asymmetric broadcast message to %s over round %d", eid, rid) jww.INFO.Printf("Sent asymmetric broadcast message to %s over round %d", eid, rid)
default:
jww.WARN.Printf("Unknown broadcast type (this should not happen)")
} }
wg.Done()
break
}
}()
wg.Wait()
}
/* Create properly sized broadcast message */
/* Receive broadcast messages over the channel */ /* Receive broadcast messages over the channel */
jww.INFO.Printf("Waiting for message reception...")
waitSecs := viper.GetUint(waitTimeoutFlag) waitSecs := viper.GetUint(waitTimeoutFlag)
expectedCnt := viper.GetUint(receiveCountFlag) expectedCnt := viper.GetUint(receiveCountFlag)
waitTimeout := time.Duration(waitSecs) * time.Second waitTimeout := time.Duration(waitSecs) * time.Second
...@@ -212,6 +253,17 @@ var broadcastCmd = &cobra.Command{ ...@@ -212,6 +253,17 @@ var broadcastCmd = &cobra.Command{
for !done && expectedCnt != 0 { for !done && expectedCnt != 0 {
timeout := time.NewTimer(waitTimeout) timeout := time.NewTimer(waitTimeout)
select { select {
case receivedPayload := <-asymmetricReceiveChan:
receivedCount++
receivedBroadcast, err := broadcast.DecodeSizedBroadcast(receivedPayload)
if err != nil {
jww.ERROR.Printf("Failed to decode sized broadcast: %+v", err)
continue
}
fmt.Printf("Asymmetric broadcast message received: %s\n", string(receivedBroadcast))
if receivedCount == expectedCnt {
done = true
}
case receivedPayload := <-receiveChan: case receivedPayload := <-receiveChan:
receivedCount++ receivedCount++
receivedBroadcast, err := broadcast.DecodeSizedBroadcast(receivedPayload) receivedBroadcast, err := broadcast.DecodeSizedBroadcast(receivedPayload)
...@@ -219,7 +271,7 @@ var broadcastCmd = &cobra.Command{ ...@@ -219,7 +271,7 @@ var broadcastCmd = &cobra.Command{
jww.ERROR.Printf("Failed to decode sized broadcast: %+v", err) jww.ERROR.Printf("Failed to decode sized broadcast: %+v", err)
continue continue
} }
fmt.Printf("Symmetric broadcast message %d/%d received: %s\n", receivedCount, expectedCnt, string(receivedBroadcast)) fmt.Printf("Symmetric broadcast message received: %s\n", string(receivedBroadcast))
if receivedCount == expectedCnt { if receivedCount == expectedCnt {
done = true done = true
} }
...@@ -269,16 +321,14 @@ func init() { ...@@ -269,16 +321,14 @@ func init() {
"Create new broadcast channel") "Create new broadcast channel")
bindFlagHelper(broadcastNewFlag, broadcastCmd) bindFlagHelper(broadcastNewFlag, broadcastCmd)
broadcastCmd.Flags().StringP(broadcastFlag, "", "", broadcastCmd.Flags().StringP(broadcastSymmetricFlag, "", "",
"Message contents for broadcast") "Send symmetric broadcast message")
bindFlagHelper(broadcastFlag, broadcastCmd) _ = viper.BindPFlag("symmetric", broadcastCmd.Flags().Lookup("symmetric"))
broadcastCmd.Flags().BoolP(broadcastSymmetricFlag, "", false,
"Set broadcast method to symmetric")
bindFlagHelper(broadcastSymmetricFlag, broadcastCmd) bindFlagHelper(broadcastSymmetricFlag, broadcastCmd)
broadcastCmd.Flags().BoolP(broadcastAsymmetricFlag, "", false, broadcastCmd.Flags().StringP(broadcastAsymmetricFlag, "", "",
"Set broadcast method to asymmetric") "Send asymmetric broadcast message (must be used with keyPath)")
_ = viper.BindPFlag("asymmetric", broadcastCmd.Flags().Lookup("asymmetric"))
bindFlagHelper(broadcastAsymmetricFlag, broadcastCmd) bindFlagHelper(broadcastAsymmetricFlag, broadcastCmd)
rootCmd.AddCommand(broadcastCmd) rootCmd.AddCommand(broadcastCmd)
......
...@@ -993,9 +993,9 @@ func init() { ...@@ -993,9 +993,9 @@ func init() {
rootCmd.PersistentFlags().UintP(sendCountFlag, rootCmd.PersistentFlags().UintP(sendCountFlag,
"", 1, "The number of times to send the message") "", 1, "The number of times to send the message")
viper.BindPFlag(sendCountFlag, rootCmd.PersistentFlags().Lookup(sendCountFlag)) viper.BindPFlag(sendCountFlag, rootCmd.PersistentFlags().Lookup(sendCountFlag))
rootCmd.Flags().UintP(sendDelayFlag, rootCmd.PersistentFlags().UintP(sendDelayFlag,
"", 500, "The delay between sending the messages in ms") "", 500, "The delay between sending the messages in ms")
viper.BindPFlag(sendDelayFlag, rootCmd.Flags().Lookup(sendDelayFlag)) viper.BindPFlag(sendDelayFlag, rootCmd.PersistentFlags().Lookup(sendDelayFlag))
rootCmd.Flags().BoolP(splitSendsFlag, rootCmd.Flags().BoolP(splitSendsFlag,
"", false, "Force sends to go over multiple rounds if possible") "", false, "Force sends to go over multiple rounds if possible")
viper.BindPFlag(splitSendsFlag, rootCmd.Flags().Lookup(splitSendsFlag)) viper.BindPFlag(splitSendsFlag, rootCmd.Flags().Lookup(splitSendsFlag))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment