Skip to content
Snippets Groups Projects
Commit 22e70bd0 authored by Jake Taylor's avatar Jake Taylor
Browse files

made more things use correct sender

parent e2f3f276
Branches
Tags
No related merge requests found
...@@ -37,12 +37,12 @@ func NewSender(poolParams PoolParams, rng io.Reader, ndf *ndf.NetworkDefinition, ...@@ -37,12 +37,12 @@ func NewSender(poolParams PoolParams, rng io.Reader, ndf *ndf.NetworkDefinition,
// Call given sendFunc to a specific Host in the HostPool, // Call given sendFunc to a specific Host in the HostPool,
// attempting with up to numProxies destinations in case of failure // attempting with up to numProxies destinations in case of failure
func (m *Sender) SendToSpecific(targets []*id.ID, numProxies int, func (m *Sender) SendToSpecific(targets []*id.ID, numProxies int,
sendFunc func(host *connect.Host) (interface{}, error)) (interface{}, error) { sendFunc func(host *connect.Host, target *id.ID) (interface{}, error)) (interface{}, error) {
for _, target := range targets { for _, target := range targets {
host, ok := m.GetSpecific(target) host, ok := m.GetSpecific(target)
if ok { if ok {
result, err := sendFunc(host) result, err := sendFunc(host, target)
if err == nil { if err == nil {
return result, m.ForceAdd([]*id.ID{host.GetId()}) return result, m.ForceAdd([]*id.ID{host.GetId()})
} }
...@@ -68,12 +68,12 @@ func (m *Sender) SendToAny(numProxies int, ...@@ -68,12 +68,12 @@ func (m *Sender) SendToAny(numProxies int,
} }
// Call given sendFunc to any Host in the HostPool, attempting with up to numProxies destinations // Call given sendFunc to any Host in the HostPool, attempting with up to numProxies destinations
func (m *Sender) SendToPreferred(targets []*id.ID, func (m *Sender) SendToPreferred(targets []*id.ID, numProxies int,
sendFunc func(host *connect.Host) (interface{}, error)) (interface{}, error) { sendFunc func(host *connect.Host, target *id.ID) (interface{}, error)) (interface{}, error) {
targetHosts := m.GetPreferred(targets) targetHosts := m.GetPreferred(targets)
for _, host := range targetHosts { for i, host := range targetHosts {
result, err := sendFunc(host) result, err := sendFunc(host, targets[i])
if err == nil { if err == nil {
return result, nil return result, nil
} }
......
...@@ -136,7 +136,7 @@ func (m *manager) Follow(report interfaces.ClientErrorReport) (stoppable.Stoppab ...@@ -136,7 +136,7 @@ func (m *manager) Follow(report interfaces.ClientErrorReport) (stoppable.Stoppab
multi.Add(healthStop) multi.Add(healthStop)
// Node Updates // Node Updates
multi.Add(node.StartRegistration(m.Instance, m.Session, m.Rng, multi.Add(node.StartRegistration(m.GetSender(), m.Session, m.Rng,
m.Comms, m.NodeRegistration, m.param.ParallelNodeRegistrations)) // Adding/Keys m.Comms, m.NodeRegistration, m.param.ParallelNodeRegistrations)) // Adding/Keys
//TODO-remover //TODO-remover
//m.runners.Add(StartNodeRemover(m.Context)) // Removing //m.runners.Add(StartNodeRemover(m.Context)) // Removing
......
...@@ -95,7 +95,7 @@ func (m *Manager) criticalMessages() { ...@@ -95,7 +95,7 @@ func (m *Manager) criticalMessages() {
jww.INFO.Printf("Resending critical raw message to %s "+ jww.INFO.Printf("Resending critical raw message to %s "+
"(msgDigest: %s)", rid, msg.Digest()) "(msgDigest: %s)", rid, msg.Digest())
//send the message //send the message
round, _, err := m.SendCMIX(msg, rid, param) round, _, err := m.SendCMIX(m.sender, msg, rid, param)
//if the message fail to send, notify the buffer so it can be handled //if the message fail to send, notify the buffer so it can be handled
//in the future and exit //in the future and exit
if err != nil { if err != nil {
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman" jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/interfaces/params"
"gitlab.com/elixxir/client/network/gateway"
"gitlab.com/elixxir/client/storage" "gitlab.com/elixxir/client/storage"
pb "gitlab.com/elixxir/comms/mixmessages" pb "gitlab.com/elixxir/comms/mixmessages"
"gitlab.com/elixxir/comms/network" "gitlab.com/elixxir/comms/network"
...@@ -29,7 +30,6 @@ import ( ...@@ -29,7 +30,6 @@ import (
// interface for SendCMIX comms; allows mocking this in testing // interface for SendCMIX comms; allows mocking this in testing
type sendCmixCommsInterface interface { type sendCmixCommsInterface interface {
GetHost(hostId *id.ID) (*connect.Host, bool)
SendPutMessage(host *connect.Host, message *pb.GatewaySlot) (*pb.GatewaySlotResponse, error) SendPutMessage(host *connect.Host, message *pb.GatewaySlot) (*pb.GatewaySlotResponse, error)
} }
...@@ -38,9 +38,9 @@ const sendTimeBuffer = 2500 * time.Millisecond ...@@ -38,9 +38,9 @@ const sendTimeBuffer = 2500 * time.Millisecond
// WARNING: Potentially Unsafe // WARNING: Potentially Unsafe
// Public manager function to send a message over CMIX // Public manager function to send a message over CMIX
func (m *Manager) SendCMIX(msg format.Message, recipient *id.ID, param params.CMIX) (id.Round, ephemeral.Id, error) { func (m *Manager) SendCMIX(sender *gateway.Sender, msg format.Message, recipient *id.ID, param params.CMIX) (id.Round, ephemeral.Id, error) {
msgCopy := msg.Copy() msgCopy := msg.Copy()
return sendCmixHelper(msgCopy, recipient, param, m.Instance, m.Session, m.nodeRegistration, m.Rng, m.TransmissionID, m.Comms) return sendCmixHelper(sender, msgCopy, recipient, param, m.Instance, m.Session, m.nodeRegistration, m.Rng, m.TransmissionID, m.Comms)
} }
// Payloads send are not End to End encrypted, MetaData is NOT protected with // Payloads send are not End to End encrypted, MetaData is NOT protected with
...@@ -51,7 +51,7 @@ func (m *Manager) SendCMIX(msg format.Message, recipient *id.ID, param params.CM ...@@ -51,7 +51,7 @@ func (m *Manager) SendCMIX(msg format.Message, recipient *id.ID, param params.CM
// If the message is successfully sent, the id of the round sent it is returned, // If the message is successfully sent, the id of the round sent it is returned,
// which can be registered with the network instance to get a callback on // which can be registered with the network instance to get a callback on
// its status // its status
func sendCmixHelper(msg format.Message, recipient *id.ID, param params.CMIX, instance *network.Instance, func sendCmixHelper(sender *gateway.Sender, msg format.Message, recipient *id.ID, param params.CMIX, instance *network.Instance,
session *storage.Session, nodeRegistration chan network.NodeGateway, rng *fastRNG.StreamGenerator, senderId *id.ID, session *storage.Session, nodeRegistration chan network.NodeGateway, rng *fastRNG.StreamGenerator, senderId *id.ID,
comms sendCmixCommsInterface) (id.Round, ephemeral.Id, error) { comms sendCmixCommsInterface) (id.Round, ephemeral.Id, error) {
...@@ -142,15 +142,6 @@ func sendCmixHelper(msg format.Message, recipient *id.ID, param params.CMIX, ins ...@@ -142,15 +142,6 @@ func sendCmixHelper(msg format.Message, recipient *id.ID, param params.CMIX, ins
firstGateway := topology.GetNodeAtIndex(0).DeepCopy() firstGateway := topology.GetNodeAtIndex(0).DeepCopy()
firstGateway.SetType(id.Gateway) firstGateway.SetType(id.Gateway)
transmitGateway, ok := comms.GetHost(firstGateway)
if !ok {
jww.ERROR.Printf("Failed to get host for gateway %s when "+
"sending to %s (msgDigest: %s)", transmitGateway, recipient,
msg.Digest())
time.Sleep(param.RetryDelay)
continue
}
//encrypt the message //encrypt the message
stream = rng.GetStream() stream = rng.GetStream()
salt := make([]byte, 32) salt := make([]byte, 32)
...@@ -187,9 +178,15 @@ func sendCmixHelper(msg format.Message, recipient *id.ID, param params.CMIX, ins ...@@ -187,9 +178,15 @@ func sendCmixHelper(msg format.Message, recipient *id.ID, param params.CMIX, ins
jww.INFO.Printf("Sending to EphID %d (%s) on round %d, "+ jww.INFO.Printf("Sending to EphID %d (%s) on round %d, "+
"(msgDigest: %s, ecrMsgDigest: %s) via gateway %s", "(msgDigest: %s, ecrMsgDigest: %s) via gateway %s",
ephID.Int64(), recipient, bestRound.ID, msg.Digest(), ephID.Int64(), recipient, bestRound.ID, msg.Digest(),
encMsg.Digest(), transmitGateway.GetId()) encMsg.Digest(), firstGateway.String())
// //Send the payload
gwSlotResp, err := comms.SendPutMessage(transmitGateway, wrappedMsg) // Send the payload
result, err := sender.SendToSpecific([]*id.ID{firstGateway}, 3, func(host *connect.Host, target *id.ID) (interface{}, error) {
wrappedMsg.Target = target.Marshal()
return comms.SendPutMessage(host, wrappedMsg)
})
gwSlotResp := result.(*pb.GatewaySlotResponse)
//if the comm errors or the message fails to send, continue retrying. //if the comm errors or the message fails to send, continue retrying.
//return if it sends properly //return if it sends properly
if err != nil { if err != nil {
...@@ -204,10 +201,10 @@ func sendCmixHelper(msg format.Message, recipient *id.ID, param params.CMIX, ins ...@@ -204,10 +201,10 @@ func sendCmixHelper(msg format.Message, recipient *id.ID, param params.CMIX, ins
"with this node?") { "with this node?") {
jww.WARN.Printf("Failed to send to %s (msgDigest: %s) "+ jww.WARN.Printf("Failed to send to %s (msgDigest: %s) "+
"via %s due to failed authentication: %s", "via %s due to failed authentication: %s",
recipient, msg.Digest(), transmitGateway.GetId(), err) recipient, msg.Digest(), firstGateway.String(), err)
//if we failed to send due to the gateway not recognizing our //if we failed to send due to the gateway not recognizing our
// authorization, renegotiate with the node to refresh it // authorization, renegotiate with the node to refresh it
nodeID := transmitGateway.GetId().DeepCopy() nodeID := firstGateway.DeepCopy()
nodeID.SetType(id.Node) nodeID.SetType(id.Node)
//delete the keys //delete the keys
session.Cmix().Remove(nodeID) session.Cmix().Remove(nodeID)
...@@ -226,7 +223,7 @@ func sendCmixHelper(msg format.Message, recipient *id.ID, param params.CMIX, ins ...@@ -226,7 +223,7 @@ func sendCmixHelper(msg format.Message, recipient *id.ID, param params.CMIX, ins
} else { } else {
jww.FATAL.Panicf("Gateway %s returned no error, but failed "+ jww.FATAL.Panicf("Gateway %s returned no error, but failed "+
"to accept message when sending to EphID %d (%s) on round %d", "to accept message when sending to EphID %d (%s) on round %d",
transmitGateway.GetId(), ephID.Int64(), recipient, bestRound.ID) firstGateway.String(), ephID.Int64(), recipient, bestRound.ID)
} }
} }
return 0, ephemeral.Id{}, errors.New("failed to send the message, " + return 0, ephemeral.Id{}, errors.New("failed to send the message, " +
......
...@@ -95,7 +95,7 @@ func (m *Manager) SendE2E(msg message.Send, param params.E2E) ([]id.Round, e2e.M ...@@ -95,7 +95,7 @@ func (m *Manager) SendE2E(msg message.Send, param params.E2E) ([]id.Round, e2e.M
wg.Add(1) wg.Add(1)
go func(i int) { go func(i int) {
var err error var err error
roundIds[i], _, err = m.SendCMIX(msgEnc, msg.Recipient, roundIds[i], _, err = m.SendCMIX(m.sender, msgEnc, msg.Recipient,
param.CMIX) param.CMIX)
if err != nil { if err != nil {
errCh <- err errCh <- err
......
...@@ -64,7 +64,7 @@ func (m *Manager) SendUnsafe(msg message.Send, param params.Unsafe) ([]id.Round, ...@@ -64,7 +64,7 @@ func (m *Manager) SendUnsafe(msg message.Send, param params.Unsafe) ([]id.Round,
wg.Add(1) wg.Add(1)
go func(i int) { go func(i int) {
var err error var err error
roundIds[i], _, err = m.SendCMIX(msgCmix, msg.Recipient, param.CMIX) roundIds[i], _, err = m.SendCMIX(m.sender, msgCmix, msg.Recipient, param.CMIX)
if err != nil { if err != nil {
errCh <- err errCh <- err
} }
......
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"fmt" "fmt"
"github.com/pkg/errors" "github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman" jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/client/network/gateway"
"gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/stoppable"
"gitlab.com/elixxir/client/storage" "gitlab.com/elixxir/client/storage"
"gitlab.com/elixxir/client/storage/cmix" "gitlab.com/elixxir/client/storage/cmix"
...@@ -33,14 +34,13 @@ import ( ...@@ -33,14 +34,13 @@ import (
) )
type RegisterNodeCommsInterface interface { type RegisterNodeCommsInterface interface {
GetHost(hostId *id.ID) (*connect.Host, bool)
SendRequestNonceMessage(host *connect.Host, SendRequestNonceMessage(host *connect.Host,
message *pb.NonceRequest) (*pb.Nonce, error) message *pb.NonceRequest) (*pb.Nonce, error)
SendConfirmNonceMessage(host *connect.Host, SendConfirmNonceMessage(host *connect.Host,
message *pb.RequestRegistrationConfirmation) (*pb.RegistrationConfirmation, error) message *pb.RequestRegistrationConfirmation) (*pb.RegistrationConfirmation, error)
} }
func StartRegistration(instance *network.Instance, session *storage.Session, rngGen *fastRNG.StreamGenerator, comms RegisterNodeCommsInterface, func StartRegistration(sender *gateway.Sender, session *storage.Session, rngGen *fastRNG.StreamGenerator, comms RegisterNodeCommsInterface,
c chan network.NodeGateway, numParallel uint) stoppable.Stoppable { c chan network.NodeGateway, numParallel uint) stoppable.Stoppable {
multi := stoppable.NewMulti("NodeRegistrations") multi := stoppable.NewMulti("NodeRegistrations")
...@@ -48,14 +48,14 @@ func StartRegistration(instance *network.Instance, session *storage.Session, rng ...@@ -48,14 +48,14 @@ func StartRegistration(instance *network.Instance, session *storage.Session, rng
for i := uint(0); i < numParallel; i++ { for i := uint(0); i < numParallel; i++ {
stop := stoppable.NewSingle(fmt.Sprintf("NodeRegistration %d", i)) stop := stoppable.NewSingle(fmt.Sprintf("NodeRegistration %d", i))
go registerNodes(session, rngGen, comms, stop, c) go registerNodes(sender, session, rngGen, comms, stop, c)
multi.Add(stop) multi.Add(stop)
} }
return multi return multi
} }
func registerNodes(session *storage.Session, rngGen *fastRNG.StreamGenerator, comms RegisterNodeCommsInterface, func registerNodes(sender *gateway.Sender, session *storage.Session, rngGen *fastRNG.StreamGenerator, comms RegisterNodeCommsInterface,
stop *stoppable.Single, c chan network.NodeGateway) { stop *stoppable.Single, c chan network.NodeGateway) {
u := session.User() u := session.User()
regSignature := u.GetTransmissionRegistrationValidationSignature() regSignature := u.GetTransmissionRegistrationValidationSignature()
...@@ -71,7 +71,7 @@ func registerNodes(session *storage.Session, rngGen *fastRNG.StreamGenerator, co ...@@ -71,7 +71,7 @@ func registerNodes(session *storage.Session, rngGen *fastRNG.StreamGenerator, co
t.Stop() t.Stop()
return return
case gw := <-c: case gw := <-c:
err := registerWithNode(comms, gw, regSignature, uci, cmix, rng) err := registerWithNode(sender, comms, gw, regSignature, uci, cmix, rng)
if err != nil { if err != nil {
jww.ERROR.Printf("Failed to register node: %+v", err) jww.ERROR.Printf("Failed to register node: %+v", err)
} }
...@@ -82,7 +82,7 @@ func registerNodes(session *storage.Session, rngGen *fastRNG.StreamGenerator, co ...@@ -82,7 +82,7 @@ func registerNodes(session *storage.Session, rngGen *fastRNG.StreamGenerator, co
//registerWithNode serves as a helper for RegisterWithNodes //registerWithNode serves as a helper for RegisterWithNodes
// It registers a user with a specific in the client's ndf. // It registers a user with a specific in the client's ndf.
func registerWithNode(comms RegisterNodeCommsInterface, ngw network.NodeGateway, regSig []byte, func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw network.NodeGateway, regSig []byte,
uci *user.CryptographicIdentity, store *cmix.Store, rng csprng.Source) error { uci *user.CryptographicIdentity, store *cmix.Store, rng csprng.Source) error {
nodeID, err := ngw.Node.GetNodeId() nodeID, err := ngw.Node.GetNodeId()
if err != nil { if err != nil {
...@@ -109,7 +109,7 @@ func registerWithNode(comms RegisterNodeCommsInterface, ngw network.NodeGateway, ...@@ -109,7 +109,7 @@ func registerWithNode(comms RegisterNodeCommsInterface, ngw network.NodeGateway,
userNum := int(uci.GetTransmissionID().Bytes()[7]) userNum := int(uci.GetTransmissionID().Bytes()[7])
h := sha256.New() h := sha256.New()
h.Reset() h.Reset()
h.Write([]byte(strconv.Itoa(int(4000 + userNum)))) h.Write([]byte(strconv.Itoa(4000 + userNum)))
transmissionKey = store.GetGroup().NewIntFromBytes(h.Sum(nil)) transmissionKey = store.GetGroup().NewIntFromBytes(h.Sum(nil))
jww.INFO.Printf("transmissionKey: %v", transmissionKey.Bytes()) jww.INFO.Printf("transmissionKey: %v", transmissionKey.Bytes())
...@@ -118,9 +118,11 @@ func registerWithNode(comms RegisterNodeCommsInterface, ngw network.NodeGateway, ...@@ -118,9 +118,11 @@ func registerWithNode(comms RegisterNodeCommsInterface, ngw network.NodeGateway,
// keys // keys
transmissionHash, _ := hash.NewCMixHash() transmissionHash, _ := hash.NewCMixHash()
nonce, dhPub, err := requestNonce(comms, gatewayID, regSig, uci, store, rng) _, err := sender.SendToAny(1, func(host *connect.Host) (interface{}, error) {
nonce, dhPub, err := requestNonce(comms, host, gatewayID, regSig, uci, store, rng)
if err != nil { if err != nil {
return errors.Errorf("Failed to request nonce: %+v", err) return nil, errors.Errorf("Failed to request nonce: %+v", err)
} }
// Load server DH pubkey // Load server DH pubkey
...@@ -129,13 +131,19 @@ func registerWithNode(comms RegisterNodeCommsInterface, ngw network.NodeGateway, ...@@ -129,13 +131,19 @@ func registerWithNode(comms RegisterNodeCommsInterface, ngw network.NodeGateway,
// Confirm received nonce // Confirm received nonce
jww.INFO.Println("Register: Confirming received nonce") jww.INFO.Println("Register: Confirming received nonce")
err = confirmNonce(comms, uci.GetTransmissionID().Bytes(), err = confirmNonce(comms, uci.GetTransmissionID().Bytes(),
nonce, uci.GetTransmissionRSA(), gatewayID) nonce, uci.GetTransmissionRSA(), host, gatewayID)
if err != nil { if err != nil {
errMsg := fmt.Sprintf("Register: Unable to confirm nonce: %v", err) errMsg := fmt.Sprintf("Register: Unable to confirm nonce: %v", err)
return errors.New(errMsg) return nil, errors.New(errMsg)
} }
transmissionKey = registration.GenerateBaseKey(store.GetGroup(), transmissionKey = registration.GenerateBaseKey(store.GetGroup(),
serverPubDH, store.GetDHPrivateKey(), transmissionHash) serverPubDH, store.GetDHPrivateKey(), transmissionHash)
return nil, nil
})
if err != nil {
jww.ERROR.Printf("registerNode failed: %+v", err)
return err
}
} }
store.Add(nodeID, transmissionKey) store.Add(nodeID, transmissionKey)
...@@ -145,7 +153,7 @@ func registerWithNode(comms RegisterNodeCommsInterface, ngw network.NodeGateway, ...@@ -145,7 +153,7 @@ func registerWithNode(comms RegisterNodeCommsInterface, ngw network.NodeGateway,
return nil return nil
} }
func requestNonce(comms RegisterNodeCommsInterface, gwId *id.ID, regHash []byte, func requestNonce(comms RegisterNodeCommsInterface, host *connect.Host, gwId *id.ID, regHash []byte,
uci *user.CryptographicIdentity, store *cmix.Store, rng csprng.Source) ([]byte, []byte, error) { uci *user.CryptographicIdentity, store *cmix.Store, rng csprng.Source) ([]byte, []byte, error) {
dhPub := store.GetDHPublicKey().Bytes() dhPub := store.GetDHPublicKey().Bytes()
opts := rsa.NewDefaultOptions() opts := rsa.NewDefaultOptions()
...@@ -165,10 +173,6 @@ func requestNonce(comms RegisterNodeCommsInterface, gwId *id.ID, regHash []byte, ...@@ -165,10 +173,6 @@ func requestNonce(comms RegisterNodeCommsInterface, gwId *id.ID, regHash []byte,
jww.INFO.Printf("Register: Requesting nonce from gateway %v", jww.INFO.Printf("Register: Requesting nonce from gateway %v",
gwId.Bytes()) gwId.Bytes())
host, ok := comms.GetHost(gwId)
if !ok {
return nil, nil, errors.Errorf("Failed to find host with ID %s", gwId.String())
}
nonceResponse, err := comms.SendRequestNonceMessage(host, nonceResponse, err := comms.SendRequestNonceMessage(host,
&pb.NonceRequest{ &pb.NonceRequest{
Salt: uci.GetTransmissionSalt(), Salt: uci.GetTransmissionSalt(),
...@@ -180,6 +184,7 @@ func requestNonce(comms RegisterNodeCommsInterface, gwId *id.ID, regHash []byte, ...@@ -180,6 +184,7 @@ func requestNonce(comms RegisterNodeCommsInterface, gwId *id.ID, regHash []byte,
RequestSignature: &messages.RSASignature{ RequestSignature: &messages.RSASignature{
Signature: clientSig, Signature: clientSig,
}, },
Target: gwId.Marshal(),
}) })
if err != nil { if err != nil {
...@@ -198,7 +203,7 @@ func requestNonce(comms RegisterNodeCommsInterface, gwId *id.ID, regHash []byte, ...@@ -198,7 +203,7 @@ func requestNonce(comms RegisterNodeCommsInterface, gwId *id.ID, regHash []byte,
// It signs a nonce and sends it for confirmation // It signs a nonce and sends it for confirmation
// Returns nil if successful, error otherwise // Returns nil if successful, error otherwise
func confirmNonce(comms RegisterNodeCommsInterface, UID, nonce []byte, func confirmNonce(comms RegisterNodeCommsInterface, UID, nonce []byte,
privateKeyRSA *rsa.PrivateKey, gwID *id.ID) error { privateKeyRSA *rsa.PrivateKey, host *connect.Host, gwID *id.ID) error {
opts := rsa.NewDefaultOptions() opts := rsa.NewDefaultOptions()
opts.Hash = hash.CMixHash opts.Hash = hash.CMixHash
h, _ := hash.NewCMixHash() h, _ := hash.NewCMixHash()
...@@ -224,12 +229,9 @@ func confirmNonce(comms RegisterNodeCommsInterface, UID, nonce []byte, ...@@ -224,12 +229,9 @@ func confirmNonce(comms RegisterNodeCommsInterface, UID, nonce []byte,
NonceSignedByClient: &messages.RSASignature{ NonceSignedByClient: &messages.RSASignature{
Signature: sig, Signature: sig,
}, },
Target: gwID.Marshal(),
} }
host, ok := comms.GetHost(gwID)
if !ok {
return errors.Errorf("Failed to find host with ID %s", gwID.String())
}
confirmResponse, err := comms.SendConfirmNonceMessage(host, msg) confirmResponse, err := comms.SendConfirmNonceMessage(host, msg)
if err != nil { if err != nil {
err := errors.New(fmt.Sprintf( err := errors.New(fmt.Sprintf(
......
...@@ -55,9 +55,9 @@ func (m *Manager) processMessageRetrieval(comms messageRetrievalComms, ...@@ -55,9 +55,9 @@ func (m *Manager) processMessageRetrieval(comms messageRetrievalComms,
} }
// Send to the gateways using backup proxies // Send to the gateways using backup proxies
result, err := m.sender.SendToSpecific(gwIds, 5, func(host *connect.Host) (interface{}, error) { result, err := m.sender.SendToPreferred(gwIds, 5, func(host *connect.Host, target *id.ID) (interface{}, error) {
// Attempt to request for this gateway // Attempt to request for this gateway
result, err := m.getMessagesFromGateway(id.Round(ri.ID), rl.identity, comms, host) result, err := m.getMessagesFromGateway(id.Round(ri.ID), rl.identity, comms, host, target)
if err != nil { if err != nil {
jww.WARN.Printf("Failed on gateway %s to get messages for round %v", jww.WARN.Printf("Failed on gateway %s to get messages for round %v",
host.GetId().String(), ri.ID) host.GetId().String(), ri.ID)
...@@ -86,7 +86,7 @@ func (m *Manager) processMessageRetrieval(comms messageRetrievalComms, ...@@ -86,7 +86,7 @@ func (m *Manager) processMessageRetrieval(comms messageRetrievalComms,
// getMessagesFromGateway attempts to get messages from their assigned // getMessagesFromGateway attempts to get messages from their assigned
// gateway host in the round specified. If successful // gateway host in the round specified. If successful
func (m *Manager) getMessagesFromGateway(roundID id.Round, identity reception.IdentityUse, func (m *Manager) getMessagesFromGateway(roundID id.Round, identity reception.IdentityUse,
comms messageRetrievalComms, gwHost *connect.Host) (message.Bundle, error) { comms messageRetrievalComms, gwHost *connect.Host, target *id.ID) (message.Bundle, error) {
jww.DEBUG.Printf("Trying to get messages for round %v for ephmeralID %d (%v) "+ jww.DEBUG.Printf("Trying to get messages for round %v for ephmeralID %d (%v) "+
"via Gateway: %s", roundID, identity.EphId.Int64(), identity.Source.String(), gwHost.GetId()) "via Gateway: %s", roundID, identity.EphId.Int64(), identity.Source.String(), gwHost.GetId())
...@@ -95,6 +95,7 @@ func (m *Manager) getMessagesFromGateway(roundID id.Round, identity reception.Id ...@@ -95,6 +95,7 @@ func (m *Manager) getMessagesFromGateway(roundID id.Round, identity reception.Id
msgReq := &pb.GetMessages{ msgReq := &pb.GetMessages{
ClientID: identity.EphId[:], ClientID: identity.EphId[:],
RoundID: uint64(roundID), RoundID: uint64(roundID),
Target: target.Marshal(),
} }
msgResp, err := comms.RequestMessages(gwHost, msgReq) msgResp, err := comms.RequestMessages(gwHost, msgReq)
// Fail the round if an error occurs so it can be tried again later // Fail the round if an error occurs so it can be tried again later
......
...@@ -28,7 +28,7 @@ func (m *manager) SendCMIX(msg format.Message, recipient *id.ID, param params.CM ...@@ -28,7 +28,7 @@ func (m *manager) SendCMIX(msg format.Message, recipient *id.ID, param params.CM
"network is not healthy") "network is not healthy")
} }
return m.message.SendCMIX(msg, recipient, param) return m.message.SendCMIX(m.GetSender(), msg, recipient, param)
} }
// SendUnsafe sends an unencrypted payload to the provided recipient // SendUnsafe sends an unencrypted payload to the provided recipient
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment