diff --git a/network/gateway/gateway.go b/network/gateway/gateway.go index 1c2c247f33eb64993da2bd1288f979e454e5a9a9..2af245821076802716ffd2bd1412f5a3fdc43540 100644 --- a/network/gateway/gateway.go +++ b/network/gateway/gateway.go @@ -23,10 +23,14 @@ type HostGetter interface { GetHost(hostId *id.ID) (*connect.Host, bool) } +// Get the Host of a random gateway in the NDF func Get(ndf *ndf.NetworkDefinition, hg HostGetter, rng io.Reader) (*connect.Host, error) { - // Get a random gateway - gateways := ndf.Gateways - gwIdx := ReadRangeUint32(0, uint32(len(gateways)), rng) + gwLen := uint32(len(ndf.Gateways)) + if gwLen == 0 { + return nil, errors.Errorf("no gateways available") + } + + gwIdx := ReadRangeUint32(0, gwLen, rng) gwID, err := id.Unmarshal(ndf.Nodes[gwIdx].ID) if err != nil { return nil, errors.WithMessage(err, "failed to get Gateway") @@ -40,6 +44,7 @@ func Get(ndf *ndf.NetworkDefinition, hg HostGetter, rng io.Reader) (*connect.Hos return gwHost, nil } +// Get the last gateway Host from the given RoundInfo func GetLast(hg HostGetter, ri *mixmessages.RoundInfo) (*connect.Host, error) { roundTop := ri.GetTopology() lastGw, err := id.Unmarshal(roundTop[len(roundTop)-1])