Skip to content
Snippets Groups Projects
Commit f0a828d9 authored by Jake Taylor's avatar Jake Taylor :lips:
Browse files

Merge branch 'NewHostPool' into 'release'

implemented new host pool

See merge request !472
parents 88578485 807ad261
No related branches found
No related tags found
2 merge requests!510Release,!472implemented new host pool
Showing
with 1920 additions and 1603 deletions
...@@ -18,7 +18,7 @@ import ( ...@@ -18,7 +18,7 @@ import (
) )
// Change this value to set the version for this build // Change this value to set the version for this build
const currentVersion = "4.4.2" const currentVersion = "4.4.3"
func Version() string { func Version() string {
out := fmt.Sprintf("Elixxir Client v%s -- %s\n\n", xxdk.SEMVER, out := fmt.Sprintf("Elixxir Client v%s -- %s\n\n", xxdk.SEMVER,
......
...@@ -186,9 +186,8 @@ func (c *client) initialize(ndfile *ndf.NetworkDefinition) error { ...@@ -186,9 +186,8 @@ func (c *client) initialize(ndfile *ndf.NetworkDefinition) error {
// Enable optimized HostPool initialization // Enable optimized HostPool initialization
poolParams.MaxPings = 50 poolParams.MaxPings = 50
poolParams.ForceConnection = true
sender, err := gateway.NewSender(poolParams, c.rng, ndfile, c.comms, sender, err := gateway.NewSender(poolParams, c.rng, ndfile, c.comms,
c.session, nodeChan) c.session, c.comms, nodeChan)
if err != nil { if err != nil {
return err return err
} }
...@@ -288,6 +287,9 @@ func (c *client) Follow(report ClientErrorReport) (stoppable.Stoppable, error) { ...@@ -288,6 +287,9 @@ func (c *client) Follow(report ClientErrorReport) (stoppable.Stoppable, error) {
//Start the critical processing thread //Start the critical processing thread
multi.Add(c.crit.startProcessies()) multi.Add(c.crit.startProcessies())
//start the host pool thread
multi.Add(c.Sender.StartProcesses())
return multi, nil return multi, nil
} }
......
////////////////////////////////////////////////////////////////////////////////
// Copyright © 2022 xx foundation //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file. //
////////////////////////////////////////////////////////////////////////////////
package gateway
import (
"bytes"
"crypto"
"crypto/sha256"
"fmt"
"github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/client/v4/storage/versioned"
pb "gitlab.com/elixxir/comms/mixmessages"
"gitlab.com/elixxir/crypto/hash"
"gitlab.com/xx_network/comms/connect"
"gitlab.com/xx_network/crypto/signature/rsa"
"gitlab.com/xx_network/primitives/id"
"time"
)
const (
certCheckerPrefix = "GwCertChecker"
keyTemplate = "GatewayCertificate-%s"
certCheckerStorageVer = uint64(1)
)
// CertCheckerCommInterface is an interface for client comms to be used in cert checker
type CertCheckerCommInterface interface {
GetGatewayTLSCertificate(host *connect.Host,
message *pb.RequestGatewayCert) (*pb.GatewayCertificate, error)
}
// certChecker stores verified certificates and handles verification checking
type certChecker struct {
kv *versioned.KV
comms CertCheckerCommInterface
}
// newCertChecker initializes a certChecker object
func newCertChecker(comms CertCheckerCommInterface, kv *versioned.KV) *certChecker {
return &certChecker{
kv: kv.Prefix(certCheckerPrefix),
comms: comms,
}
}
// CheckRemoteCertificate attempts to verify the tls certificate for a given host
func (cc *certChecker) CheckRemoteCertificate(gwHost *connect.Host) error {
if !gwHost.IsWeb() {
jww.TRACE.Printf("remote certificate verification is only " +
"implemented for web connections")
return nil
}
// Request signed certificate from the gateway
// NOTE: the remote certificate on the host is populated using the response
// after sending, so this must occur before getting the remote
// certificate from the host
gwTlsCertResp, err := cc.comms.GetGatewayTLSCertificate(gwHost, &pb.RequestGatewayCert{})
if err != nil {
return err
}
remoteCertSignature := gwTlsCertResp.GetSignature()
declaredFingerprint := sha256.Sum256(gwTlsCertResp.GetCertificate())
// Get remote certificate used for connection from the host object
actualRemoteCert, err := gwHost.GetRemoteCertificate()
if err != nil {
return err
}
rawActualRemoteCert := actualRemoteCert.Raw
actualFingerprint := sha256.Sum256(rawActualRemoteCert)
// If the fingerprints of the used & declared certs do not match, return an error
if actualFingerprint != declaredFingerprint {
return errors.Errorf("Declared & used remote certificates "+
"do not match\n\tDeclared: %+v\n\tUsed: %+v\n",
declaredFingerprint, actualFingerprint)
}
// Check if we have already verified this certificate for this host
storedFingerprint, err := cc.loadGatewayCertificateFingerprint(gwHost.GetId())
if err == nil {
if bytes.Compare(storedFingerprint, actualFingerprint[:]) == 0 {
return nil
}
}
// Verify received signature
err = verifyRemoteCertificate(rawActualRemoteCert, remoteCertSignature, gwHost)
if err != nil {
return err
}
// Store checked certificate fingerprint
return cc.storeGatewayCertificateFingerprint(actualFingerprint[:], gwHost.GetId())
}
// verifyRemoteCertificate verifies the RSA signature of a gateway on its tls certificate
func verifyRemoteCertificate(cert, sig []byte, gwHost *connect.Host) error {
opts := rsa.NewDefaultOptions()
opts.Hash = crypto.SHA256
h := opts.Hash.New()
h.Write(cert)
return rsa.Verify(gwHost.GetPubKey(), hash.CMixHash, h.Sum(nil), sig, rsa.NewDefaultOptions())
}
// loadGatewayCertificateFingerprint retrieves the stored certificate
// fingerprint for a given gateway, or returns an error if not found
func (cc *certChecker) loadGatewayCertificateFingerprint(id *id.ID) ([]byte, error) {
key := getKey(id)
obj, err := cc.kv.Get(key, certCheckerStorageVer)
if err != nil {
return nil, err
}
return obj.Data, err
}
// storeGatewayCertificateFingerprint stores the certificate fingerprint for a given gateway
func (cc *certChecker) storeGatewayCertificateFingerprint(fingerprint []byte, id *id.ID) error {
key := getKey(id)
return cc.kv.Set(key, &versioned.Object{
Version: certCheckerStorageVer,
Timestamp: time.Now(),
Data: fingerprint,
})
}
// getKey is a helper function to generate the key for a gateway certificate fingerprint
func getKey(id *id.ID) string {
return fmt.Sprintf(keyTemplate, id.String())
}
package gateway
import (
"bytes"
"gitlab.com/elixxir/client/v4/storage/versioned"
pb "gitlab.com/elixxir/comms/mixmessages"
"gitlab.com/elixxir/comms/testkeys"
"gitlab.com/elixxir/ekv"
"gitlab.com/xx_network/comms/connect"
"gitlab.com/xx_network/primitives/id"
"golang.org/x/crypto/blake2b"
"testing"
)
type mockCertCheckerComm struct {
}
func (mccc *mockCertCheckerComm) GetGatewayTLSCertificate(host *connect.Host,
message *pb.RequestGatewayCert) (*pb.GatewayCertificate, error) {
return &pb.GatewayCertificate{}, nil
}
// Test load & store functions for cert checker
func Test_certChecker_loadStore(t *testing.T) {
kv := versioned.NewKV(ekv.MakeMemstore())
cc := newCertChecker(&mockCertCheckerComm{}, kv)
gwCert := testkeys.LoadFromPath(testkeys.GetGatewayCertPath())
gwID := id.NewIdFromString("testid01", id.Gateway, t)
expectedFp := blake2b.Sum256(gwCert)
fp, err := cc.loadGatewayCertificateFingerprint(gwID)
if err == nil || fp != nil {
t.Errorf("Should error & receive nil when nothing is in storage")
}
err = cc.storeGatewayCertificateFingerprint(expectedFp[:], gwID)
if err != nil {
t.Fatal("Failed to store certificate")
}
fp, err = cc.loadGatewayCertificateFingerprint(gwID)
if err != nil {
t.Fatalf("Failed to load certificate for gwID %s: %+v", gwID, err)
}
if bytes.Compare(fp, expectedFp[:]) != 0 {
t.Errorf("Did not receive expected fingerprint after load\n\tExpected: %+v\n\tReceived: %+v\n", expectedFp, fp)
}
}
package gateway
import (
"gitlab.com/xx_network/primitives/netTime"
"math"
"time"
)
// piecewise table of delay to percent of the bucket that
// is full
var table = map[float64]time.Duration{
0: 0,
0.1: 0,
0.2: 0,
0.3: 0,
0.4: 100 * time.Millisecond,
0.5: 500 * time.Millisecond,
0.6: 2 * time.Second,
0.7: 5 * time.Second,
0.8: 8 * time.Second,
0.9: 9 * time.Second,
1.0: 10 * time.Second,
1.1: 10 * time.Second,
}
// getDelay computes the delay through linear
func getDelay(bucket float64, poolsize uint) time.Duration {
ratio := bucket / float64(poolsize)
if ratio < 0 {
ratio = 0
}
if ratio > 1 {
ratio = 1
}
ratioFloor := math.Floor(ratio)
ratioCeil := math.Ceil(ratio)
upperRatio := ratio - ratioFloor
lowerRatio := 1 - upperRatio
bottom := time.Duration(float64(table[ratioFloor]) * lowerRatio)
top := time.Duration(float64(table[ratioCeil]) * upperRatio)
return top + bottom
}
// bucket is a leaky bucket implementation.
type bucket struct {
//time until the entire bucket is leaked
leakRate time.Duration
points float64
lastEdit time.Time
poolsize uint
}
// newBucket initializes a new bucket.
func newBucket(poolSize int) *bucket {
return &bucket{
leakRate: time.Duration(poolSize) * table[1],
points: 0,
lastEdit: netTime.Now(),
poolsize: uint(poolSize),
}
}
func (b *bucket) leak() {
now := netTime.Now()
delta := now.Sub(b.lastEdit)
if delta < 0 {
return
}
leaked := (float64(delta) / float64(b.leakRate)) * float64(b.poolsize)
b.points -= leaked
if b.points < 0 {
b.points = 0
}
}
func (b *bucket) Add() {
b.leak()
b.points += 1
}
func (b *bucket) Reset() {
b.points = 0
}
func (b *bucket) GetDelay() time.Duration {
b.leak()
return getDelay(b.points, b.poolsize)
}
////////////////////////////////////////////////////////////////////////////////
// Copyright © 2022 xx foundation //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file. //
////////////////////////////////////////////////////////////////////////////////
package gateway
import (
"gitlab.com/xx_network/comms/connect"
"gitlab.com/xx_network/primitives/ndf"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
)
// List of errors that initiate a Host replacement
var errorsList = []string{
context.DeadlineExceeded.Error(),
"connection refused",
"host disconnected",
"transport is closing",
balancer.ErrTransientFailure.Error(),
"Last try to connect",
ndf.NO_NDF,
"Host is in cool down",
grpc.ErrClientConnClosing.Error(),
connect.TooManyProxyError,
"Failed to fetch",
"NetworkError when attempting to fetch resource.",
}
var errorMap = make(map[string]struct{})
func init() {
for _, str := range errorsList {
errorMap[str] = struct{}{}
}
}
// IsGuilty returns true if the error means the host
// will get kicked out of the pool
func IsGuilty(err error) bool {
_, exists := errorMap[err.Error()]
return exists
}
This diff is collapsed.
This diff is collapsed.
////////////////////////////////////////////////////////////////////////////////
// Copyright © 2022 xx foundation //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file. //
////////////////////////////////////////////////////////////////////////////////
package gateway
import (
jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/client/v4/stoppable"
"gitlab.com/xx_network/comms/connect"
"sync"
"time"
)
// connectivityFailure is a constant value indicating that the node being
// processed in hostPool.nodeTester has failed its connectivity test.
const connectivityFailure = 0
// nodeTester is a long-running thread that tests the connectivity of nodes
// that may be added to the hostPool.
func (hp *hostPool) nodeTester(stop *stoppable.Single) {
for {
select {
case <-stop.Quit():
stop.ToStopped()
return
case queryList := <-hp.testNodes:
jww.DEBUG.Printf("[NodeTester] Received queryList of nodes to test: %v", queryList)
// Test all nodes, find the best
resultList := make([]time.Duration, len(queryList))
wg := sync.WaitGroup{}
for i := 0; i < len(queryList); i++ {
wg.Add(1)
go func(hostToQuery *connect.Host, index int) {
latency, pinged := hostToQuery.IsOnline()
if !pinged {
latency = connectivityFailure
}
resultList[index] = latency
wg.Done()
}(queryList[i], i)
}
// Wait until all tests complete
wg.Wait()
// Find the fastest one which is not 0 (designated as failure)
lowestLatency := time.Hour
var bestHost *connect.Host
for i := 0; i < len(queryList); i++ {
if resultList[i] != connectivityFailure && resultList[i] < lowestLatency {
lowestLatency = resultList[i]
bestHost = queryList[i]
}
}
if bestHost != nil {
// Connect to the host then send it over to be added to the
// host pool
err := bestHost.Connect()
if err == nil {
select {
case hp.newHost <- bestHost:
default:
jww.ERROR.Printf("failed to send best host to main thread, " +
"will be dropped, new addRequest to be sent")
bestHost = nil
}
} else {
jww.WARN.Printf("Failed to connect to bestHost %s with error %+v, will be dropped", bestHost.GetId(), err)
bestHost = nil
}
}
// Send the tested nodes back to be labeled as available again
select {
case hp.doneTesting <- queryList:
jww.DEBUG.Printf("[NodeTester] Completed testing query list %s", queryList)
default:
jww.ERROR.Printf("Failed to send queryList to main thread, " +
"nodes are stuck in testing, this should never happen")
bestHost = nil
}
if bestHost == nil {
jww.WARN.Printf("No host selected, restarting the request process")
// If none of the hosts could be contacted, send a signal
// to add a new node to the pool
select {
case hp.addRequest <- nil:
default:
jww.WARN.Printf("Failed to send a signal to add hosts after " +
"testing failure")
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////
// Copyright © 2022 xx foundation //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file. //
////////////////////////////////////////////////////////////////////////////////
package gateway
import (
"encoding/json"
"gitlab.com/xx_network/comms/connect"
"time"
)
// Params allows configuration of HostPool parameters.
type Params struct {
// MaxPoolSize is the maximum number of Hosts in the HostPool.
MaxPoolSize uint32
// PoolSize allows override of HostPool size. Set to zero for dynamic size
// calculation.
PoolSize uint32
// ProxyAttempts dictates how many proxies will be used in event of send
// failure.
ProxyAttempts uint32
// MaxPings is the number of gateways to concurrently test when selecting
// a new member of HostPool. Must be at least 1.
MaxPings uint32
// NumConnectionsWorkers is the number of workers connecting to gateways
NumConnectionsWorkers int
// MinBufferLength is the minimum length of input buffers
// to the hostpool runner
MinBufferLength uint32
// EnableRotation enables the system which auto rotates
// gateways regularly. this system will auto disable
// if the network size is less than 20
EnableRotation bool
// RotationPeriod is how long until a single
// host is rotated
RotationPeriod time.Duration
// RotationPeriodVariability is the max that the rotation
// period can randomly deviate from the stated amount
RotationPeriodVariability time.Duration
// HostParams is the parameters for the creation of new Host objects.
HostParams connect.HostParams
}
// DefaultParams returns a default set of PoolParams.
func DefaultParams() Params {
p := Params{
MaxPoolSize: MaxPoolSize,
ProxyAttempts: 5,
PoolSize: 0,
MaxPings: 5,
NumConnectionsWorkers: 5,
MinBufferLength: 100,
EnableRotation: true,
RotationPeriod: 7 * time.Minute,
RotationPeriodVariability: 4 * time.Minute,
HostParams: GetDefaultHostPoolHostParams(),
}
return p
}
// DefaultPoolParams is a deprecated version of DefaultParams
// it does the same thing, just under a different function name
// Use DefaultParams.
func DefaultPoolParams() Params {
return DefaultParams()
}
// GetParameters returns the default PoolParams, or
// override with given parameters, if set.
func GetParameters(params string) (Params, error) {
p := DefaultParams()
if len(params) > 0 {
err := json.Unmarshal([]byte(params), &p)
if err != nil {
return Params{}, err
}
}
return p, nil
}
// MarshalJSON adheres to the json.Marshaler interface.
func (pp *Params) MarshalJSON() ([]byte, error) {
return json.Marshal(pp)
}
// UnmarshalJSON adheres to the json.Unmarshaler interface.
func (pp *Params) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, pp)
}
// GetDefaultHostPoolHostParams returns the default parameters used for
// hosts in the host pool
func GetDefaultHostPoolHostParams() connect.HostParams {
hp := connect.GetDefaultHostParams()
hp.MaxRetries = 1
hp.MaxSendRetries = 1
hp.AuthEnabled = false
hp.EnableCoolOff = false
hp.NumSendsBeforeCoolOff = 1
hp.CoolOffTimeout = 5 * time.Minute
hp.SendTimeout = 1000 * time.Millisecond
hp.PingTimeout = 1000 * time.Millisecond
hp.DisableAutoConnect = true
return hp
}
////////////////////////////////////////////////////////////////////////////////
// Copyright © 2022 xx foundation //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file. //
////////////////////////////////////////////////////////////////////////////////
package gateway
import (
"github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman"
"gitlab.com/xx_network/comms/connect"
"gitlab.com/xx_network/crypto/csprng"
"gitlab.com/xx_network/crypto/randomness"
"gitlab.com/xx_network/primitives/id"
"io"
"strings"
)
var errHostPoolNotReady = "Host pool is not ready, wait a " +
"little then try again. if this persists, you may have connectivity issues"
func IsHostPoolNotReadyError(err error) bool {
return strings.Contains(err.Error(), errHostPoolNotReady)
}
type Pool interface {
Get(id *id.ID) (*connect.Host, bool)
Has(id *id.ID) bool
Size() int
IsReady() error
GetAny(length uint32, excluded []*id.ID, rng io.Reader) []*connect.Host
GetSpecific(target *id.ID) (*connect.Host, bool)
GetPreferred(targets []*id.ID, rng io.Reader) []*connect.Host
}
type pool struct {
hostMap map[id.ID]uint // Map key to its index in the slice
hostList []*connect.Host // Each index in the slice contains the value
isConnected func(host *connect.Host) bool
}
// newPool creates a pool of size "size"
func newPool(size int) *pool {
return &pool{
hostMap: make(map[id.ID]uint, size),
hostList: make([]*connect.Host, 0, size),
isConnected: func(host *connect.Host) bool {
c, _ := host.Connected()
return c
},
}
}
// Get returns a specific member of the host pool if it exists
func (p *pool) Get(id *id.ID) (*connect.Host, bool) {
h, exists := p.hostMap[*id]
return p.hostList[h], exists
}
// Has returns true if the id is a member of the host pool
func (p *pool) Has(id *id.ID) bool {
_, exists := p.hostMap[*id]
return exists
}
// Size returns the number of currently connected and usable members
// of the host pool
func (p *pool) Size() int {
size := 0
for i := 0; i < len(p.hostList); i++ {
if p.isConnected(p.hostList[i]) {
size++
}
}
return size
}
// IsReady returns true if there is at least one connected member of the hostPool
func (p *pool) IsReady() error {
jww.TRACE.Printf("[IsReady] Length of Host List %d", len(p.hostList))
for i := 0; i < len(p.hostList); i++ {
if p.isConnected(p.hostList[i]) {
return nil
}
}
return errors.New(errHostPoolNotReady)
}
// GetAny returns up to n host pool members randomly, excluding any which are
// in the host pool list. if the number of returnable members is less than
// the number requested, is will return the smaller set
// will not return any disconnected hosts
func (p *pool) GetAny(length uint32, excluded []*id.ID, rng io.Reader) []*connect.Host {
poolLen := uint32(len(p.hostList))
if length > poolLen {
length = poolLen
}
// Keep track of Hosts already selected to avoid duplicates
checked := make(map[uint32]interface{}, len(p.hostList))
if excluded != nil {
// Add excluded Hosts to already-checked list
for i := range excluded {
gwId := excluded[i]
if idx, ok := p.hostMap[*gwId]; ok {
checked[uint32(idx)] = nil
}
}
}
result := make([]*connect.Host, 0, length)
for i := uint32(0); i < length; {
// If we've checked the entire HostPool, bail
if uint32(len(checked)) >= poolLen {
break
}
// Check the next HostPool index
gwIdx := randomness.ReadRangeUint32(0, poolLen,
rng)
if _, ok := checked[gwIdx]; !ok {
h := p.hostList[gwIdx]
checked[gwIdx] = nil
if !p.isConnected(h) {
continue
}
result = append(result, p.hostList[gwIdx])
i++
}
}
return result
}
// GetSpecific obtains a specific connect.Host from the pool if it exists,
// it otherwise returns nil (and false on the bool) if it does not.
// It will not return the host if it is in the pool but disconnected
func (p *pool) GetSpecific(target *id.ID) (*connect.Host, bool) {
if idx, exists := p.hostMap[*target]; exists {
h := p.hostList[idx]
if !p.isConnected(h) {
return nil, false
}
return h, true
}
return nil, false
}
// GetPreferred tries to obtain the given targets from the HostPool. If each is
// not present, then obtains a random replacement from the HostPool which will
// be proxied.
func (p *pool) GetPreferred(targets []*id.ID, rng io.Reader) []*connect.Host {
// Keep track of Hosts already selected to avoid duplicates
checked := make(map[id.ID]struct{})
//edge checks
numToReturn := len(targets)
poolLen := len(p.hostList)
if numToReturn > poolLen {
numToReturn = poolLen
}
result := make([]*connect.Host, 0, numToReturn)
//check if any targets are in the pool
numSelected := 0
for _, target := range targets {
if targeted, inPool := p.GetSpecific(target); inPool {
result = append(result, targeted)
checked[*target] = struct{}{}
numSelected++
}
}
//fill the rest of the list with random proxies until full
for numSelected < numToReturn && len(checked) < len(p.hostList) {
gwIdx := randomness.ReadRangeUint32(0, uint32(len(p.hostList)),
rng)
selected := p.hostList[gwIdx]
//check if it is already in the list, if not Add it
gwID := selected.GetId()
if _, ok := checked[*gwID]; !ok {
checked[*gwID] = struct{}{}
if !p.isConnected(selected) {
continue
}
result = append(result, selected)
numSelected++
}
}
return result
}
// addOrReplace adds the given host if the pool is not full, or replaces a
// random one if the pool is full. If a host was replaced, it returns it, so
// it can be cleaned up.
func (p *pool) addOrReplace(rng io.Reader, host *connect.Host) *connect.Host {
// if the pool is not full, append to the end
if len(p.hostList) < cap(p.hostList) {
jww.TRACE.Printf("[AddOrReplace] Adding host %s to host list", host.GetId())
p.hostList = append(p.hostList, host)
p.hostMap[*host.GetId()] = uint(len(p.hostList) - 1)
return nil
} else {
jww.TRACE.Printf("[AddOrReplace] Internally replacing...")
selectedIndex := uint(randomness.ReadRangeUint32(0, uint32(len(p.hostList)), rng))
return p.internalReplace(selectedIndex, host)
}
}
// replaceSpecific will replace a specific gateway with the given ID with
// the given host.
func (p *pool) replaceSpecific(toReplace *id.ID,
host *connect.Host) (*connect.Host, error) {
selectedIndex, exists := p.hostMap[*toReplace]
if !exists {
return nil, errors.Errorf("Cannot replace %s, host does not "+
"exist in pool", toReplace)
}
return p.internalReplace(selectedIndex, host), nil
}
// internalReplace places the given host into the hostList and the hostMap.
// This will replace the data from the given index.
func (p *pool) internalReplace(selectedIndex uint, host *connect.Host) *connect.Host {
toRemove := p.hostList[selectedIndex]
p.hostList[selectedIndex] = host
delete(p.hostMap, *toRemove.GetId())
p.hostMap[*host.GetId()] = selectedIndex
return toRemove
}
// deepCopy returns a deep copy of the internal state of pool.
func (p *pool) deepCopy() *pool {
pCopy := &pool{
hostMap: make(map[id.ID]uint, len(p.hostMap)),
hostList: make([]*connect.Host, len(p.hostList)),
isConnected: p.isConnected,
}
copy(pCopy.hostList, p.hostList)
for key, data := range p.hostMap {
pCopy.hostMap[key] = data
}
return pCopy
}
// selectNew will pull random nodes from the pool.
func (p *pool) selectNew(rng csprng.Source, allNodes map[id.ID]int,
currentlyAddingNodes map[id.ID]struct{}, numToSelect int) ([]*id.ID,
map[id.ID]struct{}, error) {
newList := make(map[id.ID]interface{})
// Copy all nodes while removing nodes from the host list and
// from the processing list
for nid := range allNodes {
_, inPool := p.hostMap[nid]
_, inAdd := currentlyAddingNodes[nid]
if !(inPool || inAdd) {
newList[nid] = struct{}{}
}
}
// Error out if no nodes are left
if len(newList) == 0 {
return nil, nil, errors.New("no nodes available for selection")
}
if numToSelect > len(newList) {
// Return all nodes
selections := make([]*id.ID, 0, len(newList))
for gwID := range newList {
localGwid := gwID.DeepCopy()
selections = append(selections, localGwid)
currentlyAddingNodes[*localGwid] = struct{}{}
jww.DEBUG.Printf("[SelectNew] Adding gwId %s to inProgress", localGwid)
}
return selections, currentlyAddingNodes, nil
}
// Randomly select numToSelect indices
toSelectMap := make(map[uint]struct{}, numToSelect)
for i := 0; i < numToSelect; i++ {
newSelection := uint(randomness.ReadRangeUint32(0, uint32(len(newList)), rng))
if _, exists := toSelectMap[newSelection]; exists {
i--
continue
}
toSelectMap[newSelection] = struct{}{}
}
// Use the random indices to choose gateways
selections := make([]*id.ID, 0, numToSelect)
// Select the new ones
index := uint(0)
for gwID := range newList {
localGwid := gwID.DeepCopy()
if _, exists := toSelectMap[index]; exists {
selections = append(selections, localGwid)
currentlyAddingNodes[*localGwid] = struct{}{}
if len(selections) == cap(selections) {
break
}
}
index++
}
return selections, currentlyAddingNodes, nil
}
////////////////////////////////////////////////////////////////////////////////
// Copyright © 2022 xx foundation //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file. //
////////////////////////////////////////////////////////////////////////////////
package gateway
import (
"gitlab.com/xx_network/comms/connect"
"gitlab.com/xx_network/primitives/id"
"math/rand"
"testing"
)
// Unit test, happy paths of GetAny.
func TestHostPool_GetAny(t *testing.T) {
manager := newMockManager()
rng := rand.New(rand.NewSource(42))
testNdf := getTestNdf(t)
params := DefaultParams()
params.MaxPoolSize = uint32(len(testNdf.Gateways))
// Call the constructor
testPool := newPool(5)
// Pull all gateways from NDF into host manager
for _, gw := range testNdf.Gateways {
gwId, err := id.Unmarshal(gw.ID)
if err != nil {
t.Fatalf("Failed to unmarshal ID in mock NDF: %+v", err)
}
// Add mock gateway to manager
var h *connect.Host
h, err = manager.AddHost(
gwId, gw.Address, nil, connect.GetDefaultHostParams())
if err != nil {
t.Fatalf("Could not Add mock host to manager: %+v", err)
}
//Add to the host pool
testPool.addOrReplace(rng, h)
}
testPool.isConnected = func(host *connect.Host) bool { return true }
requested := 3
anyList := testPool.GetAny(uint32(requested), nil, rng)
if len(anyList) != requested {
t.Errorf("GetAnyList did not get requested length."+
"\n\tExpected: %v"+
"\n\tReceived: %v", requested, len(anyList))
}
for _, h := range anyList {
_, ok := manager.GetHost(h.GetId())
if !ok {
t.Errorf("Host %s in retrieved list not in manager", h)
}
}
// Request more than are in host list
largeRequest := uint32(requested * 1000)
largeRetrieved := testPool.GetAny(largeRequest, nil, rng)
if len(largeRetrieved) != len(testPool.hostList) {
t.Errorf("Large request should result in a list of all in host list")
}
// request the whole host pool with a member exluced
excluded := []*id.ID{testPool.hostList[2].GetId()}
requestedExcluded := uint32(len(testPool.hostList))
excludedRetrieved := testPool.GetAny(requestedExcluded, excluded, rng)
if len(excludedRetrieved) != int(requestedExcluded-1) {
t.Errorf("One member should not have been returned due to being excluded")
}
for i := 0; i < len(excludedRetrieved); i++ {
if excludedRetrieved[i].GetId().Cmp(excluded[0]) {
t.Errorf("index %d of the returned list includes the excluded id %s", i, excluded[0])
}
}
}
// Unit test, happy paths of GetAny.
func TestHostPool_GetSpecific(t *testing.T) {
manager := newMockManager()
rng := rand.New(rand.NewSource(42))
testNdf := getTestNdf(t)
params := DefaultParams()
params.MaxPoolSize = uint32(len(testNdf.Gateways))
// Call the constructor
poolLen := 5
testPool := newPool(poolLen)
testPool.isConnected = func(host *connect.Host) bool { return true }
// Pull all gateways from NDF into host manager
for i, gw := range testNdf.Gateways {
gwId, err := id.Unmarshal(gw.ID)
if err != nil {
t.Fatalf("Failed to unmarshal ID in mock NDF: %+v", err)
}
// Add mock gateway to manager
var h *connect.Host
h, err = manager.AddHost(
gwId, gw.Address, nil, connect.GetDefaultHostParams())
if err != nil {
t.Fatalf("Could not Add mock host to manager: %+v", err)
}
//Add to the host pool
if i < poolLen {
testPool.addOrReplace(rng, h)
}
}
testPool.isConnected = func(host *connect.Host) bool { return true }
//test get specific returns something in the host pool]
toGet := testPool.hostList[0].GetId()
h, exists := testPool.GetSpecific(toGet)
if !exists {
t.Errorf("Failed to get member of host pool that should "+
"be there, got: %s", h)
}
if h == nil || !h.GetId().Cmp(toGet) {
t.Errorf("Wrong or invalid host returned")
}
//test get specific returns nothing when the item is not in the host pool
toGet, _ = testNdf.Gateways[poolLen+1].GetGatewayId()
h, exists = testPool.GetSpecific(toGet)
if exists || h != nil {
t.Errorf("Got a member of host pool that should not be there")
}
}
// Full test
func TestHostPool_GetPreferred(t *testing.T) {
manager := newMockManager()
rng := rand.New(rand.NewSource(42))
testNdf := getTestNdf(t)
params := DefaultParams()
params.PoolSize = uint32(len(testNdf.Gateways))
poolLen := 12
testPool := newPool(poolLen)
testPool.isConnected = func(host *connect.Host) bool { return true }
// Pull all gateways from NDF into host manager
hostMap := make(map[id.ID]bool, 0)
targets := make([]*id.ID, 0)
for i, gw := range testNdf.Gateways {
gwId, err := id.Unmarshal(gw.ID)
if err != nil {
t.Fatalf("Failed to unmarshal ID in mock NDF: %+v", err)
}
// Add mock gateway to manager
h, err := manager.AddHost(
gwId, gw.Address, nil, connect.GetDefaultHostParams())
if err != nil {
t.Fatalf("Could not Add mock host to manager: %+v", err)
}
hostMap[*gwId] = true
targets = append(targets, gwId)
//Add to the host pool
if i < poolLen {
testPool.addOrReplace(rng, h)
}
}
retrievedList := testPool.GetPreferred(targets, rng)
if len(retrievedList) != len(targets) {
t.Errorf("Requested list did not output requested length."+
"\n\tExpected: %d"+
"\n\tReceived: %v", len(targets), len(retrievedList))
}
// In case where all requested gateways are present
// ensure requested hosts were returned
for _, h := range retrievedList {
if !hostMap[*h.GetId()] {
t.Errorf("A target gateways which should have been returned was not."+
"\n\tExpected: %v", h.GetId())
}
}
// Replace a request with a gateway not in pool
targets[3] = id.NewIdFromUInt(74, id.Gateway, t)
retrievedList = testPool.GetPreferred(targets, rng)
if len(retrievedList) != len(targets) {
t.Errorf("Requested list did not output requested length."+
"\n\tExpected: %d"+
"\n\tReceived: %v", len(targets), len(retrievedList))
}
// In case where a requested gateway is not present
for _, h := range retrievedList {
if h.GetId().Cmp(targets[3]) {
t.Errorf("Should not have returned ID not in pool")
}
}
}
////////////////////////////////////////////////////////////////////////////////
// Copyright © 2022 xx foundation //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file. //
////////////////////////////////////////////////////////////////////////////////
package gateway
import (
jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/client/v4/stoppable"
"gitlab.com/elixxir/crypto/hash"
"gitlab.com/xx_network/crypto/randomness"
"math/big"
"time"
)
// Rotation is a long-running thread which will cause the runner thread
// to shuffle the hostPool. This will be done over a random interval, with
// the interval being randomly selected every Rotation. The Rotation
// is performed by sending a nil signal to hostPool.addRequest, which
// will force random updates to the hostPool.
func (hp *hostPool) Rotation(stop *stoppable.Single) {
for {
delay := hp.params.RotationPeriod
if hp.params.RotationPeriodVariability != 0 {
stream := hp.rng.GetStream()
seed := make([]byte, 32)
_, err := stream.Read(seed)
if err != nil {
jww.FATAL.Panicf("Failed to read (rng): %+v", err)
}
h, _ := hash.NewCMixHash()
r := randomness.RandInInterval(big.NewInt(int64(hp.params.RotationPeriodVariability)), seed, h).Int64()
r = r - (r / 2)
delay = delay + time.Duration(r)
}
t := time.NewTimer(delay)
select {
case <-stop.Quit():
stop.ToStopped()
t.Stop()
return
case <-t.C:
select {
case hp.addRequest <- nil:
default:
jww.WARN.Printf("Failed to send an Add request after %s delay", delay)
}
}
}
}
////////////////////////////////////////////////////////////////////////////////
// Copyright © 2022 xx foundation //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file. //
////////////////////////////////////////////////////////////////////////////////
package gateway
import (
jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/client/v4/stoppable"
"gitlab.com/elixxir/comms/network"
"gitlab.com/xx_network/comms/connect"
"gitlab.com/xx_network/primitives/id"
"gitlab.com/xx_network/primitives/ndf"
"time"
)
// runner is the primary long-running thread for handling events. It will
// handle the following signals:
// - Requests to add hosts to the hostPool.
// - Requests to remove hosts from the hostPool.
// - Indications that a host has been tested (see nodeTester).
// - Indications that a new host that is ready to be added to the hostPool.
// - Indications that a new NDF has been received.
func (hp *hostPool) runner(stop *stoppable.Single) {
inProgress := make(map[id.ID]struct{})
toRemoveList := make(map[id.ID]interface{}, 2*cap(hp.writePool.hostList))
online := newBucket(cap(hp.writePool.hostList))
for {
update := false
input:
select {
case <-stop.Quit():
stop.ToStopped()
return
// Receives a request to add a node to the host pool if a
// specific node if is sent. It will send that node off
// to testing. If no specific node is sent (ie it receive nil),
// it will send a random one
case toAdd := <-hp.addRequest:
var hostList []*connect.Host
hostList, inProgress = hp.processAddRequest(toAdd, inProgress)
if len(hostList) == 0 {
jww.ERROR.Printf("Host list for testing is empty, this " +
"error should never occur")
break input
}
// Send the signal to the adding pool to add
select {
case hp.testNodes <- hostList:
default:
jww.ERROR.Printf("Failed to send add message")
}
// Handle requests to remove a node from the host pool
case toRemove := <-hp.removeRequest:
// If the host is already slated to be removed, ignore
if _, exists := toRemoveList[*toRemove]; exists {
break input
}
// Do not remove if it is not present in the pool
if !hp.writePool.Has(toRemove) {
jww.DEBUG.Printf("Skipping remove request for %s,"+
" not in the host pool", toRemove)
break input
}
// add to the leaky bucket detecting if we are offline
online.Add()
// Add to the "to remove" list. This will replace that
// node on th next addition to the pool
toRemoveList[*toRemove] = struct{}{}
// Send a signal back to this thread to add a node to the pool
go func() {
hp.addRequest <- nil
}()
// Internal signal on reception of vetted node to add to pool
case newHost := <-hp.newHost:
// Verify the new host is still in the NDF,
// due to how testing is async, it can get removed
if _, exists := hp.ndfMap[*newHost.GetId()]; !exists {
jww.WARN.Printf("New vetted host (%s) is not in NDF, "+
"this is theoretically possible but extremely unlikely. "+
"If this is seen more than once, it is likely something is "+
"wrong", newHost.GetId())
// Send a signal back to this thread to add a node to the pool
go func() {
hp.addRequest <- nil
}()
break input
}
//
online.Reset()
// Replace a node slated for replacement if required
// pop to remove list
toRemove := pop(toRemoveList)
if toRemove != nil {
//if this fails, handle the new host without removing a node
if oldHost, err := hp.writePool.replaceSpecific(toRemove, newHost); err == nil {
update = true
if oldHost != nil {
go func() {
oldHost.Disconnect()
}()
}
} else {
jww.WARN.Printf("Failed to replace %s due to %s, skipping "+
"addition to host pool", toRemove, err)
}
} else {
stream := hp.rng.GetStream()
hp.writePool.addOrReplace(stream, newHost)
stream.Close()
update = true
}
// Tested gateways get passed back, so they can be
// removed from the list of gateways which are being
// tested
case tested := <-hp.doneTesting:
for _, h := range tested {
delete(inProgress, *h.GetId())
jww.DEBUG.Printf("[Runner] Deleted %s from inProgress", h.GetId())
}
// New NDF updates come in over this channel
case newNDF := <-hp.newNdf:
hp.ndf = newNDF.DeepCopy()
// Process the new NDF map
newNDFMap := hp.processNdf(hp.ndf)
// Remove all gateways which are not missing from the host pool
// that are in the host pool
for gwID := range hp.ndfMap {
if hp.writePool.Has(&gwID) {
hp.removeRequest <- gwID.DeepCopy()
}
}
// Replace the ndfMap
hp.ndfMap = newNDFMap
}
// Handle updates by writing host pool into storage
if update == true {
poolCopy := hp.writePool.deepCopy()
hp.readPool.Store(poolCopy)
saveList := make([]*id.ID, 0, len(poolCopy.hostList))
for i := 0; i < len(poolCopy.hostList); i++ {
saveList = append(saveList, poolCopy.hostList[i].GetId())
}
err := saveHostList(hp.kv, saveList)
if err != nil {
jww.WARN.Printf("Host list could not be stored, updates will "+
"not be available on load: %s", err)
}
}
// Wait the delay until next iteration.
delay := online.GetDelay()
select {
case <-time.After(delay):
case <-stop.Quit():
stop.ToStopped()
return
}
}
}
// processAddRequest will return the host of the passed in node if it is
// specified (ie it is not nil). If it is nil, it will select random nodes
// for testing.
func (hp *hostPool) processAddRequest(toAdd *id.ID,
inProgress map[id.ID]struct{}) ([]*connect.Host, map[id.ID]struct{}) {
// Get the nodes to add
var toTest []*id.ID
// Add the given ID if it is in the NDF
if toAdd != nil {
// Check if it is in the NDF
if _, exist := hp.ndfMap[*toAdd]; exist {
toTest = []*id.ID{toAdd}
}
}
// If there are no nodes to add, randomly select some
if len(toTest) == 0 {
var err error
stream := hp.rng.GetStream()
toTest, inProgress, err = hp.writePool.selectNew(stream, hp.ndfMap, inProgress,
hp.numNodesToTest)
stream.Close()
if err != nil {
jww.DEBUG.Printf("[ProcessAndRequest] SelectNew returned error: %s", err)
jww.WARN.Printf("Failed to select any nodes to test for adding, " +
"skipping add. This error may be the result of being disconnected " +
"from the internet or very old network credentials")
return nil, inProgress
}
}
// Get hosts for the selected nodes
hostList := make([]*connect.Host, 0, len(toTest))
for i := 0; i < len(toTest); i++ {
gwID := toTest[i]
h, exists := hp.manager.GetHost(gwID)
if !exists {
jww.FATAL.Panicf("Gateway is not in host pool, this should" +
"be impossible")
}
hostList = append(hostList, h)
}
return hostList, inProgress
}
// processNdf is a helper function which processes a new NDF, converting it to
// a map which maps the gateway's ID to the index it is in the NDF. This map is
// returned, and may be set as hostPool.ndfMap's new value.
func (hp *hostPool) processNdf(newNdf *ndf.NetworkDefinition) map[id.ID]int {
newNDFMap := make(map[id.ID]int, len(hp.ndf.Gateways))
// Make a list of all gateways
for i := 0; i < len(newNdf.Gateways); i++ {
gw := newNdf.Gateways[i]
// Get the ID and bail if it cannot be retrieved
gwID, err := gw.GetGatewayId()
if err != nil {
jww.WARN.Printf("Skipped gateway %d: %x, "+
"ID couldn't be unmarshalled, %+v", i,
newNdf.Gateways[i].ID, err)
continue
}
// Skip adding if the node is not active
if newNdf.Nodes[i].Status != ndf.Active {
continue
}
// Check if the ID exists, if it does not add its host
if _, exists := hp.manager.GetHost(gwID); !exists {
var gwAddr string
var cert []byte
gwAddr, cert, err = getConnectionInfo(gwID, gw.Address, gw.TlsCertificate)
if err == nil {
_, err = hp.manager.AddHost(gwID, gwAddr,
cert, hp.params.HostParams)
}
if err != nil {
jww.WARN.Printf("Skipped gateway %d: %s, "+
"host could not be added, %+v", i,
gwID, err)
continue
}
hp.addChan <- network.NodeGateway{
Node: newNdf.Nodes[i],
Gateway: gw,
}
}
// Add to the new map
newNDFMap[*gwID] = i
// Delete from the old ndf map so we can track which gateways are
// missing
delete(hp.ndfMap, *gwID)
}
return newNDFMap
}
// pop selects an element from the map that tends to be an earlier insert,
// removes it, and returns it
func pop(m map[id.ID]interface{}) *id.ID {
for tr := range m {
delete(m, tr)
return &tr
}
return nil
}
...@@ -14,13 +14,14 @@ import ( ...@@ -14,13 +14,14 @@ import (
jww "github.com/spf13/jwalterweatherman" jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/client/v4/stoppable" "gitlab.com/elixxir/client/v4/stoppable"
"gitlab.com/elixxir/client/v4/storage" "gitlab.com/elixxir/client/v4/storage"
"gitlab.com/elixxir/comms/network" commNetwork "gitlab.com/elixxir/comms/network"
"gitlab.com/elixxir/crypto/fastRNG" "gitlab.com/elixxir/crypto/fastRNG"
"gitlab.com/xx_network/comms/connect" "gitlab.com/xx_network/comms/connect"
"gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/id"
"gitlab.com/xx_network/primitives/ndf" "gitlab.com/xx_network/primitives/ndf"
"gitlab.com/xx_network/primitives/netTime" "gitlab.com/xx_network/primitives/netTime"
"strings" "strings"
"testing"
"time" "time"
) )
...@@ -34,60 +35,79 @@ type Sender interface { ...@@ -34,60 +35,79 @@ type Sender interface {
UpdateNdf(ndf *ndf.NetworkDefinition) UpdateNdf(ndf *ndf.NetworkDefinition)
SetGatewayFilter(f Filter) SetGatewayFilter(f Filter)
GetHostParams() connect.HostParams GetHostParams() connect.HostParams
StartProcesses() stoppable.Stoppable
} }
type sender struct { type sender struct {
*HostPool *hostPool
} }
const RetryableError = "Nonfatal error occurred, please retry" const RetryableError = "Nonfatal error occurred, please retry"
// NewSender Create a new Sender object wrapping a HostPool object // NewSender creates a new Sender object wrapping a HostPool object
func NewSender(poolParams PoolParams, rng *fastRNG.StreamGenerator, func NewSender(poolParams Params, rng *fastRNG.StreamGenerator,
ndf *ndf.NetworkDefinition, getter HostManager, storage storage.Session, ndf *ndf.NetworkDefinition, getter HostManager,
addGateway chan network.NodeGateway) (Sender, error) { storage storage.Session, comms CertCheckerCommInterface,
addChan chan commNetwork.NodeGateway) (Sender, error) {
hostPool, err := newHostPool( hp, err := newHostPool(poolParams, rng, ndf,
poolParams, rng, ndf, getter, storage, addGateway) getter, storage, addChan, comms)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &sender{hostPool}, nil return &sender{hp}, nil
} }
// SendToAny call given sendFunc to any Host in the HostPool, attempting with up // NewSender creates a new Sender object wrapping a HostPool object
// to numProxies destinations. func NewTestingSender(poolParams Params, rng *fastRNG.StreamGenerator,
ndf *ndf.NetworkDefinition, getter HostManager,
storage storage.Session, addChan chan commNetwork.NodeGateway,
t *testing.T) (Sender, error) {
if t == nil {
jww.FATAL.Panicf("can only be called in testing")
}
hp, err := newTestingHostPool(poolParams, rng, ndf,
getter, storage, addChan, nil, t)
if err != nil {
return nil, err
}
return &sender{hp}, nil
}
// SendToAny will call the given send function to any connect.Host in the host pool.
func (s *sender) SendToAny(sendFunc func(*connect.Host) (interface{}, error), func (s *sender) SendToAny(sendFunc func(*connect.Host) (interface{}, error),
stop *stoppable.Single) (interface{}, error) { stop *stoppable.Single) (interface{}, error) {
proxies := s.getAny(s.poolParams.ProxyAttempts, nil) p := s.getPool()
if err := p.IsReady(); err != nil {
return nil, errors.WithMessagef(err, "Failed to SendToAny")
}
rng := s.rng.GetStream()
proxies := p.GetAny(s.params.ProxyAttempts, nil, rng)
rng.Close()
for proxy := range proxies { for proxy := range proxies {
result, err := sendFunc(proxies[proxy]) proxyHost := proxies[proxy]
result, err := sendFunc(proxyHost)
if stop != nil && !stop.IsRunning() { if stop != nil && !stop.IsRunning() {
return nil, return nil,
errors.Errorf(stoppable.ErrMsg, stop.Name(), "SendToAny") errors.Errorf(stoppable.ErrMsg, stop.Name(), "SendToAny")
} else if err == nil { } else if err == nil {
return result, nil return result, nil
} else { } else {
// Now we must check whether the Host should be replaced // send a signal to remove from the host pool if it is a not
replaced, checkReplaceErr := s.checkReplace( // allowed error
proxies[proxy].GetId(), err) if IsGuilty(err) {
if replaced { s.Remove(proxyHost)
jww.WARN.Printf("Unable to SendToAny, replaced a proxy %s "+
"with error %s", proxies[proxy].GetId(), err.Error())
} else {
if checkReplaceErr != nil {
jww.WARN.Printf("Unable to SendToAny via %s: %s. "+
"Unable to replace host: %+v",
proxies[proxy].GetId(), err.Error(), checkReplaceErr)
} else {
jww.WARN.Printf("Unable to SendToAny via %s: %s. "+
"Did not replace host.",
proxies[proxy].GetId(), err.Error())
}
} }
// End for non-retryable errors // If the send function denotes the error are recoverable,
// try another host
if !strings.Contains(err.Error(), RetryableError) { if !strings.Contains(err.Error(), RetryableError) {
return nil, return nil,
errors.WithMessage(err, "Received error with SendToAny") errors.WithMessage(err, "Received error with SendToAny")
...@@ -102,16 +122,24 @@ func (s *sender) SendToAny(sendFunc func(*connect.Host) (interface{}, error), ...@@ -102,16 +122,24 @@ func (s *sender) SendToAny(sendFunc func(*connect.Host) (interface{}, error),
type SendToPreferredFunc func(host *connect.Host, target *id.ID, type SendToPreferredFunc func(host *connect.Host, target *id.ID,
timeout time.Duration) (interface{}, error) timeout time.Duration) (interface{}, error)
// SendToPreferred Call given sendFunc to any Host in the HostPool, attempting // SendToPreferred calls the given send function to any connect.Host in the
// with up to numProxies destinations. Returns an error if the timeout is // host pool, attempting. Returns an error if the timeout is reached.
// reached.
func (s *sender) SendToPreferred(targets []*id.ID, sendFunc SendToPreferredFunc, func (s *sender) SendToPreferred(targets []*id.ID, sendFunc SendToPreferredFunc,
stop *stoppable.Single, timeout time.Duration) (interface{}, error) { stop *stoppable.Single, timeout time.Duration) (interface{}, error) {
startTime := netTime.Now() startTime := netTime.Now()
p := s.getPool()
if err := p.IsReady(); err != nil {
return nil, errors.WithMessagef(err, "Failed to SendToPreferred")
}
rng := s.rng.GetStream()
defer rng.Close()
// get the hosts and shuffle randomly // get the hosts and shuffle randomly
targetHosts := s.getPreferred(targets) targetHosts := p.GetPreferred(targets, rng)
// Attempt to send directly to targets if they are in the HostPool // Attempt to send directly to targets if they are in the HostPool
for i := range targetHosts { for i := range targetHosts {
...@@ -129,36 +157,27 @@ func (s *sender) SendToPreferred(targets []*id.ID, sendFunc SendToPreferredFunc, ...@@ -129,36 +157,27 @@ func (s *sender) SendToPreferred(targets []*id.ID, sendFunc SendToPreferredFunc,
} else if err == nil { } else if err == nil {
return result, nil return result, nil
} else { } else {
// Now we must check whether the Host should be replaced // send a signal to remove from the host pool if it is a not
replaced, checkReplaceErr := s.checkReplace(targetHosts[i].GetId(), err) // allowed error
if replaced { if IsGuilty(err) {
jww.WARN.Printf("Unable to SendToPreferred first pass via %s, "+ s.Remove(targetHosts[i])
"replaced a proxy %s with error %s",
targets[i], targetHosts[i].GetId(), err.Error())
} else {
if checkReplaceErr != nil {
jww.WARN.Printf("Unable to SendToPreferred first pass %s "+
"via %s: %s. Unable to replace host: %+v", targets[i],
targetHosts[i].GetId(), err.Error(), checkReplaceErr)
} else {
jww.WARN.Printf("Unable to SendToPreferred first pass %s "+
"via %s: %s. Did not replace host.",
targets[i], targetHosts[i].GetId(), err.Error())
}
} }
// End for non-retryable errors // If the send function denotes the error are recoverable,
// try another host
if !strings.Contains(err.Error(), RetryableError) { if !strings.Contains(err.Error(), RetryableError) {
return nil, errors.WithMessage( return nil,
err, "Received error with SendToPreferred") errors.WithMessage(err, "Received error with SendToAny")
} }
} }
} }
//re-get the pool in case it has had an update
p = s.getPool()
// Build a list of proxies for every target // Build a list of proxies for every target
proxies := make([][]*connect.Host, len(targets)) proxies := make([][]*connect.Host, len(targets))
for i := 0; i < len(targets); i++ { for i := 0; i < len(targets); i++ {
proxies[i] = s.getAny(s.poolParams.ProxyAttempts, targets) proxies[i] = p.GetAny(s.params.ProxyAttempts, targets, rng)
} }
// Build a map of bad proxies // Build a map of bad proxies
...@@ -166,7 +185,7 @@ func (s *sender) SendToPreferred(targets []*id.ID, sendFunc SendToPreferredFunc, ...@@ -166,7 +185,7 @@ func (s *sender) SendToPreferred(targets []*id.ID, sendFunc SendToPreferredFunc,
// Iterate between each target's list of proxies, using the next target for // Iterate between each target's list of proxies, using the next target for
// each proxy // each proxy
for proxyIdx := uint32(0); proxyIdx < s.poolParams.ProxyAttempts; proxyIdx++ { for proxyIdx := uint32(0); proxyIdx < s.params.ProxyAttempts; proxyIdx++ {
for targetIdx := range proxies { for targetIdx := range proxies {
// Return an error if the timeout duration is reached // Return an error if the timeout duration is reached
if netTime.Since(startTime) > timeout { if netTime.Since(startTime) > timeout {
...@@ -203,23 +222,17 @@ func (s *sender) SendToPreferred(targets []*id.ID, sendFunc SendToPreferredFunc, ...@@ -203,23 +222,17 @@ func (s *sender) SendToPreferred(targets []*id.ID, sendFunc SendToPreferredFunc,
target, proxy, err) target, proxy, err)
continue continue
} else { } else {
// Check whether the Host should be replaced // send a signal to remove from the host pool if it is a not
replaced, checkReplaceErr := s.checkReplace(proxy.GetId(), err) // allowed error
badProxies[proxy.String()] = nil if IsGuilty(err) {
if replaced { s.Remove(proxy)
jww.WARN.Printf("Unable to SendToPreferred second pass "+
"via %s, replaced a proxy %s with error %s",
target, proxy.GetId(), err.Error())
} else {
if checkReplaceErr != nil {
jww.WARN.Printf("Unable to SendToPreferred second "+
"pass %s via %s: %s. Unable to replace host: %+v",
target, proxy.GetId(), err.Error(), checkReplaceErr)
} else {
jww.WARN.Printf("Unable to SendToPreferred second "+
"pass %s via %s: %s. Did not replace host.",
target, proxy.GetId(), err.Error())
} }
// If the send function denotes the error are recoverable,
// try another host
if !strings.Contains(err.Error(), RetryableError) {
return nil,
errors.WithMessage(err, "Received error with SendToAny")
} }
// End for non-retryable errors // End for non-retryable errors
......
...@@ -25,11 +25,11 @@ func TestNewSender(t *testing.T) { ...@@ -25,11 +25,11 @@ func TestNewSender(t *testing.T) {
rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG)
testNdf := getTestNdf(t) testNdf := getTestNdf(t)
testStorage := storage.InitTestingSession(t) testStorage := storage.InitTestingSession(t)
addGwChan := make(chan network.NodeGateway) params := DefaultParams()
params := DefaultPoolParams()
params.MaxPoolSize = uint32(len(testNdf.Gateways)) params.MaxPoolSize = uint32(len(testNdf.Gateways))
addChan := make(chan network.NodeGateway, len(testNdf.Gateways))
_, err := NewSender(params, rng, testNdf, manager, testStorage, addGwChan) mccc := &mockCertCheckerComm{}
_, err := NewSender(params, rng, testNdf, manager, testStorage, mccc, addChan)
if err != nil { if err != nil {
t.Fatalf("Failed to create mock sender: %v", err) t.Fatalf("Failed to create mock sender: %v", err)
} }
...@@ -41,50 +41,47 @@ func TestSender_SendToAny(t *testing.T) { ...@@ -41,50 +41,47 @@ func TestSender_SendToAny(t *testing.T) {
rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG)
testNdf := getTestNdf(t) testNdf := getTestNdf(t)
testStorage := storage.InitTestingSession(t) testStorage := storage.InitTestingSession(t)
addGwChan := make(chan network.NodeGateway) params := DefaultParams()
params := DefaultPoolParams()
params.PoolSize = uint32(len(testNdf.Gateways)) params.PoolSize = uint32(len(testNdf.Gateways))
addChan := make(chan network.NodeGateway, len(testNdf.Gateways))
// Pull all gateways from NDF into host manager mccc := &mockCertCheckerComm{}
for _, gw := range testNdf.Gateways {
gwId, err := id.Unmarshal(gw.ID)
if err != nil {
t.Fatalf("Failed to unmarshal ID in mock NDF: %+v", err)
}
// Add mock gateway to manager
_, err = manager.AddHost(
gwId, gw.Address, nil, connect.GetDefaultHostParams())
if err != nil {
t.Fatalf("Could not add mock host to manager: %+v", err)
}
}
senderFace, err := NewSender( senderFace, err := NewSender(
params, rng, testNdf, manager, testStorage, addGwChan) params, rng, testNdf, manager, testStorage, mccc, addChan)
s := senderFace.(*sender) s := senderFace.(*sender)
if err != nil { if err != nil {
t.Fatalf("Failed to create mock sender: %v", err) t.Fatalf("Failed to create mock sender: %v", err)
} }
// Add all gateways to hostPool's map stream := rng.GetStream()
for index, gw := range testNdf.Gateways { defer stream.Close()
// Put 3 gateways into the pool
for i := 0; i < cap(s.writePool.hostList); i++ {
gw := testNdf.Gateways[i]
gwId, err := id.Unmarshal(gw.ID) gwId, err := id.Unmarshal(gw.ID)
if err != nil { if err != nil {
t.Fatalf("Failed to unmarshal ID in mock NDF: %+v", err) t.Fatalf("Failed to unmarshal ID in mock NDF: %+v", err)
} }
// Add mock gateway to manager
err = s.replaceHost(gwId, uint32(index)) gwHost, err := manager.AddHost(
gwId, gw.Address, nil, connect.GetDefaultHostParams())
if err != nil { if err != nil {
t.Fatalf("Failed to replace host in set-up: %+v", err) t.Fatalf("Could not Add mock host to manager: %+v", err)
} }
s.writePool.addOrReplace(stream, gwHost)
} }
s.writePool.isConnected = func(host *connect.Host) bool { return true }
//update the read pool
s.readPool.Store(s.writePool)
// Test sendToAny with test interfaces // Test sendToAny with test interfaces
result, err := s.SendToAny(SendToAnyHappyPath, nil) result, err := s.SendToAny(SendToAnyHappyPath, nil)
if err != nil { if err != nil {
t.Errorf("Should not error in SendToAny happy path: %v", err) t.Errorf("Should not error in SendToAny happy path: %+v", err)
} }
if !reflect.DeepEqual(result, happyPathReturn) { if !reflect.DeepEqual(result, happyPathReturn) {
...@@ -111,38 +108,50 @@ func TestSender_SendToPreferred(t *testing.T) { ...@@ -111,38 +108,50 @@ func TestSender_SendToPreferred(t *testing.T) {
rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG)
testNdf := getTestNdf(t) testNdf := getTestNdf(t)
testStorage := storage.InitTestingSession(t) testStorage := storage.InitTestingSession(t)
addGwChan := make(chan network.NodeGateway) params := DefaultParams()
params := DefaultPoolParams()
params.PoolSize = uint32(len(testNdf.Gateways)) - 5 params.PoolSize = uint32(len(testNdf.Gateways)) - 5
// Do not test proxy attempts code in this test // Do not test proxy attempts code in this test
// (self contain to code specific in sendPreferred) // (self contain to code specific in sendPreferred)
params.ProxyAttempts = 0 params.ProxyAttempts = 0
mccc := &mockCertCheckerComm{}
addChan := make(chan network.NodeGateway, len(testNdf.Gateways))
sFace, err := NewSender(params, rng, testNdf, manager, testStorage, mccc, addChan)
if err != nil {
t.Fatalf("Failed to create mock sender: %v", err)
}
s := sFace.(*sender)
// Pull all gateways from NDF into host manager stream := rng.GetStream()
for _, gw := range testNdf.Gateways { defer stream.Close()
var preferredHost *connect.Host
// Put 3 gateways into the pool
for i := 0; i < cap(s.writePool.hostList); i++ {
gw := testNdf.Gateways[i]
gwId, err := id.Unmarshal(gw.ID) gwId, err := id.Unmarshal(gw.ID)
if err != nil { if err != nil {
t.Fatalf("Failed to unmarshal ID in mock NDF: %+v", err) t.Fatalf("Failed to unmarshal ID in mock NDF: %+v", err)
} }
// Add mock gateway to manager // Add mock gateway to manager
_, err = manager.AddHost( gwHost, err := manager.AddHost(
gwId, gw.Address, nil, connect.GetDefaultHostParams()) gwId, gw.Address, nil, connect.GetDefaultHostParams())
if err != nil { if err != nil {
t.Fatalf("Could not add mock host to manager: %+v", err) t.Fatalf("Could not Add mock host to manager: %+v", err)
} }
} s.writePool.addOrReplace(stream, gwHost)
sFace, err := NewSender(params, rng, testNdf, manager, testStorage, addGwChan) if i == 1 {
if err != nil { preferredHost = gwHost
t.Fatalf("Failed to create mock sender: %v", err) }
} }
s := sFace.(*sender)
preferredIndex := 0 s.writePool.isConnected = func(host *connect.Host) bool { return true }
preferredHost := s.hostList[preferredIndex]
//update the read pool
s.readPool.Store(s.writePool)
// Happy path // Happy path
result, err := s.SendToPreferred([]*id.ID{preferredHost.GetId()}, result, err := s.SendToPreferred([]*id.ID{preferredHost.GetId()},
...@@ -164,20 +173,21 @@ func TestSender_SendToPreferred(t *testing.T) { ...@@ -164,20 +173,21 @@ func TestSender_SendToPreferred(t *testing.T) {
t.Fatalf("Expected error path did not receive error") t.Fatalf("Expected error path did not receive error")
} }
// Check the host has been replaced // Check the host removal signal has been sent
if _, ok := s.hostMap[*preferredHost.GetId()]; ok { select {
t.Errorf("Expected host %s to be removed due to error", preferredHost) case removeSignal := <-s.removeRequest:
if !removeSignal.Cmp(preferredHost.GetId()) {
t.Errorf("Expected host %s to be removed due "+
"to error, instead %s removed", preferredHost, removeSignal)
} }
default:
// Ensure we are disconnected from the old host t.Errorf("Expected host %s error to trigger a removal signal",
if isConnected, _ := preferredHost.Connected(); isConnected {
t.Errorf("ForceReplace error: Failed to disconnect from old host %s",
preferredHost) preferredHost)
} }
// get a new host to test on // get a new host to test on
preferredIndex = 4 preferredIndex := 4
preferredHost = s.hostList[preferredIndex] preferredHost = s.writePool.hostList[preferredIndex]
// Unknown error return will not trigger replacement // Unknown error return will not trigger replacement
_, err = s.SendToPreferred([]*id.ID{preferredHost.GetId()}, _, err = s.SendToPreferred([]*id.ID{preferredHost.GetId()},
...@@ -187,7 +197,7 @@ func TestSender_SendToPreferred(t *testing.T) { ...@@ -187,7 +197,7 @@ func TestSender_SendToPreferred(t *testing.T) {
} }
// Check the host has not been replaced // Check the host has not been replaced
if _, ok := s.hostMap[*preferredHost.GetId()]; !ok { if _, ok := s.writePool.hostMap[*preferredHost.GetId()]; !ok {
t.Errorf("Host %s should not have been removed due on an unknown error", t.Errorf("Host %s should not have been removed due on an unknown error",
preferredHost) preferredHost)
} }
......
...@@ -77,7 +77,7 @@ func unmarshalHostList(data []byte) ([]*id.ID, error) { ...@@ -77,7 +77,7 @@ func unmarshalHostList(data []byte) ([]*id.ID, error) {
buff := bytes.NewBuffer(data) buff := bytes.NewBuffer(data)
list := make([]*id.ID, 0, len(data)/id.ArrIDLen) list := make([]*id.ID, 0, len(data)/id.ArrIDLen)
// Read each ID from data, unmarshal, and add to list // Read each ID from data, unmarshal, and Add to list
length := id.ArrIDLen length := id.ArrIDLen
for n := buff.Next(length); len(n) == length; n = buff.Next(length) { for n := buff.Next(length); len(n) == length; n = buff.Next(length) {
hid, err := id.Unmarshal(n) hid, err := id.Unmarshal(n)
......
package nodes
import (
pb "gitlab.com/elixxir/comms/mixmessages"
"gitlab.com/xx_network/comms/connect"
)
type mockCertCheckerComm struct {
}
func (mccc *mockCertCheckerComm) GetGatewayTLSCertificate(host *connect.Host,
message *pb.RequestGatewayCert) (*pb.GatewayCertificate, error) {
return &pb.GatewayCertificate{}, nil
}
...@@ -35,7 +35,6 @@ func registerNodes(r *registrar, s session, stop *stoppable.Single, ...@@ -35,7 +35,6 @@ func registerNodes(r *registrar, s session, stop *stoppable.Single,
inProgress, attempts *sync.Map, index int) { inProgress, attempts *sync.Map, index int) {
atomic.AddInt64(r.numberRunning, 1) atomic.AddInt64(r.numberRunning, 1)
for { for {
select { select {
case <-r.pauser: case <-r.pauser:
...@@ -79,15 +78,6 @@ func registerNodes(r *registrar, s session, stop *stoppable.Single, ...@@ -79,15 +78,6 @@ func registerNodes(r *registrar, s session, stop *stoppable.Single,
continue continue
} }
// Keep track of how many times registering with this node
// has been attempted
numAttempts := uint(1)
if nunAttemptsInterface, hasValue := attempts.LoadOrStore(
nidStr, numAttempts); hasValue {
numAttempts = nunAttemptsInterface.(uint)
attempts.Store(nidStr, numAttempts+1)
}
// No need to register with stale nodes // No need to register with stale nodes
if isStale := gw.Node.Status == ndf.Stale; isStale { if isStale := gw.Node.Status == ndf.Stale; isStale {
jww.DEBUG.Printf( jww.DEBUG.Printf(
...@@ -104,7 +94,34 @@ func registerNodes(r *registrar, s session, stop *stoppable.Single, ...@@ -104,7 +94,34 @@ func registerNodes(r *registrar, s session, stop *stoppable.Single,
// Process the result // Process the result
if err != nil { if err != nil {
if gateway.IsHostPoolNotReadyError(err) {
jww.WARN.Printf("Failed to register node due to non ready host "+
"pool: %s", err.Error())
// retry registering without counting it against the node
go func() {
r.c <- gw
}()
//wait 5 seconds to give the host pool a chance to resolve
select {
case <-time.NewTimer(5 * time.Second).C:
case <-stop.Quit():
stop.ToStopped()
return
}
} else {
jww.ERROR.Printf("Failed to register node: %s", err.Error()) jww.ERROR.Printf("Failed to register node: %s", err.Error())
// Keep track of how many times registering with this node
// has been attempted
numAttempts := uint(1)
if nunAttemptsInterface, hasValue := attempts.LoadOrStore(
nidStr, numAttempts); hasValue {
numAttempts = nunAttemptsInterface.(uint)
attempts.Store(nidStr, numAttempts+1)
}
// If we have not reached the attempt limit for this gateway, // If we have not reached the attempt limit for this gateway,
// then send it back into the channel to retry // then send it back into the channel to retry
if numAttempts < maxAttempts { if numAttempts < maxAttempts {
...@@ -115,6 +132,8 @@ func registerNodes(r *registrar, s session, stop *stoppable.Single, ...@@ -115,6 +132,8 @@ func registerNodes(r *registrar, s session, stop *stoppable.Single,
}() }()
} }
} }
}
rng.Close() rng.Close()
} }
if index >= 2 { if index >= 2 {
......
...@@ -30,8 +30,11 @@ func TestLoadRegistrar_New(t *testing.T) { ...@@ -30,8 +30,11 @@ func TestLoadRegistrar_New(t *testing.T) {
rngGen := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) rngGen := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG)
p := gateway.DefaultPoolParams() p := gateway.DefaultPoolParams()
p.MaxPoolSize = 1 p.MaxPoolSize = 1
addChan := make(chan commNetwork.NodeGateway, 1)
mccc := &mockCertCheckerComm{}
sender, err := gateway.NewSender(gateway.DefaultPoolParams(), rngGen, sender, err := gateway.NewSender(gateway.DefaultPoolParams(), rngGen,
getNDF(), newMockManager(), session, nil) getNDF(), newMockManager(), session, mccc, addChan)
if err != nil { if err != nil {
t.Fatalf("Failed to create new sender: %+v", err) t.Fatalf("Failed to create new sender: %+v", err)
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment