Skip to content
Snippets Groups Projects
Commit 981df392 authored by Jake Taylor's avatar Jake Taylor :lips:
Browse files

Merge branch 'Mufasa/CommsRestructure' into 'release'

Mufasa/comms restructure

Closes XX-2394

See merge request elixxir/comms!252
parents 3f3b22b6 0f6e55c1
No related branches found
No related tags found
No related merge requests found
Showing with 299 additions and 2189 deletions
......@@ -19,7 +19,7 @@ before_script:
- ssh-keyscan -t rsa gitlab.com > ~/.ssh/known_hosts
- git config --global url."git@gitlab.com:".insteadOf "https://gitlab.com/"
- export PATH=$HOME/go/bin:$PATH
- export GOPRIVATE=gitlab.com/elixxir/*
- export GOPRIVATE=gitlab.com/elixxir/*,gitlab.com/xx_network/*
stages:
- build
......
......@@ -17,10 +17,12 @@ build:
update_release:
GOFLAGS="" go get -u gitlab.com/elixxir/primitives@release
GOFLAGS="" go get -u gitlab.com/elixxir/crypto@release
GOFLAGS="" go get -u gitlab.com/xx_network/comms@release
update_master:
GOFLAGS="" go get -u gitlab.com/elixxir/primitives@master
GOFLAGS="" go get -u gitlab.com/elixxir/crypto@master
GOFLAGS="" go get -u gitlab.com/xx_network/comms@master
master: clean update_master build
......
......@@ -59,7 +59,7 @@ run the following command in the terminal in order to regenerate the
`mixmessage.pb.go` file:
```
protoc -I mixmessages/ mixmessages/mixmessages.proto --go_out=plugins=grpc:mixmessages
protoc -I vendor/ -I mixmessages/ mixmessages/mixmessages.proto --go_out=plugins=grpc:mixmessages
```
## Repository Organization
......
......@@ -11,8 +11,8 @@ package client
import (
"github.com/pkg/errors"
"gitlab.com/elixxir/comms/connect"
"gitlab.com/elixxir/primitives/id"
"gitlab.com/xx_network/comms/connect"
)
// Client object used to implement endpoints and top-level comms functionality
......
......@@ -14,8 +14,8 @@ import (
"github.com/golang/protobuf/ptypes/any"
"github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/comms/connect"
pb "gitlab.com/elixxir/comms/mixmessages"
"gitlab.com/xx_network/comms/connect"
"google.golang.org/grpc"
)
......
......@@ -8,10 +8,10 @@
package client
import (
"gitlab.com/elixxir/comms/connect"
"gitlab.com/elixxir/comms/gateway"
pb "gitlab.com/elixxir/comms/mixmessages"
"gitlab.com/elixxir/primitives/id"
"gitlab.com/xx_network/comms/connect"
"testing"
)
......
......@@ -11,12 +11,34 @@ package client
import (
"fmt"
"github.com/pkg/errors"
pb "gitlab.com/elixxir/comms/mixmessages"
"gitlab.com/elixxir/comms/testutils"
"gitlab.com/elixxir/primitives/id"
"gitlab.com/elixxir/primitives/ndf"
"gitlab.com/xx_network/comms/connect"
"sync"
)
// ------------------------- Testing globals -------------------------------------
var GetHostErrBool = true
var RequestNdfErr error = nil
var NdfToreturn = pb.NDF{Ndf: []byte(testutils.ExampleJSON)}
const RegistrationAddr = "0.0.0.0:5558"
var ExampleBadNdfJSON = "badNDF"
var RegistrationHandler = &MockRegistration{}
const RegistrationAddrErr = "0.0.0.0:5559"
var portLock sync.Mutex
var port = 5800
var RegistrationError = &MockRegistrationError{}
var Retries = 0
// Utility function to avoid address collisions in testing suite
func getNextAddress() string {
portLock.Lock()
defer func() {
......@@ -25,3 +47,76 @@ func getNextAddress() string {
}()
return fmt.Sprintf("0.0.0.0:%d", port)
}
// ------------------------- Mock Registration Server Handler ---------------------------
type MockRegistration struct {
}
func (s *MockRegistration) RegisterNode(NodeID *id.ID,
NodeTLSCert, GatewayTLSCert, RegistrationCode, Addr, Addr2 string) error {
return nil
}
func (s *MockRegistration) PollNdf(clientNdfHash []byte, auth *connect.Auth) ([]byte, error) {
return []byte(testutils.ExampleJSON), nil
}
func (s *MockRegistration) Poll(*pb.PermissioningPoll, *connect.Auth, string) (*pb.PermissionPollResponse, error) {
return &pb.PermissionPollResponse{}, nil
}
// Registers a user and returns a signed public key
func (s *MockRegistration) RegisterUser(registrationCode,
key string) (hash []byte, err error) {
return nil, nil
}
func (s *MockRegistration) GetCurrentClientVersion() (version string, err error) {
return "", nil
}
func (s *MockRegistration) CheckRegistration(msg *pb.RegisteredNodeCheck) (*pb.RegisteredNodeConfirmation, error) {
return nil, nil
}
// ------------------------- Mock Error Registration Server Handler ---------------------------
type MockRegistrationError struct {
}
func (s *MockRegistrationError) RegisterNode(NodeID *id.ID,
NodeTLSCert, GatewayTLSCert, RegistrationCode, Addr, Addr2 string) error {
return nil
}
func (s *MockRegistrationError) PollNdf(clientNdfHash []byte, auth *connect.Auth) ([]byte, error) {
if Retries < 5 {
Retries++
return nil, errors.New(ndf.NO_NDF)
}
return []byte(testutils.ExampleJSON), nil
}
func (s *MockRegistrationError) Poll(*pb.PermissioningPoll, *connect.Auth, string) (*pb.PermissionPollResponse, error) {
if Retries < 5 {
Retries++
return nil, errors.New(ndf.NO_NDF)
}
return &pb.PermissionPollResponse{}, nil
}
// Registers a user and returns a signed public key
func (s *MockRegistrationError) RegisterUser(registrationCode,
key string) (hash []byte, err error) {
return nil, nil
}
func (s *MockRegistrationError) GetCurrentClientVersion() (version string, err error) {
return "", nil
}
func (s *MockRegistrationError) CheckRegistration(msg *pb.RegisteredNodeCheck) (*pb.RegisteredNodeConfirmation, error) {
return nil, nil
}
......@@ -14,14 +14,15 @@ import (
"github.com/golang/protobuf/ptypes/any"
"github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/comms/connect"
pb "gitlab.com/elixxir/comms/mixmessages"
"gitlab.com/xx_network/comms/connect"
"gitlab.com/xx_network/comms/messages"
"google.golang.org/grpc"
)
// Client -> NotificationBot
func (c *Comms) RegisterForNotifications(host *connect.Host,
message *pb.NotificationToken) (*pb.Ack, error) {
message *pb.NotificationToken) (*messages.Ack, error) {
// Create the Send Function
f := func(conn *grpc.ClientConn) (*any.Any, error) {
// Set up the context
......@@ -50,20 +51,20 @@ func (c *Comms) RegisterForNotifications(host *connect.Host,
}
// Marshall the result
result := &pb.Ack{}
result := &messages.Ack{}
return result, ptypes.UnmarshalAny(resultMsg, result)
}
// Client -> NotificationBot
func (c *Comms) UnregisterForNotifications(host *connect.Host) (*pb.Ack, error) {
func (c *Comms) UnregisterForNotifications(host *connect.Host) (*messages.Ack, error) {
// Create the Send Function
f := func(conn *grpc.ClientConn) (*any.Any, error) {
// Set up the context
ctx, cancel := connect.MessagingContext()
defer cancel()
authMsg, err := c.PackAuthenticatedMessage(&pb.Ping{}, host, false)
authMsg, err := c.PackAuthenticatedMessage(&messages.Ping{}, host, false)
if err != nil {
return nil, errors.New(err.Error())
}
......@@ -85,7 +86,7 @@ func (c *Comms) UnregisterForNotifications(host *connect.Host) (*pb.Ack, error)
}
// Marshall the result
result := &pb.Ack{}
result := &messages.Ack{}
return result, ptypes.UnmarshalAny(resultMsg, result)
}
......@@ -8,10 +8,10 @@
package client
import (
"gitlab.com/elixxir/comms/connect"
pb "gitlab.com/elixxir/comms/mixmessages"
"gitlab.com/elixxir/comms/notificationBot"
"gitlab.com/elixxir/primitives/id"
"gitlab.com/xx_network/comms/connect"
"testing"
)
......
......@@ -10,13 +10,19 @@
package client
import (
"crypto/sha256"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/any"
"github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/comms/connect"
pb "gitlab.com/elixxir/comms/mixmessages"
"gitlab.com/elixxir/primitives/id"
"gitlab.com/elixxir/primitives/ndf"
"gitlab.com/xx_network/comms/connect"
"gitlab.com/xx_network/comms/messages"
"google.golang.org/grpc"
"strings"
"time"
)
// Client -> Registration Send Function
......@@ -62,7 +68,7 @@ func (c *Comms) SendGetCurrentClientVersionMessage(
// Send the message
resultMsg, err := pb.NewRegistrationClient(
conn).GetCurrentClientVersion(ctx, &pb.Ping{})
conn).GetCurrentClientVersion(ctx, &messages.Ping{})
if err != nil {
return nil, errors.New(err.Error())
}
......@@ -81,8 +87,98 @@ func (c *Comms) SendGetCurrentClientVersionMessage(
return result, ptypes.UnmarshalAny(resultMsg, result)
}
// Client -> Registration Send Function
func (c *Comms) RequestNdf(host *connect.Host, message *pb.NDFHash) (*pb.NDF, error) {
// Call Protocomms Request NDF
return c.ProtoComms.RequestNdf(host, message)
// RequestNdf is used to Request an ndf from permissioning
// Used by gateway, client, nodes and gateways
func (c *Comms) RequestNdf(host *connect.Host,
message *pb.NDFHash) (*pb.NDF, error) {
// Create the Send Function
f := func(conn *grpc.ClientConn) (*any.Any, error) {
// Set up the context
ctx, cancel := connect.MessagingContext()
defer cancel()
authMsg, err := c.PackAuthenticatedMessage(message, host, false)
if err != nil {
return nil, errors.New(err.Error())
}
// Send the message
resultMsg, err := pb.NewRegistrationClient(
conn).PollNdf(ctx, authMsg)
if err != nil {
return nil, errors.New(err.Error())
}
return ptypes.MarshalAny(resultMsg)
}
// Execute the Send function
jww.DEBUG.Printf("Sending Request Ndf message: %+v", message)
resultMsg, err := c.Send(host, f)
if err != nil {
return nil, err
}
result := &pb.NDF{}
return result, ptypes.UnmarshalAny(resultMsg, result)
}
// RetrieveNdf, attempts to connect to the permissioning server to retrieve the latest ndf for the notifications bot
func (c *Comms) RetrieveNdf(currentDef *ndf.NetworkDefinition) (*ndf.NetworkDefinition, error) {
//Hash the notifications bot ndf for comparison with registration's ndf
var ndfHash []byte
// If the ndf passed not nil, serialize and hash it
if currentDef != nil {
//Hash the notifications bot ndf for comparison with registration's ndf
hash := sha256.New()
ndfBytes, err := currentDef.Marshal()
if err != nil {
return nil, err
}
hash.Write(ndfBytes)
ndfHash = hash.Sum(nil)
}
//Put the hash in a message
msg := &pb.NDFHash{Hash: ndfHash}
regHost, ok := c.Manager.GetHost(&id.Permissioning)
if !ok {
return nil, errors.New("Failed to find permissioning host")
}
//Send the hash to registration
response, err := c.RequestNdf(regHost, msg)
// Keep going until we get a grpc error or we get an ndf
for err != nil {
// If there is an unexpected error
if !strings.Contains(err.Error(), ndf.NO_NDF) {
// If it is not an issue with no ndf, return the error up the stack
errMsg := errors.Errorf("Failed to get ndf from permissioning: %v", err)
return nil, errMsg
}
// If the error is that the permissioning server is not ready, ask again
jww.WARN.Println("Failed to get an ndf, possibly not ready yet. Retying now...")
time.Sleep(250 * time.Millisecond)
response, err = c.RequestNdf(regHost, msg)
}
//If there was no error and the response is nil, client's ndf is up-to-date
if response == nil || response.Ndf == nil {
jww.DEBUG.Printf("Our NDF is up-to-date")
return nil, nil
}
jww.INFO.Printf("Remote NDF: %s", string(response.Ndf))
//Otherwise pull the ndf out of the response
updatedNdf, _, err := ndf.DecodeNDF(string(response.Ndf))
if err != nil {
//If there was an error decoding ndf
errMsg := errors.Errorf("Failed to decode response to ndf: %v", err)
return nil, errMsg
}
return updatedNdf, nil
}
......@@ -8,10 +8,12 @@
package client
import (
"gitlab.com/elixxir/comms/connect"
pb "gitlab.com/elixxir/comms/mixmessages"
"gitlab.com/elixxir/comms/registration"
"gitlab.com/elixxir/comms/testutils"
"gitlab.com/elixxir/primitives/id"
"gitlab.com/elixxir/primitives/ndf"
"gitlab.com/xx_network/comms/connect"
"testing"
)
......@@ -93,3 +95,87 @@ func TestSendGetUpdatedNDF(t *testing.T) {
t.Errorf("RequestNdf: Error received: %s", err)
}
}
// Test that Poll NDF handles all comms errors returned properly, and that it decodes and successfully returns an ndf
func TestProtoComms_PollNdf(t *testing.T) {
// Define a client object
clientId := id.NewIdFromString("client", id.Generic, t)
c, err := NewClientComms(clientId, nil, nil, nil)
if err != nil {
t.Errorf("Can't create client comms: %+v", err)
}
mockPermServer := registration.StartRegistrationServer(&id.Permissioning, RegistrationAddr, RegistrationHandler, nil, nil)
defer mockPermServer.Shutdown()
newNdf := &ndf.NetworkDefinition{}
// Test that poll ndf fails if getHost returns an error
GetHostErrBool = false
RequestNdfErr = nil
_, err = c.RetrieveNdf(newNdf)
if err == nil {
t.Errorf("GetHost should have failed but it didnt't: %+v", err)
t.Fail()
}
// Test that pollNdf returns an error in this case
// This enters an infinite loop is there a way to fix this test?
// Test that pollNdf Fails if it cant decode the request msg
RequestNdfErr = nil
GetHostErrBool = true
NdfToreturn.Ndf = []byte(ExampleBadNdfJSON)
_, err = c.RetrieveNdf(newNdf)
if err == nil {
t.Logf("RequestNdf should have failed to parse bad ndf: %+v", err)
t.Fail()
}
_, err = c.ProtoComms.AddHost(&id.Permissioning, RegistrationAddr, nil, false, false)
if err != nil {
t.Errorf("Failed to add permissioning as a host: %+v", err)
}
// Test that pollNDf Is successful with expected result
RequestNdfErr = nil
GetHostErrBool = true
NdfToreturn.Ndf = []byte(testutils.ExampleJSON)
_, err = c.RetrieveNdf(newNdf)
//comms.mockManager.AddHost()
if err != nil {
t.Logf("Ndf failed to parse: %+v", err)
t.Fail()
}
}
// Happy path
func TestProtoComms_PollNdfRepeatedly(t *testing.T) {
// Define a client object
clientId := id.NewIdFromString("client", id.Generic, t)
c, err := NewClientComms(clientId, nil, nil, nil)
if err != nil {
t.Errorf("Can't create client comms: %+v", err)
}
// Start up the mock reg server
mockPermServer := registration.StartRegistrationServer(&id.Permissioning, RegistrationAddrErr, RegistrationError, nil, nil)
defer mockPermServer.Shutdown()
// Add the host to the comms object
_, err = c.ProtoComms.AddHost(&id.Permissioning, RegistrationAddrErr, nil, false, false)
if err != nil {
t.Errorf("Failed to add permissioning as a host: %+v", err)
}
newNdf := &ndf.NetworkDefinition{}
// This should hit the loop until the number of retries is satisfied in the error handler
_, err = c.RetrieveNdf(newNdf)
if err != nil {
t.Errorf("Expected error case, should not return non-error until attempt #5")
}
}
///////////////////////////////////////////////////////////////////////////////
// Copyright © 2020 xx network SEZC //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file //
///////////////////////////////////////////////////////////////////////////////
// Handles authentication logic for the top-level comms object
package connect
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman"
pb "gitlab.com/elixxir/comms/mixmessages"
"gitlab.com/elixxir/crypto/nonce"
"gitlab.com/elixxir/crypto/signature/rsa"
"gitlab.com/elixxir/crypto/xx"
"gitlab.com/elixxir/primitives/id"
"google.golang.org/grpc/metadata"
)
// Auth represents an authorization state for a message or host
type Auth struct {
// Indicates whether authentication was successful
IsAuthenticated bool
// The information about the Host that sent the authenticated communication
Sender *Host
}
// Perform the client handshake to establish reverse-authentication
func (c *ProtoComms) clientHandshake(host *Host) (err error) {
// Set up the context
client := pb.NewGenericClient(host.connection)
ctx, cancel := MessagingContext()
defer cancel()
// Send the token request message
result, err := client.RequestToken(ctx,
&pb.Ping{})
if err != nil {
return errors.New(err.Error())
}
// Pack the authenticated message with signature enabled
msg, err := c.PackAuthenticatedMessage(&pb.AssignToken{
Token: result.Token,
}, host, true)
if err != nil {
return errors.New(err.Error())
}
// Add special client-specific info to the
// newly generated authenticated message if needed
if c.pubKeyPem != nil && c.salt != nil {
msg.Client.Salt = c.salt
msg.Client.PublicKey = string(c.pubKeyPem)
}
// Set up the context
ctx, cancel = MessagingContext()
defer cancel()
// Send the authenticate token message
_, err = client.AuthenticateToken(ctx, msg)
if err != nil {
return errors.New(err.Error())
}
// Assign the host token
host.transmissionToken = result.Token
return
}
// Convert any message type into a authenticated message
func (c *ProtoComms) PackAuthenticatedMessage(msg proto.Message, host *Host,
enableSignature bool) (*pb.AuthenticatedMessage, error) {
// Marshall the provided message into an Any type
anyMsg, err := ptypes.MarshalAny(msg)
if err != nil {
return nil, errors.New(err.Error())
}
// Build the authenticated message
authMsg := &pb.AuthenticatedMessage{
ID: c.Id.Marshal(),
Token: host.transmissionToken,
Message: anyMsg,
Client: &pb.ClientID{
Salt: make([]byte, 0),
PublicKey: "",
},
}
// If signature is enabled, sign the message and add to payload
if enableSignature && !c.disableAuth {
authMsg.Signature, err = c.SignMessage(msg)
if err != nil {
return nil, err
}
}
return authMsg, nil
}
// Add authentication fields to a given context and return it
func (c *ProtoComms) PackAuthenticatedContext(host *Host,
ctx context.Context) context.Context {
ctx = metadata.AppendToOutgoingContext(ctx, "ID", c.Id.String())
ctx = metadata.AppendToOutgoingContext(ctx, "TOKEN", base64.StdEncoding.EncodeToString(host.transmissionToken))
return ctx
}
// Returns authentication packed into a context
func UnpackAuthenticatedContext(ctx context.Context) (*pb.AuthenticatedMessage, error) {
auth := &pb.AuthenticatedMessage{}
var err error
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, errors.New("unable to retrieve meta data / header")
}
idStr := md.Get("ID")[0]
auth.ID, err = base64.StdEncoding.DecodeString(idStr)
if err != nil {
return nil, errors.WithMessage(err, "could not decode authentication ID")
}
tokenStr := md.Get("TOKEN")[0]
auth.Token, err = base64.StdEncoding.DecodeString(tokenStr)
if err != nil {
return nil, errors.WithMessage(err, "could not decode authentication Token")
}
return auth, nil
}
// Generates a new token and adds it to internal state
func (c *ProtoComms) GenerateToken() ([]byte, error) {
token, err := nonce.NewNonce(nonce.RegistrationTTL)
if err != nil {
return nil, err
}
c.tokens.Store(string(token.Bytes()), &token)
jww.DEBUG.Printf("Token generated: %v", token.Bytes())
return token.Bytes(), nil
}
// Performs the dynamic authentication process such that Hosts that were not
// already added to the Manager can establish authentication
func (c *ProtoComms) dynamicAuth(msg *pb.AuthenticatedMessage) (
host *Host, err error) {
// Verify the client is attempting a dynamic authentication
if msg.Client == nil || msg.Client.PublicKey == "" || msg.Client.Salt == nil {
return nil, errors.New("Invalid dynamic authentication attempt!")
}
// Process the public key
pubKey, err := rsa.LoadPublicKeyFromPem([]byte(msg.Client.PublicKey))
if err != nil {
return nil, errors.New(err.Error())
}
// Generate the user's ID from supplied Client information
uid, err := xx.NewID(pubKey, msg.Client.Salt, id.User)
if err != nil {
return nil, err
}
// Verify the ID provided correctly matches the generated ID
if !bytes.Equal(msg.ID, uid.Marshal()) {
return nil, errors.Errorf(
"Provided ID does not match. Expected: %s, Actual: %s",
uid.String(), msg.ID)
}
// Create and add the new host to the manager
host, err = newDynamicHost(uid, []byte(msg.Client.PublicKey))
if err != nil {
return
}
c.addHost(host)
return
}
// Validates a signed token using internal state
func (c *ProtoComms) ValidateToken(msg *pb.AuthenticatedMessage) (err error) {
// Convert EntityID to ID
msgID, err := id.Unmarshal(msg.ID)
if err != nil {
return err
}
// Verify the Host exists for the provided ID
host, ok := c.GetHost(msgID)
if !ok {
// If the host does not already exist, attempt dynamic authentication
jww.DEBUG.Printf("Attempting dynamic authentication: %s", msgID.String())
host, err = c.dynamicAuth(msg)
if err != nil {
return errors.Errorf(
"Unable to complete dynamic authentication: %+v", err)
}
}
// This logic prevents deadlocks when performing authentication with self
// TODO: This may require further review
if !msgID.Cmp(c.Id) || bytes.Compare(host.receptionToken, msg.Token) != 0 {
host.mux.Lock()
defer host.mux.Unlock()
}
// Get the signed token
tokenMsg := &pb.AssignToken{}
err = ptypes.UnmarshalAny(msg.Message, tokenMsg)
if err != nil {
return errors.Errorf("Unable to unmarshal token: %+v", err)
}
// Verify the token signature unless disableAuth has been set for testing
if !c.disableAuth {
if err := c.VerifyMessage(tokenMsg, msg.Signature, host); err != nil {
return errors.Errorf("Invalid token signature: %+v", err)
}
}
// Verify the signed token was actually assigned
token, ok := c.tokens.Load(string(tokenMsg.Token))
if !ok {
return errors.Errorf("Unable to locate token: %+v", msg.Token)
}
// Verify the signed token is not expired
if !token.(*nonce.Nonce).IsValid() {
return errors.Errorf("Invalid or expired token: %+v", tokenMsg.Token)
}
// Token has been validated and can be safely stored
host.receptionToken = tokenMsg.Token
jww.DEBUG.Printf("Token validated: %v", tokenMsg.Token)
return
}
// AuthenticatedReceiver handles reception of an AuthenticatedMessage,
// checking if the host is authenticated & returning an Auth state
func (c *ProtoComms) AuthenticatedReceiver(msg *pb.AuthenticatedMessage) (*Auth, error) {
// Convert EntityID to ID
msgID, err := id.Unmarshal(msg.ID)
if err != nil {
return nil, err
}
// Try to obtain the Host for the specified ID
host, ok := c.GetHost(msgID)
if !ok {
return &Auth{
IsAuthenticated: false,
Sender: &Host{},
}, nil
}
// If host is found, mutex must be locked
host.mux.RLock()
defer host.mux.RUnlock()
// Check the token's validity
validToken := host.receptionToken != nil && msg.Token != nil &&
bytes.Compare(host.receptionToken, msg.Token) == 0
// Assemble the Auth object
res := &Auth{
IsAuthenticated: validToken,
Sender: host,
}
jww.TRACE.Printf("Authentication status: %v, ProvidedId: %v ProvidedToken: %v",
res.IsAuthenticated, msg.ID, msg.Token)
return res, nil
}
// DisableAuth makes the authentication code skip signing and signature verification if the
// set. Can only be set while in a testing structure. Is not thread safe.
func (c *ProtoComms) DisableAuth() {
jww.WARN.Print("Auth checking disabled, running insecurely")
c.disableAuth = true
}
// Takes a message and returns its signature
// The message is signed with the ProtoComms RSA PrivateKey
func (c *ProtoComms) SignMessage(msg proto.Message) ([]byte, error) {
// Hash the message data
msgBytes, err := proto.Marshal(msg)
if err != nil {
return nil, errors.New(err.Error())
}
options := rsa.NewDefaultOptions()
hash := options.Hash.New()
hash.Write([]byte(msgBytes))
hashed := hash.Sum(nil)
// Obtain the private key
key := c.GetPrivateKey()
if key == nil {
return nil, errors.Errorf("Cannot sign message: No private key")
}
// Sign the message and return the signature
signature, err := rsa.Sign(rand.Reader, key, options.Hash, hashed, nil)
if err != nil {
return nil, errors.New(err.Error())
}
return signature, nil
}
// Takes a message and a Host, verifies the signature
// using Host public key, returning an error if invalid
func (c *ProtoComms) VerifyMessage(msg proto.Message, signature []byte, host *Host) error {
// Get hashed data of the message
msgBytes, err := proto.Marshal(msg)
if err != nil {
return errors.New(err.Error())
}
options := rsa.NewDefaultOptions()
hash := options.Hash.New()
hash.Write([]byte(msgBytes))
hashed := hash.Sum(nil)
// Verify signature of message using host public key
err = rsa.Verify(host.rsaPublicKey, options.Hash, hashed, signature, nil)
if err != nil {
return errors.New(err.Error())
}
return nil
}
///////////////////////////////////////////////////////////////////////////////
// Copyright © 2020 xx network SEZC //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file //
///////////////////////////////////////////////////////////////////////////////
package connect
import (
"errors"
"gitlab.com/elixxir/primitives/id"
"strings"
)
const baseAuthErr = "Failed to authenticate"
// AuthError returns a valid authorization error on the given id
func AuthError(id *id.ID) error {
return errors.New(baseAuthErr + " id: " + id.String())
}
// IsAuthError returns true if the passed error is a valid auth error
func IsAuthError(err error) bool {
return strings.Contains(err.Error(), baseAuthErr)
}
///////////////////////////////////////////////////////////////////////////////
// Copyright © 2020 xx network SEZC //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file //
///////////////////////////////////////////////////////////////////////////////
package connect
import (
"errors"
"gitlab.com/elixxir/primitives/id"
"testing"
)
func TestAuthError(t *testing.T) {
expectedAuthErrorStr := "Failed to authenticate id: c29pc29pc29pAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
result := AuthError(id.NewIdFromString("soisoisoi", id.Generic, t))
if result == nil {
t.Error("AuthError did not return an error object")
}
if result.Error() != expectedAuthErrorStr {
t.Errorf("returned error not as expected: Expected: %s, received: %s",
expectedAuthErrorStr, result.Error())
}
}
func TestIsAuthError(t *testing.T) {
isAuthError := errors.New("Failed to authenticate id: soisoisoi")
if !IsAuthError(isAuthError) {
t.Errorf("IsAuthError returned that authError is not an authError")
}
notAuthError := errors.New("dont feed the plants")
if IsAuthError(notAuthError) {
t.Errorf("IsAuthError returned that a non authError is an authError")
}
}
///////////////////////////////////////////////////////////////////////////////
// Copyright © 2020 xx network SEZC //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file //
///////////////////////////////////////////////////////////////////////////////
package connect
import (
"bytes"
"github.com/golang/protobuf/ptypes"
pb "gitlab.com/elixxir/comms/mixmessages"
"gitlab.com/elixxir/comms/testkeys"
"gitlab.com/elixxir/crypto/csprng"
"gitlab.com/elixxir/crypto/signature/rsa"
"gitlab.com/elixxir/crypto/xx"
"gitlab.com/elixxir/primitives/id"
"sync"
"testing"
)
func TestSignVerify(t *testing.T) {
c := *new(ProtoComms)
key := testkeys.GetNodeKeyPath()
err := c.setPrivateKey(testkeys.LoadFromPath(key))
if err != nil {
t.Errorf("Error setting private key: %+v", err)
}
private := c.GetPrivateKey()
pub := private.Public().(*rsa.PublicKey)
message := pb.NDF{
Ndf: []byte("test"),
}
wrappedMessage, err := ptypes.MarshalAny(&message)
if err != nil {
t.Errorf("Error converting to Any type: %+v", err)
}
signature, err := c.SignMessage(wrappedMessage)
if err != nil {
t.Errorf("Error signing message: %+v", err)
}
host := &Host{
rsaPublicKey: pub,
}
err = c.VerifyMessage(wrappedMessage, signature, host)
if err != nil {
t.Errorf("Error verifying signature")
}
}
func TestProtoComms_AuthenticatedReceiver(t *testing.T) {
// Create comm object
pc := ProtoComms{
Manager: Manager{},
tokens: sync.Map{},
LocalServer: nil,
ListeningAddr: "",
privateKey: nil,
}
// Create id and token
testID := id.NewIdFromString("testSender", id.Node, t)
token := []byte("testtoken")
// Add host
_, err := pc.AddHost(testID, "", nil, false, true)
if err != nil {
t.Errorf("Failed to add host: %+v", err)
}
// Get host
h, _ := pc.GetHost(testID)
h.receptionToken = token
msg := &pb.AuthenticatedMessage{
ID: testID.Marshal(),
Signature: nil,
Token: token,
Message: nil,
}
// Try the authenticated received
auth, err := pc.AuthenticatedReceiver(msg)
if err != nil {
t.Errorf("AuthenticatedReceiver() produced an error: %v", err)
}
if !auth.IsAuthenticated {
t.Errorf("Failed: authenticated receiver")
}
// Compare the tokens
if !bytes.Equal(auth.Sender.receptionToken, token) {
t.Errorf("Tokens do not match! \n\tExptected: %+v\n\tReceived: %+v", token, auth.Sender.receptionToken)
}
}
// Error path
func TestProtoComms_AuthenticatedReceiver_BadId(t *testing.T) {
// Create comm object
pc := ProtoComms{
Manager: Manager{},
tokens: sync.Map{},
LocalServer: nil,
ListeningAddr: "",
privateKey: nil,
}
// Create id and token
testID := id.NewIdFromString("testSender", id.Node, t)
token := []byte("testtoken")
// Add host
_, err := pc.AddHost(testID, "", nil, false, true)
if err != nil {
t.Errorf("Failed to add host: %+v", err)
}
// Get host
h, _ := pc.GetHost(testID)
h.receptionToken = token
badId := []byte("badID")
msg := &pb.AuthenticatedMessage{
ID: badId,
Signature: nil,
Token: token,
Message: nil,
}
// Try the authenticated received
_, err = pc.AuthenticatedReceiver(msg)
if err != nil {
return
}
t.Errorf("Expected error path!"+
"Should not be able to marshal a message with id: %v", badId)
}
// Happy path
func TestProtoComms_GenerateToken(t *testing.T) {
comm := ProtoComms{
LocalServer: nil,
ListeningAddr: "",
privateKey: nil,
}
tokenBytes, err := comm.GenerateToken()
if err != nil || tokenBytes == nil {
t.Errorf("Unable to generate token: %+v", err)
}
token, ok := comm.tokens.Load(string(tokenBytes))
if !ok || token == nil {
t.Errorf("Unable to find token stored in internal map")
}
}
// Happy path
func TestProtoComms_PackAuthenticatedMessage(t *testing.T) {
testServerId := id.NewIdFromString("test12345", id.Node, t)
comm := ProtoComms{
Id: testServerId,
LocalServer: nil,
ListeningAddr: "",
privateKey: nil,
}
tokenBytes, err := comm.GenerateToken()
if err != nil || tokenBytes == nil {
t.Errorf("Unable to generate token: %+v", err)
}
testId := id.NewIdFromString("test", id.Node, t)
host, err := NewHost(testId, "test", nil, false, true)
if err != nil {
t.Errorf("Unable to create host: %+v", err)
}
host.transmissionToken = tokenBytes
tokenMsg := &pb.AssignToken{
Token: tokenBytes,
}
msg, err := comm.PackAuthenticatedMessage(tokenMsg, host, false)
if err != nil {
t.Errorf("Expected no error packing authenticated message: %+v", err)
}
// Compare the tokens and id's
if bytes.Compare(msg.Token, tokenBytes) != 0 || !bytes.Equal(msg.ID, testServerId.Marshal()) {
t.Errorf("Expected packed message to have correct ID and Token: %+v",
msg)
}
}
// Happy path
func TestProtoComms_ValidateToken(t *testing.T) {
testId := id.NewIdFromString("test", id.Node, t)
comm := ProtoComms{
Id: testId,
LocalServer: nil,
ListeningAddr: "",
privateKey: nil,
}
err := comm.setPrivateKey(testkeys.LoadFromPath(testkeys.GetNodeKeyPath()))
if err != nil {
t.Errorf("Expected to set private key: %+v", err)
}
tokenBytes, err := comm.GenerateToken()
if err != nil || tokenBytes == nil {
t.Errorf("Unable to generate token: %+v", err)
}
pub := testkeys.LoadFromPath(testkeys.GetNodeCertPath())
host, err := comm.AddHost(testId, "test", pub, false, true)
if err != nil {
t.Errorf("Unable to create host: %+v", err)
}
host.transmissionToken = tokenBytes
tokenMsg := &pb.AssignToken{
Token: tokenBytes,
}
msg, err := comm.PackAuthenticatedMessage(tokenMsg, host, true)
if err != nil {
t.Errorf("Expected no error packing authenticated message: %+v", err)
}
msg.Client.PublicKey = string(pub)
// Check the token
err = comm.ValidateToken(msg)
if err != nil {
t.Errorf("Expected to validate token: %+v", err)
}
if !bytes.Equal(msg.Token, host.transmissionToken) {
t.Errorf("Message token doesn't match message's token! "+
"Expected: %+v"+
"\n\tReceived: %+v", host.transmissionToken, msg.Token)
}
}
// Error Path
func TestProtoComms_ValidateToken_BadId(t *testing.T) {
testId := id.NewIdFromString("test", id.Node, t)
comm := ProtoComms{
Id: testId,
LocalServer: nil,
ListeningAddr: "",
privateKey: nil,
}
err := comm.setPrivateKey(testkeys.LoadFromPath(testkeys.GetNodeKeyPath()))
if err != nil {
t.Errorf("Expected to set private key: %+v", err)
}
tokenBytes, err := comm.GenerateToken()
if err != nil || tokenBytes == nil {
t.Errorf("Unable to generate token: %+v", err)
}
pub := testkeys.LoadFromPath(testkeys.GetNodeCertPath())
host, err := comm.AddHost(testId, "test", pub, false, true)
if err != nil {
t.Errorf("Unable to create host: %+v", err)
}
host.transmissionToken = tokenBytes
tokenMsg := &pb.AssignToken{
Token: tokenBytes,
}
msg, err := comm.PackAuthenticatedMessage(tokenMsg, host, true)
if err != nil {
t.Errorf("Expected no error packing authenticated message: %+v", err)
}
// Assign message a bad id
badId := []byte("badID")
msg.ID = badId
msg.Client.PublicKey = string(pub)
// Check the token
err = comm.ValidateToken(msg)
if err != nil {
return
}
t.Errorf("Expected error path!"+
"Should not be able to marshal a message with id: %v", badId)
}
// Dynamic authentication happy path (e.g. host not pre-added)
func TestProtoComms_ValidateTokenDynamic(t *testing.T) {
// All of this is setup for UID ----
privKey, err := rsa.GenerateKey(csprng.NewSystemRNG(), rsa.DefaultRSABitLen)
if err != nil {
t.Errorf("Could not generate private key: %+v", err)
}
pubKey := privKey.GetPublic()
salt := []byte("0123456789ABCDEF0123456789ABCDEF")
uid, err := xx.NewID(pubKey, salt, id.User)
if err != nil {
t.Errorf("Could not generate user ID: %+v", err)
}
testId := uid.String()
// ------
// Now we set up the client comms object
comm := ProtoComms{
Id: uid,
ListeningAddr: "",
}
err = comm.setPrivateKey(rsa.CreatePrivateKeyPem(privKey))
if err != nil {
t.Errorf("Expected to set private key: %+v", err)
}
tokenBytes, err := comm.GenerateToken()
if err != nil || tokenBytes == nil {
t.Errorf("Unable to generate token: %+v", err)
}
// For this test we won't addHost to Manager, we'll just create a host
// so we can compare to the dynamic one later
host, err := newDynamicHost(uid, rsa.CreatePublicKeyPem(pubKey))
if err != nil {
t.Errorf("Unable to create host: %+v", err)
}
host.transmissionToken = tokenBytes
tokenMsg := &pb.AssignToken{
Token: tokenBytes,
}
// Set up auth msg
msg, err := comm.PackAuthenticatedMessage(tokenMsg, host, true)
if err != nil {
t.Errorf("Expected no error packing authenticated message: %+v", err)
}
msg.Client = &pb.ClientID{
Salt: salt,
PublicKey: string(rsa.CreatePublicKeyPem(pubKey)),
}
// Here's the method we're testing
err = comm.ValidateToken(msg)
if err != nil {
t.Errorf("Expected to validate token: %+v", err)
}
// Check the output values behaved as expected
host, ok := comm.GetHost(uid)
if !ok {
t.Errorf("Expected dynamic auth to add host %s!", testId)
}
if !host.IsDynamicHost() {
t.Errorf("Expected host to be dynamic!")
}
if !bytes.Equal(msg.Token, host.receptionToken) {
t.Errorf("Message token doesn't match message's token! "+
"Expected: %+v"+
"\n\tReceived: %+v", host.receptionToken, msg.Token)
}
}
func TestProtoComms_DisableAuth(t *testing.T) {
testId := id.NewIdFromString("test", id.Node, t)
comm := ProtoComms{
Id: testId,
LocalServer: nil,
ListeningAddr: "",
privateKey: nil,
}
comm.DisableAuth()
if !comm.disableAuth {
t.Error("Auth was not disabled when DisableAuth was called")
}
}
///////////////////////////////////////////////////////////////////////////////
// Copyright © 2020 xx network SEZC //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file //
///////////////////////////////////////////////////////////////////////////////
// Handles the basic top-level comms object used across all packages
package connect
import (
"crypto/sha256"
"crypto/tls"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/any"
"github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/comms/mixmessages"
"gitlab.com/elixxir/crypto/signature/rsa"
"gitlab.com/elixxir/primitives/id"
"gitlab.com/elixxir/primitives/ndf"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"math"
"net"
"strings"
"sync"
"time"
)
// TODO: Set these via config
// KaOpts are Keepalive options for servers
var KaOpts = keepalive.ServerParameters{
// Idle for at most 5s
MaxConnectionIdle: 5 * time.Second,
// Reset after an hour
MaxConnectionAge: 1 * time.Hour,
// w/ 1m grace shutdown
MaxConnectionAgeGrace: 1 * time.Minute,
// ping if no activity after 1s
Time: 1 * time.Second,
// Close conn 2 seconds afer ping
Timeout: 2 * time.Second,
}
// KaEnforcement are keepalive enforcement options for servers
var KaEnforcement = keepalive.EnforcementPolicy{
// Client should wait at least 250ms
MinTime: 250 * time.Millisecond,
// Doing KA on non-streams is OK
PermitWithoutStream: true,
}
// MaxConcurrentStreams is the number of server-side streams to allow open
var MaxConcurrentStreams = uint32(250000)
// Proto object containing a gRPC server
type ProtoComms struct {
// Inherit the Manager object
Manager
// The network ID of this comms server
Id *id.ID
// Private key of the local comms instance
privateKey *rsa.PrivateKey
// Disables the checking of authentication signatures for testing setups
disableAuth bool
// SERVER-ONLY FIELDS ------------------------------------------------------
// A map of reverse-authentication tokens
tokens sync.Map
// Local network server
LocalServer *grpc.Server
// Listening address of the local server
ListeningAddr string
// CLIENT-ONLY FIELDS ------------------------------------------------------
// Used to store the public key used for generating Client Id
pubKeyPem []byte
// Used to store the salt used for generating Client Id
salt []byte
// -------------------------------------------------------------------------
}
// Creates a ProtoComms client-type object to be used in various initializers
func CreateCommClient(id *id.ID, pubKeyPem, privKeyPem,
salt []byte) (*ProtoComms, error) {
// Build the ProtoComms object
pc := &ProtoComms{
Id: id,
pubKeyPem: pubKeyPem,
salt: salt,
}
// Set the private key if specified
if privKeyPem != nil {
err := pc.setPrivateKey(privKeyPem)
if err != nil {
return nil, errors.Errorf("Could not set private key: %+v", err)
}
}
return pc, nil
}
// Creates a ProtoComms server-type object to be used in various initializers
func StartCommServer(id *id.ID, localServer string, certPEMblock,
keyPEMblock []byte) (*ProtoComms, net.Listener, error) {
// Build the ProtoComms object
pc := &ProtoComms{
Id: id,
ListeningAddr: localServer,
}
listen:
// Listen on the given address
lis, err := net.Listen("tcp", localServer)
if err != nil {
if strings.Contains(err.Error(), "bind: address already in use") {
jww.WARN.Printf("Could not listen on %s, is port in use? waiting 30s: %s", localServer, err.Error())
time.Sleep(30 * time.Second)
goto listen
}
return nil, nil, errors.New(err.Error())
}
// If TLS was specified
if certPEMblock != nil && keyPEMblock != nil {
// Create the TLS certificate
x509cert, err := tls.X509KeyPair(certPEMblock, keyPEMblock)
if err != nil {
return nil, nil, errors.Errorf("Could not load TLS keys: %+v", err)
}
// Set the private key
err = pc.setPrivateKey(keyPEMblock)
if err != nil {
return nil, nil, errors.Errorf("Could not set private key: %+v", err)
}
// Create the gRPC server with TLS
jww.INFO.Printf("Starting server with TLS...")
creds := credentials.NewServerTLSFromCert(&x509cert)
pc.LocalServer = grpc.NewServer(grpc.Creds(creds),
grpc.MaxConcurrentStreams(MaxConcurrentStreams),
grpc.MaxRecvMsgSize(math.MaxInt32),
grpc.KeepaliveParams(KaOpts),
grpc.KeepaliveEnforcementPolicy(KaEnforcement))
} else {
// Create the gRPC server without TLS
jww.WARN.Printf("Starting server with TLS disabled...")
pc.LocalServer = grpc.NewServer(
grpc.MaxConcurrentStreams(MaxConcurrentStreams),
grpc.MaxRecvMsgSize(math.MaxInt32),
grpc.KeepaliveParams(KaOpts),
grpc.KeepaliveEnforcementPolicy(KaEnforcement))
}
return pc, lis, nil
}
// Performs a graceful shutdown of the local server
func (c *ProtoComms) Shutdown() {
c.DisconnectAll()
c.LocalServer.GracefulStop()
time.Sleep(time.Millisecond * 500)
}
// Stringer method
func (c *ProtoComms) String() string {
return c.ListeningAddr
}
// Setter for local server's private key
func (c *ProtoComms) setPrivateKey(data []byte) error {
key, err := rsa.LoadPrivateKeyFromPem(data)
if err != nil {
return errors.Errorf("Failed to form private key file from data at %s: %+v", data, err)
}
c.privateKey = key
return nil
}
// Getter for local server's private key
func (c *ProtoComms) GetPrivateKey() *rsa.PrivateKey {
return c.privateKey
}
const (
//numerical designators for 3 operations of send, used to ensure the same
//operation isn't repeated
con = 1
auth = 2
send = 3
//maximum number of connection attempts. should be 1 because a connection
//is unlikely to be successful after a failure
maxConnects = 1
//maximum number of auth attempts. should be 3 because while a connection
//is unlikely to be successful after a failure, it is possible to attempt
//one on a dead connection and then need to do it again after the connection
//is resuscitated
maxAuths = 2
)
// Sets up or recovers the Host's connection
// Then runs the given Send function
func (c *ProtoComms) Send(host *Host, f func(conn *grpc.ClientConn) (*any.Any,
error)) (result *any.Any, err error) {
if host.GetAddress() == "" {
return nil, errors.New("Host address is blank, host might be receive only.")
}
numConnects, numAuths, lastEvent := 0, 0, 0
connect:
// Ensure the connection is running
if !host.Connected() {
host.transmissionToken = nil
//do not attempt to connect again if multiple attempts have been made
if numConnects == maxConnects {
return nil, errors.WithMessage(err, "Maximum number of connects attempted")
}
//denote that a connection is being tried
lastEvent = con
//attempt to make the connection
jww.INFO.Printf("Host %s not connected, attempting to connect...", host.id.String())
err = host.connect()
//if connection cannot be made, do not retry
if err != nil {
return nil, errors.WithMessage(err, "Failed to connect")
}
//denote the connection attempt
numConnects++
}
authorize:
// Establish authentication if required
if host.authenticationRequired() && host.transmissionToken == nil {
//do not attempt to connect again if multiple attempts have been made
if numAuths == maxAuths {
return nil, errors.New("Maximum number of authorizations attempted")
}
//do not try multiple auths in a row
if lastEvent == auth {
return nil, errors.New("Cannot attempt to authorize with host multiple times in a row")
}
//denote that an auth is being tried
lastEvent = auth
jww.INFO.Printf("Attempting to establish authentication with host %s", host.id.String())
err = host.authenticate(c.clientHandshake)
if err != nil {
//if failure of connection, retry connection
if isConnError(err) {
jww.INFO.Printf("Failed to auth due to connection issue: %s", err)
goto connect
}
//otherwise, return the error
return nil, errors.New("Failed to authenticate")
}
//denote the authorization attempt
numAuths++
}
//denote that a send is being tried
lastEvent = send
// Attempt to send to host
result, err = host.send(f)
// If failed to authenticate, retry negotiation by jumping to the top of the loop
if err != nil {
//if failure of connection, retry connection
if isConnError(err) {
jww.INFO.Printf("Failed send due to connection issue: %s", err)
goto connect
}
// Handle resetting authentication
if strings.Contains(err.Error(), AuthError(host.id).Error()) {
jww.INFO.Printf("Failed send due to auth error, retrying authentication: %s", err.Error())
host.transmissionToken = nil
goto authorize
}
// otherwise, return the error
return nil, errors.WithMessage(err, "Failed to send")
}
return result, err
}
// returns true if the connection error is one of the connection errors which
// should be retried
func isConnError(err error) bool {
return strings.Contains(err.Error(), "context deadline exceeded") ||
strings.Contains(err.Error(), "connection refused")
}
// Sets up or recovers the Host's connection
// Then runs the given Stream function
func (c *ProtoComms) Stream(host *Host, f func(conn *grpc.ClientConn) (
interface{}, error)) (client interface{}, err error) {
// Ensure the connection is running
jww.TRACE.Printf("Attempting to send to host: %s", host)
if !host.Connected() {
err = host.connect()
if err != nil {
return
}
}
//establish authentication if required
if host.authenticationRequired() {
err = host.authenticate(c.clientHandshake)
if err != nil {
return
}
}
// Run the send function
return host.stream(f)
}
// RequestNdf is used to Request an ndf from permissioning
// Used by gateway, client, nodes and gateways
func (c *ProtoComms) RequestNdf(host *Host,
message *mixmessages.NDFHash) (*mixmessages.NDF, error) {
// Create the Send Function
f := func(conn *grpc.ClientConn) (*any.Any, error) {
// Set up the context
ctx, cancel := MessagingContext()
defer cancel()
authMsg, err := c.PackAuthenticatedMessage(message, host, false)
if err != nil {
return nil, errors.New(err.Error())
}
// Send the message
resultMsg, err := mixmessages.NewRegistrationClient(
conn).PollNdf(ctx, authMsg)
if err != nil {
return nil, errors.New(err.Error())
}
return ptypes.MarshalAny(resultMsg)
}
// Execute the Send function
jww.DEBUG.Printf("Sending Request Ndf message: %+v", message)
resultMsg, err := c.Send(host, f)
if err != nil {
return nil, err
}
result := &mixmessages.NDF{}
return result, ptypes.UnmarshalAny(resultMsg, result)
}
// RetrieveNdf, attempts to connect to the permissioning server to retrieve the latest ndf for the notifications bot
func (c *ProtoComms) RetrieveNdf(currentDef *ndf.NetworkDefinition) (*ndf.NetworkDefinition, error) {
//Hash the notifications bot ndf for comparison with registration's ndf
var ndfHash []byte
// If the ndf passed not nil, serialize and hash it
if currentDef != nil {
//Hash the notifications bot ndf for comparison with registration's ndf
hash := sha256.New()
ndfBytes, err := currentDef.Marshal()
if err != nil {
return nil, err
}
hash.Write(ndfBytes)
ndfHash = hash.Sum(nil)
}
//Put the hash in a message
msg := &mixmessages.NDFHash{Hash: ndfHash}
regHost, ok := c.Manager.GetHost(&id.Permissioning)
if !ok {
return nil, errors.New("Failed to find permissioning host")
}
//Send the hash to registration
response, err := c.RequestNdf(regHost, msg)
// Keep going until we get a grpc error or we get an ndf
for err != nil {
// If there is an unexpected error
if !strings.Contains(err.Error(), ndf.NO_NDF) {
// If it is not an issue with no ndf, return the error up the stack
errMsg := errors.Errorf("Failed to get ndf from permissioning: %v", err)
return nil, errMsg
}
// If the error is that the permissioning server is not ready, ask again
jww.WARN.Println("Failed to get an ndf, possibly not ready yet. Retying now...")
time.Sleep(250 * time.Millisecond)
response, err = c.RequestNdf(regHost, msg)
}
//If there was no error and the response is nil, client's ndf is up-to-date
if response == nil || response.Ndf == nil {
jww.DEBUG.Printf("Our NDF is up-to-date")
return nil, nil
}
jww.INFO.Printf("Remote NDF: %s", string(response.Ndf))
//Otherwise pull the ndf out of the response
updatedNdf, _, err := ndf.DecodeNDF(string(response.Ndf))
if err != nil {
//If there was an error decoding ndf
errMsg := errors.Errorf("Failed to decode response to ndf: %v", err)
return nil, errMsg
}
return updatedNdf, nil
}
This diff is collapsed.
///////////////////////////////////////////////////////////////////////////////
// Copyright © 2020 xx network SEZC //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file //
///////////////////////////////////////////////////////////////////////////////
// Contains functionality related to context objects
package connect
import (
jww "github.com/spf13/jwalterweatherman"
"golang.org/x/net/context"
"google.golang.org/grpc/peer"
"net"
"time"
)
// Used for creating connections with the default timeout
func ConnectionContext(seconds time.Duration) (context.Context, context.CancelFunc) {
waitingPeriod := seconds * time.Second
jww.DEBUG.Printf("Timing out in: %s", waitingPeriod)
ctx, cancel := context.WithTimeout(context.Background(),
waitingPeriod)
return ctx, cancel
}
// Used for sending messages with the default timeout
func MessagingContext() (context.Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(context.Background(),
2*time.Minute)
return ctx, cancel
}
// Creates a context object with the default context
// for all client streaming messages. This is primarily used to
// allow a cancel option for clients and is suitable for unary streaming.
func StreamingContext() (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(context.Background())
return ctx, cancel
}
// Obtain address:port from the context of an incoming communication
func GetAddressFromContext(ctx context.Context) (address string, port string, err error) {
info, _ := peer.FromContext(ctx)
address, port, err = net.SplitHostPort(info.Addr.String())
return
}
///////////////////////////////////////////////////////////////////////////////
// Copyright © 2020 xx network SEZC //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file //
///////////////////////////////////////////////////////////////////////////////
// Contains functionality for describing and creating connections
package connect
import (
"fmt"
"github.com/golang/protobuf/ptypes/any"
"github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/crypto/signature/rsa"
tlsCreds "gitlab.com/elixxir/crypto/tls"
"gitlab.com/elixxir/primitives/id"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"math"
"net"
"sync"
"testing"
"time"
)
// KaClientOpts are the keepalive options for clients
// TODO: Set via configuration
var KaClientOpts = keepalive.ClientParameters{
// Wait 1s before pinging to keepalive
Time: 1 * time.Second,
// 2s after ping before closing
Timeout: 2 * time.Second,
// For all connections, streaming and nonstreaming
PermitWithoutStream: true,
}
// Represents a reverse-authentication token
type Token []byte
// Information used to describe a connection to a host
type Host struct {
// System-wide ID of the Host
id *id.ID
// address:Port being connected to
address string
// PEM-format TLS Certificate
certificate []byte
/* Tokens shared with this Host establishing reverse authentication */
// Token used for receiving from this host
receptionToken Token
// Token used for sending to this host
transmissionToken Token
// Configure the maximum number of connection attempts
maxRetries int
// GRPC connection object
connection *grpc.ClientConn
// TLS credentials object used to establish the connection
credentials credentials.TransportCredentials
// RSA Public Key corresponding to the TLS Certificate
rsaPublicKey *rsa.PublicKey
// If set, reverse authentication will be established with this Host
enableAuth bool
// Indicates whether dynamic authentication was used for this Host
// This is useful for determining whether a Host's key was hardcoded
dynamicHost bool
// Read/Write Mutex for thread safety
mux sync.RWMutex
}
// Creates a new Host object
func NewHost(id *id.ID, address string, cert []byte, disableTimeout,
enableAuth bool) (host *Host, err error) {
// Initialize the Host object
host = &Host{
id: id,
address: address,
certificate: cert,
enableAuth: enableAuth,
}
// Set the max number of retries for establishing a connection
if disableTimeout {
host.maxRetries = math.MaxInt32
} else {
host.maxRetries = 100
}
// Configure the host credentials
err = host.setCredentials()
return
}
// Creates a new dynamic-authenticated Host object
func newDynamicHost(id *id.ID, publicKey []byte) (host *Host, err error) {
// Initialize the Host object
// IMPORTANT: This flag must be set to true for all dynamic Hosts
// because the security properties for these Hosts differ
host = &Host{
id: id,
dynamicHost: true,
}
// Create the RSA Public Key object
host.rsaPublicKey, err = rsa.LoadPublicKeyFromPem(publicKey)
if err != nil {
err = errors.Errorf("Error extracting PublicKey: %+v", err)
}
return
}
// Simple getter for the dynamicHost value
func (h *Host) IsDynamicHost() bool {
return h.dynamicHost
}
// Simple getter for the public key
func (h *Host) GetPubKey() *rsa.PublicKey {
return h.rsaPublicKey
}
// Connected checks if the given Host's connection is alive
func (h *Host) Connected() bool {
h.mux.RLock()
defer h.mux.RUnlock()
return h.isAlive()
}
// GetId returns the id of the host
func (h *Host) GetId() *id.ID {
return h.id
}
// GetAddress returns the address of the host.
func (h *Host) GetAddress() string {
return h.address
}
// UpdateAddress updates the address of the host
func (h *Host) UpdateAddress(address string) {
h.address = address
}
// Disconnect closes a the Host connection under the write lock
func (h *Host) Disconnect() {
h.mux.Lock()
defer h.mux.Unlock()
h.disconnect()
h.transmissionToken = nil
}
// Returns whether or not the Host is able to be contacted
// by attempting to dial a tcp connection
func (h *Host) IsOnline() bool {
addr := h.GetAddress()
timeout := 5 * time.Second
conn, err := net.DialTimeout("tcp", addr, timeout)
if err != nil {
// If we cannot connect, mark the node as failed
jww.DEBUG.Printf("Failed to verify connectivity for address %s", addr)
return false
}
// Attempt to close the connection
if conn != nil {
errClose := conn.Close()
if errClose != nil {
jww.DEBUG.Printf("Failed to close connection for address %s",
addr)
}
}
return true
}
// send checks that the host has a connection and sends if it does.
// Operates under the host's read lock.
func (h *Host) send(f func(conn *grpc.ClientConn) (*any.Any,
error)) (*any.Any, error) {
h.mux.RLock()
defer h.mux.RUnlock()
if !h.isAlive() {
return nil, errors.New("Could not send, connection is not alive")
}
a, err := f(h.connection)
return a, err
}
// stream checks that the host has a connection and streams if it does.
// Operates under the host's read lock.
func (h *Host) stream(f func(conn *grpc.ClientConn) (
interface{}, error)) (interface{}, error) {
h.mux.RLock()
defer h.mux.RUnlock()
if !h.isAlive() {
return nil, errors.New("Could not stream, connection is not alive")
}
a, err := f(h.connection)
return a, err
}
// connect attempts to connect to the host if it does not have a valid connection
func (h *Host) connect() error {
h.mux.Lock()
defer h.mux.Unlock()
//checks if the connection is active and skips reconnecting if it is
if h.isAlive() {
return nil
}
h.disconnect()
h.transmissionToken = nil
//connect to remote
if err := h.connectHelper(); err != nil {
return err
}
return nil
}
// authenticationRequired Checks if new authentication is required with
// the remote
func (h *Host) authenticationRequired() bool {
h.mux.RLock()
defer h.mux.RUnlock()
return h.enableAuth && h.transmissionToken == nil
}
// Checks if the given Host's connection is alive
func (h *Host) authenticate(handshake func(host *Host) error) error {
h.mux.Lock()
defer h.mux.Unlock()
return handshake(h)
}
// isAlive returns true if the connection is non-nil and alive
func (h *Host) isAlive() bool {
if h.connection == nil {
return false
}
state := h.connection.GetState()
return state == connectivity.Idle || state == connectivity.Connecting ||
state == connectivity.Ready
}
// disconnect closes a the Host connection while not under a write lock.
// undefined behavior if the caller has not taken the write lock
func (h *Host) disconnect() {
// its possible to close a host which never sent so it never made a
// connection. In that case, we should not close a connection which does not
// exist
if h.connection != nil {
err := h.connection.Close()
if err != nil {
jww.ERROR.Printf("Unable to close connection to %s: %+v",
h.address, errors.New(err.Error()))
} else {
h.connection = nil
}
}
}
// connectHelper creates a connection while not under a write lock.
// undefined behavior if the caller has not taken the write lock
func (h *Host) connectHelper() (err error) {
// Configure TLS options
var securityDial grpc.DialOption
if h.credentials != nil {
// Create the gRPC client with TLS
securityDial = grpc.WithTransportCredentials(h.credentials)
} else {
// Create the gRPC client without TLS
jww.WARN.Printf("Connecting to %v without TLS!", h.address)
securityDial = grpc.WithInsecure()
}
jww.DEBUG.Printf("Attempting to establish connection to %s using"+
" credentials: %+v", h.address, securityDial)
// Attempt to establish a new connection
for numRetries := 0; numRetries < h.maxRetries && !h.isAlive(); numRetries++ {
h.disconnect()
jww.INFO.Printf("Connecting to %+v. Attempt number %+v of %+v",
h.address, numRetries, h.maxRetries)
// If timeout is enabled, the max wait time becomes
// ~14 seconds (with maxRetries=100)
backoffTime := 2 * (numRetries/16 + 1)
if backoffTime > 15 {
backoffTime = 15
}
ctx, cancel := ConnectionContext(time.Duration(backoffTime))
// Create the connection
h.connection, err = grpc.DialContext(ctx, h.address,
securityDial,
grpc.WithBlock(),
grpc.WithKeepaliveParams(KaClientOpts))
if err != nil {
jww.ERROR.Printf("Attempt number %+v to connect to %s failed: %+v\n",
numRetries, h.address, errors.New(err.Error()))
}
cancel()
}
// Verify that the connection was established successfully
if !h.isAlive() {
h.disconnect()
return errors.New(fmt.Sprintf(
"Last try to connect to %s failed. Giving up", h.address))
}
// Add the successful connection to the Manager
jww.INFO.Printf("Successfully connected to %v", h.address)
return
}
// setCredentials sets TransportCredentials and RSA PublicKey objects
// using a PEM-encoded TLS Certificate
func (h *Host) setCredentials() error {
// If no TLS Certificate specified, print a warning and do nothing
if h.certificate == nil || len(h.certificate) == 0 {
jww.WARN.Printf("No TLS Certificate specified!")
return nil
}
// Obtain the DNS name included with the certificate
dnsName := ""
cert, err := tlsCreds.LoadCertificate(string(h.certificate))
if err != nil {
return errors.Errorf("Error forming transportCredentials: %+v", err)
}
if len(cert.DNSNames) > 0 {
dnsName = cert.DNSNames[0]
}
// Create the TLS Credentials object
h.credentials, err = tlsCreds.NewCredentialsFromPEM(string(h.certificate),
dnsName)
if err != nil {
return errors.Errorf("Error forming transportCredentials: %+v", err)
}
// Create the RSA Public Key object
h.rsaPublicKey, err = tlsCreds.NewPublicKeyFromPEM(h.certificate)
if err != nil {
err = errors.Errorf("Error extracting PublicKey: %+v", err)
}
return err
}
// Stringer interface for connection
func (h *Host) String() string {
h.mux.RLock()
defer h.mux.RUnlock()
addr := h.address
actualConnection := h.connection
creds := h.credentials
var state connectivity.State
if actualConnection != nil {
state = actualConnection.GetState()
}
serverName := "<nil>"
protocolVersion := "<nil>"
securityVersion := "<nil>"
securityProtocol := "<nil>"
if creds != nil {
serverName = creds.Info().ServerName
securityVersion = creds.Info().SecurityVersion
protocolVersion = creds.Info().ProtocolVersion
securityProtocol = creds.Info().SecurityProtocol
}
return fmt.Sprintf(
"ID: %v\tAddr: %v\tCertificate: %s...\tTransmission Token: %v"+
"\tReception Token: %+v \tEnableAuth: %v"+
"\tMaxRetries: %v\tConnState: %v"+
"\tTLS ServerName: %v\tTLS ProtocolVersion: %v\t"+
"TLS SecurityVersion: %v\tTLS SecurityProtocol: %v\n",
h.id, addr, h.certificate, h.transmissionToken,
h.receptionToken, h.enableAuth, h.maxRetries, state,
serverName, protocolVersion, securityVersion, securityProtocol)
}
func (h *Host) SetTestPublicKey(key *rsa.PublicKey, t interface{}) {
switch t.(type) {
case *testing.T:
break
case *testing.M:
break
default:
jww.FATAL.Panicf("SetTestPublicKey is restricted to testing only. Got %T", t)
}
h.rsaPublicKey = key
}
///////////////////////////////////////////////////////////////////////////////
// Copyright © 2020 xx network SEZC //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file //
///////////////////////////////////////////////////////////////////////////////
package connect
import (
"bytes"
"gitlab.com/elixxir/primitives/id"
"net"
"testing"
)
func TestHost_address(t *testing.T) {
var mgr Manager
testId := id.NewIdFromString("test", id.Node, t)
testAddress := "test"
host, err := mgr.AddHost(testId, testAddress, nil, false, false)
if err != nil {
t.Errorf("Unable to add host")
return
}
if host.address != testAddress {
t.Errorf("Expected addresses to match")
}
}
func TestHost_GetCertificate(t *testing.T) {
testCert := []byte("TEST")
host := Host{
address: "",
certificate: testCert,
maxRetries: 0,
connection: nil,
credentials: nil,
rsaPublicKey: nil,
}
if bytes.Compare(host.certificate, testCert) != 0 {
t.Errorf("Expected certs to match!")
}
}
// Tests that getID returns the correct ID
func TestHost_GetId(t *testing.T) {
testID := id.NewIdFromString("xXx_420No1337ScopeH4xx0r_xXx", id.Generic, t)
host := Host{
id: testID,
}
if host.GetId() != testID {
t.Errorf("Correct id not returned. Expected %s, got %s",
testID, host.id)
}
}
// Tests that GetAddress() returns the address of the host.
func TestHost_GetAddress(t *testing.T) {
// Test values
testAddress := "192.167.1.1:8080"
testHost := Host{address: testAddress}
// Test function
if testHost.GetAddress() != testAddress {
t.Errorf("GetAddress() did not return the expected address."+
"\n\texpected: %v\n\treceived: %v",
testAddress, testHost.GetAddress())
}
}
func TestHost_UpdateAddress(t *testing.T) {
testAddress := "192.167.1.1:8080"
testUpdatedAddress := "192.167.1.1:8080"
testHost := Host{address: testAddress}
// Test function
if testHost.GetAddress() != testAddress {
t.Errorf("GetAddress() did not return the expected address before update."+
"\n\texpected: %v\n\treceived: %v",
testAddress, testHost.GetAddress())
}
testHost.UpdateAddress(testUpdatedAddress)
if testHost.GetAddress() != testUpdatedAddress {
t.Errorf("GetAddress() did not return the expected address after update."+
"\n\texpected: %v\n\treceived: %v",
testUpdatedAddress, testHost.GetAddress())
}
}
// Validate that dynamic host defaults to false and can be set to true
func TestHost_IsDynamicHost(t *testing.T) {
host := Host{}
if host.IsDynamicHost() != false {
t.Errorf("Correct bool not returned. Expected false, got %v",
host.dynamicHost)
}
host.dynamicHost = true
if host.IsDynamicHost() != true {
t.Errorf("Correct bool not returned. Expected true, got %v",
host.dynamicHost)
}
}
// Full test
func TestHost_IsOnline(t *testing.T) {
addr := "0.0.0.0:10234"
// Create the host
host, err := NewHost(id.NewIdFromString("test", id.Gateway, t), addr, nil,
false,
false)
if err != nil {
t.Errorf("Unable to create host: %+v", host)
return
}
// Test that host is offline
if host.IsOnline() {
t.Errorf("Expected host to be offline!")
}
// Listen on the given address
lis, err := net.Listen("tcp", addr)
if err != nil {
t.Errorf("Unable to listen: %+v", err)
}
// Test that host is online
if !host.IsOnline() {
t.Errorf("Expected host to be online!")
}
// Close listening address
err = lis.Close()
if err != nil {
t.Errorf("Unable to close listening server: %+v", err)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment