diff --git a/connect/auth.go b/connect/auth.go index 2c8512a63cf154f08b51f3ea40f4ec5d50b88012..289f46ed30fb586ff6196fe0a23bff5860658ecb 100644 --- a/connect/auth.go +++ b/connect/auth.go @@ -14,11 +14,14 @@ import ( "context" "crypto/rand" "encoding/base64" + "fmt" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/crypto/signature/rsa" "gitlab.com/elixxir/crypto/xx" + "gitlab.com/xx_network/comms/connect/token" pb "gitlab.com/xx_network/comms/messages" "gitlab.com/xx_network/crypto/nonce" "gitlab.com/xx_network/crypto/signature/rsa" @@ -32,9 +35,13 @@ type Auth struct { IsAuthenticated bool // The information about the Host that sent the authenticated communication Sender *Host + // reason it isn't authenticated if authentication fails + Reason string } // Perform the client handshake to establish reverse-authentication +// no lock is taken because this is assumed to be done exclusively under the +// send lock taken in ProtoComms.transmit() func (c *ProtoComms) clientHandshake(host *Host) (err error) { // Set up the context @@ -49,6 +56,11 @@ func (c *ProtoComms) clientHandshake(host *Host) (err error) { return errors.New(err.Error()) } + remoteToken, err := token.Unmarshal(result.Token) + if err != nil { + return errors.Errorf("Failed to unmarshal token: %s", err) + } + // Pack the authenticated message with signature enabled msg, err := c.PackAuthenticatedMessage(&pb.AssignToken{ Token: result.Token, @@ -73,9 +85,9 @@ func (c *ProtoComms) clientHandshake(host *Host) (err error) { if err != nil { return errors.New(err.Error()) } - + jww.TRACE.Printf("Negotiatied Remote token: %v", remoteToken) // Assign the host token - host.transmissionToken = result.Token + host.transmissionToken.Set(remoteToken) return } @@ -93,7 +105,7 @@ func (c *ProtoComms) PackAuthenticatedMessage(msg proto.Message, host *Host, // Build the authenticated message authMsg := &pb.AuthenticatedMessage{ ID: c.Id.Marshal(), - Token: host.transmissionToken, + Token: host.transmissionToken.GetBytes(), Message: anyMsg, Client: &pb.ClientID{ Salt: make([]byte, 0), @@ -117,7 +129,8 @@ func (c *ProtoComms) PackAuthenticatedContext(host *Host, ctx context.Context) context.Context { ctx = metadata.AppendToOutgoingContext(ctx, "ID", c.Id.String()) - ctx = metadata.AppendToOutgoingContext(ctx, "TOKEN", base64.StdEncoding.EncodeToString(host.transmissionToken)) + ctx = metadata.AppendToOutgoingContext(ctx, "TOKEN", + base64.StdEncoding.EncodeToString(host.transmissionToken.GetBytes())) return ctx } @@ -140,7 +153,7 @@ func UnpackAuthenticatedContext(ctx context.Context) (*pb.AuthenticatedMessage, tokenStr := md.Get("TOKEN")[0] auth.Token, err = base64.StdEncoding.DecodeString(tokenStr) if err != nil { - return nil, errors.WithMessage(err, "could not decode authentication Token") + return nil, errors.WithMessage(err, "could not decode authentication Live") } return auth, nil @@ -148,14 +161,7 @@ func UnpackAuthenticatedContext(ctx context.Context) (*pb.AuthenticatedMessage, // Generates a new token and adds it to internal state func (c *ProtoComms) GenerateToken() ([]byte, error) { - token, err := nonce.NewNonce(nonce.RegistrationTTL) - if err != nil { - return nil, err - } - - c.tokens.Store(string(token.Bytes()), &token) - jww.DEBUG.Printf("Token generated: %v", token.Bytes()) - return token.Bytes(), nil + return c.tokens.Generate().Marshal(), nil } // Performs the dynamic authentication process such that Hosts that were not @@ -217,13 +223,6 @@ func (c *ProtoComms) ValidateToken(msg *pb.AuthenticatedMessage) (err error) { } } - // This logic prevents deadlocks when performing authentication with self - // TODO: This may require further review - if !msgID.Cmp(c.Id) || bytes.Compare(host.receptionToken, msg.Token) != 0 { - host.mux.Lock() - defer host.mux.Unlock() - } - // Get the signed token tokenMsg := &pb.AssignToken{} err = ptypes.UnmarshalAny(msg.Message, tokenMsg) @@ -231,6 +230,11 @@ func (c *ProtoComms) ValidateToken(msg *pb.AuthenticatedMessage) (err error) { return errors.Errorf("Unable to unmarshal token: %+v", err) } + remoteToken, err := token.Unmarshal(tokenMsg.Token) + if err != nil { + return errors.Errorf("Unable to unmarshal token: %+v", err) + } + // Verify the token signature unless disableAuth has been set for testing if !c.disableAuth { if err := c.verifyMessage(tokenMsg, msg.Signature, host); err != nil { @@ -238,20 +242,14 @@ func (c *ProtoComms) ValidateToken(msg *pb.AuthenticatedMessage) (err error) { } } - // Verify the signed token was actually assigned - token, ok := c.tokens.Load(string(tokenMsg.Token)) + ok = c.tokens.Validate(remoteToken) if !ok { - return errors.Errorf("Unable to locate token: %+v", msg.Token) + jww.ERROR.Printf("Failed to validate token %v from %s", remoteToken, host) + return errors.Errorf("Failed to validate token: %v", remoteToken) } - - // Verify the signed token is not expired - if !token.(*nonce.Nonce).IsValid() { - return errors.Errorf("Invalid or expired token: %+v", tokenMsg.Token) - } - // Token has been validated and can be safely stored - host.receptionToken = tokenMsg.Token - jww.DEBUG.Printf("Token validated: %v", tokenMsg.Token) + host.receptionToken.Set(remoteToken) + jww.DEBUG.Printf("Live validated: %v", tokenMsg.Token) return } @@ -261,7 +259,12 @@ func (c *ProtoComms) AuthenticatedReceiver(msg *pb.AuthenticatedMessage) (*Auth, // Convert EntityID to ID msgID, err := id.Unmarshal(msg.ID) if err != nil { - return nil, err + return &Auth{ + IsAuthenticated: false, + Sender: &Host{}, + Reason: fmt.Sprintf("Host {%v} cannot be "+ + "unmarshaled: %s", msg.ID, err), + }, nil } // Try to obtain the Host for the specified ID @@ -270,21 +273,46 @@ func (c *ProtoComms) AuthenticatedReceiver(msg *pb.AuthenticatedMessage) (*Auth, return &Auth{ IsAuthenticated: false, Sender: &Host{}, + Reason: fmt.Sprintf("Host {%s} cannot be found", msgID), + }, nil + } + + remoteToken, err := token.Unmarshal(msg.Token) + if err != nil { + return &Auth{ + IsAuthenticated: false, + Sender: host, + Reason: fmt.Sprintf("Token {%v} cannot be unmarshaled", msg.Token), }, nil } - // If host is found, mutex must be locked - host.mux.RLock() - defer host.mux.RUnlock() + // get the hosts reception token + receptionToken, ok := host.receptionToken.Get() + if !ok { + return &Auth{ + IsAuthenticated: false, + Sender: host, + Reason: fmt.Sprintf("failed to authenticate token %v, "+ + "no reception token for %s", remoteToken, host.id), + }, nil + } - // Check the token's validity - validToken := host.receptionToken != nil && msg.Token != nil && - bytes.Compare(host.receptionToken, msg.Token) == 0 + // check if the tokens are the same + if !receptionToken.Equals(remoteToken) { + return &Auth{ + IsAuthenticated: false, + Sender: host, + Reason: fmt.Sprintf("failed to authenticate token %v, "+ + "does not match reception token %v for %s", remoteToken, + receptionToken, host.id), + }, nil + } // Assemble the Auth object res := &Auth{ - IsAuthenticated: validToken, + IsAuthenticated: true, Sender: host, + Reason: "authenticated", } jww.TRACE.Printf("Authentication status: %v, ProvidedId: %v ProvidedToken: %v", diff --git a/connect/auth_test.go b/connect/auth_test.go index 62ea1f033abc1fde87c8742903e035af7dfd83b4..024f28024d32b9286670efafe0b41f984c52995d 100644 --- a/connect/auth_test.go +++ b/connect/auth_test.go @@ -12,6 +12,7 @@ import ( "github.com/golang/protobuf/ptypes" "gitlab.com/elixxir/crypto/csprng" "gitlab.com/elixxir/crypto/xx" + token "gitlab.com/xx_network/comms/connect/token" pb "gitlab.com/xx_network/comms/messages" "gitlab.com/xx_network/comms/testkeys" "gitlab.com/xx_network/crypto/signature/rsa" @@ -60,30 +61,32 @@ func TestSignVerify(t *testing.T) { func TestProtoComms_AuthenticatedReceiver(t *testing.T) { // Create comm object pc := ProtoComms{ - Manager: Manager{}, - tokens: sync.Map{}, + Manager: newManager(), + tokens: token.NewMap(), LocalServer: nil, ListeningAddr: "", privateKey: nil, } // Create id and token testID := id.NewIdFromString("testSender", id.Node, t) - token := []byte("testtoken") + expectedVal := []byte("testToken") + tkn := token.Token{} + copy(tkn[:], expectedVal) // Add host - _, err := pc.AddHost(testID, "", nil, false, true) + _, err := pc.AddHost(testID, "", nil, GetDefaultHostParams()) if err != nil { t.Errorf("Failed to add host: %+v", err) } // Get host h, _ := pc.GetHost(testID) - h.receptionToken = token + h.receptionToken.Set(tkn) msg := &pb.AuthenticatedMessage{ ID: testID.Marshal(), Signature: nil, - Token: token, + Token: tkn.Marshal(), Message: nil, } @@ -98,8 +101,9 @@ func TestProtoComms_AuthenticatedReceiver(t *testing.T) { } // Compare the tokens - if !bytes.Equal(auth.Sender.receptionToken, token) { - t.Errorf("Tokens do not match! \n\tExptected: %+v\n\tReceived: %+v", token, auth.Sender.receptionToken) + if !bytes.Equal(auth.Sender.receptionToken.GetBytes(), tkn[:]) { + t.Errorf("Tokens do not match! \n\tExptected: "+ + "%+v\n\tReceived: %+v", tkn, auth.Sender.receptionToken) } } @@ -107,49 +111,51 @@ func TestProtoComms_AuthenticatedReceiver(t *testing.T) { func TestProtoComms_AuthenticatedReceiver_BadId(t *testing.T) { // Create comm object pc := ProtoComms{ - Manager: Manager{}, - tokens: sync.Map{}, + Manager: newManager(), + tokens: token.NewMap(), LocalServer: nil, ListeningAddr: "", privateKey: nil, } // Create id and token testID := id.NewIdFromString("testSender", id.Node, t) - token := []byte("testtoken") + expectedVal := []byte("testToken") + tkn := token.Token{} + copy(tkn[:], expectedVal) // Add host - _, err := pc.AddHost(testID, "", nil, false, true) + _, err := pc.AddHost(testID, "", nil, GetDefaultHostParams()) if err != nil { t.Errorf("Failed to add host: %+v", err) } // Get host h, _ := pc.GetHost(testID) - h.receptionToken = token + h.receptionToken.Set(tkn) badId := []byte("badID") msg := &pb.AuthenticatedMessage{ ID: badId, Signature: nil, - Token: token, + Token: tkn.Marshal(), Message: nil, } // Try the authenticated received - _, err = pc.AuthenticatedReceiver(msg) - if err != nil { - return - } + a, _ := pc.AuthenticatedReceiver(msg) - t.Errorf("Expected error path!"+ - "Should not be able to marshal a message with id: %v", badId) + if a.IsAuthenticated { + t.Errorf("Expected error path!"+ + "Should not be able to marshal a message with id: %v", badId) + } } // Happy path func TestProtoComms_GenerateToken(t *testing.T) { comm := ProtoComms{ + tokens: token.NewMap(), LocalServer: nil, ListeningAddr: "", privateKey: nil, @@ -159,9 +165,14 @@ func TestProtoComms_GenerateToken(t *testing.T) { t.Errorf("Unable to generate token: %+v", err) } - token, ok := comm.tokens.Load(string(tokenBytes)) - if !ok || token == nil { - t.Errorf("Unable to find token stored in internal map") + tkn, err := token.Unmarshal(tokenBytes) + if err != nil { + t.Errorf("Should be able to unmarshal token: %s", err) + } + + ok := comm.tokens.Validate(tkn) + if !ok { + t.Errorf("Unable to validate token") } } @@ -173,6 +184,7 @@ func TestProtoComms_PackAuthenticatedMessage(t *testing.T) { LocalServer: nil, ListeningAddr: "", privateKey: nil, + tokens: token.NewMap(), } tokenBytes, err := comm.GenerateToken() @@ -180,13 +192,16 @@ func TestProtoComms_PackAuthenticatedMessage(t *testing.T) { t.Errorf("Unable to generate token: %+v", err) } + tkn := token.Token{} + copy(tkn[:], tokenBytes) + testId := id.NewIdFromString("test", id.Node, t) - host, err := NewHost(testId, "test", nil, false, true) + host, err := NewHost(testId, "test", nil, GetDefaultHostParams()) if err != nil { t.Errorf("Unable to create host: %+v", err) } - host.transmissionToken = tokenBytes + host.transmissionToken.Set(tkn) tokenMsg := &pb.AssignToken{ Token: tokenBytes, @@ -198,7 +213,7 @@ func TestProtoComms_PackAuthenticatedMessage(t *testing.T) { } // Compare the tokens and id's if bytes.Compare(msg.Token, tokenBytes) != 0 || !bytes.Equal(msg.ID, testServerId.Marshal()) { - t.Errorf("Expected packed message to have correct ID and Token: %+v", + t.Errorf("Expected packed message to have correct ID and Live: %+v", msg) } } @@ -211,6 +226,8 @@ func TestProtoComms_ValidateToken(t *testing.T) { LocalServer: nil, ListeningAddr: "", privateKey: nil, + tokens: token.NewMap(), + Manager: newManager(), } err := comm.setPrivateKey(testkeys.LoadFromPath(testkeys.GetNodeKeyPath())) if err != nil { @@ -221,13 +238,15 @@ func TestProtoComms_ValidateToken(t *testing.T) { if err != nil || tokenBytes == nil { t.Errorf("Unable to generate token: %+v", err) } + tkn := token.Token{} + copy(tkn[:], tokenBytes) pub := testkeys.LoadFromPath(testkeys.GetNodeCertPath()) - host, err := comm.AddHost(testId, "test", pub, false, true) + host, err := comm.AddHost(testId, "test", pub, GetDefaultHostParams()) if err != nil { t.Errorf("Unable to create host: %+v", err) } - host.transmissionToken = tokenBytes + host.transmissionToken.Set(tkn) tokenMsg := &pb.AssignToken{ Token: tokenBytes, @@ -246,7 +265,7 @@ func TestProtoComms_ValidateToken(t *testing.T) { t.Errorf("Expected to validate token: %+v", err) } - if !bytes.Equal(msg.Token, host.transmissionToken) { + if !bytes.Equal(msg.Token, host.transmissionToken.GetBytes()) { t.Errorf("Message token doesn't match message's token! "+ "Expected: %+v"+ "\n\tReceived: %+v", host.transmissionToken, msg.Token) @@ -261,6 +280,8 @@ func TestProtoComms_ValidateToken_BadId(t *testing.T) { LocalServer: nil, ListeningAddr: "", privateKey: nil, + tokens: token.NewMap(), + Manager: newManager(), } err := comm.setPrivateKey(testkeys.LoadFromPath(testkeys.GetNodeKeyPath())) if err != nil { @@ -271,13 +292,15 @@ func TestProtoComms_ValidateToken_BadId(t *testing.T) { if err != nil || tokenBytes == nil { t.Errorf("Unable to generate token: %+v", err) } + tkn := token.Token{} + copy(tkn[:], tokenBytes) pub := testkeys.LoadFromPath(testkeys.GetNodeCertPath()) - host, err := comm.AddHost(testId, "test", pub, false, true) + host, err := comm.AddHost(testId, "test", pub, GetDefaultHostParams()) if err != nil { t.Errorf("Unable to create host: %+v", err) } - host.transmissionToken = tokenBytes + host.transmissionToken.Set(tkn) tokenMsg := &pb.AssignToken{ Token: tokenBytes, @@ -326,6 +349,8 @@ func TestProtoComms_ValidateTokenDynamic(t *testing.T) { comm := ProtoComms{ Id: uid, ListeningAddr: "", + tokens: token.NewMap(), + Manager: newManager(), } err = comm.setPrivateKey(rsa.CreatePrivateKeyPem(privKey)) if err != nil { @@ -336,6 +361,8 @@ func TestProtoComms_ValidateTokenDynamic(t *testing.T) { if err != nil || tokenBytes == nil { t.Errorf("Unable to generate token: %+v", err) } + tkn := token.Token{} + copy(tkn[:], tokenBytes) // For this test we won't addHost to Manager, we'll just create a host // so we can compare to the dynamic one later @@ -343,7 +370,7 @@ func TestProtoComms_ValidateTokenDynamic(t *testing.T) { if err != nil { t.Errorf("Unable to create host: %+v", err) } - host.transmissionToken = tokenBytes + host.transmissionToken.Set(tkn) tokenMsg := &pb.AssignToken{ Token: tokenBytes, } @@ -373,7 +400,7 @@ func TestProtoComms_ValidateTokenDynamic(t *testing.T) { t.Errorf("Expected host to be dynamic!") } - if !bytes.Equal(msg.Token, host.receptionToken) { + if !bytes.Equal(msg.Token, host.receptionToken.GetBytes()) { t.Errorf("Message token doesn't match message's token! "+ "Expected: %+v"+ "\n\tReceived: %+v", host.receptionToken, msg.Token) diff --git a/connect/circuit_test.go b/connect/circuit_test.go index c95ea8a22caff9521e7b80d84b545812bca8f169..780f0e921f6f0529dffa4483941f740c24f94c55 100644 --- a/connect/circuit_test.go +++ b/connect/circuit_test.go @@ -218,7 +218,7 @@ func TestCircuit_GetHostAtIndex(t *testing.T) { cert, _ := utils.ReadFile(testkeys.GetNodeCertPath()) circuit := NewCircuit(nodeIdList) testID := id.NewIdFromString("test", id.Generic, t) - testHost, _ := NewHost(testID, "test", cert, false, false) + testHost, _ := NewHost(testID, "test", cert, GetDefaultHostParams()) circuit.AddHost(testHost) if !reflect.DeepEqual(circuit.hosts[0], testHost) { diff --git a/connect/comms.go b/connect/comms.go index ad7bfc9660f5fa6e35c83d5c7b48098f06185231..97cf6de7a2ca0e001525126e9828299e36fab7a6 100644 --- a/connect/comms.go +++ b/connect/comms.go @@ -16,13 +16,13 @@ import ( jww "github.com/spf13/jwalterweatherman" "gitlab.com/xx_network/crypto/signature/rsa" "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/comms/connect/token" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" "math" "net" "strings" - "sync" "time" ) @@ -56,7 +56,7 @@ var MaxConcurrentStreams = uint32(250000) // Proto object containing a gRPC server type ProtoComms struct { // Inherit the Manager object - Manager + *Manager // The network ID of this comms server Id *id.ID @@ -70,7 +70,7 @@ type ProtoComms struct { // SERVER-ONLY FIELDS ------------------------------------------------------ // A map of reverse-authentication tokens - tokens sync.Map + tokens *token.Map // Local network server LocalServer *grpc.Server @@ -97,6 +97,8 @@ func CreateCommClient(id *id.ID, pubKeyPem, privKeyPem, Id: id, pubKeyPem: pubKeyPem, salt: salt, + tokens: token.NewMap(), + Manager: newManager(), } // Set the private key if specified @@ -117,6 +119,8 @@ func StartCommServer(id *id.ID, localServer string, certPEMblock, pc := &ProtoComms{ Id: id, ListeningAddr: localServer, + tokens: token.NewMap(), + Manager: newManager(), } listen: @@ -201,14 +205,6 @@ const ( con = 1 auth = 2 send = 3 - //maximum number of connection attempts. should be 1 because a connection - //is unlikely to be successful after a failure - maxConnects = 1 - //maximum number of auth attempts. should be 3 because while a connection - //is unlikely to be successful after a failure, it is possible to attempt - //one on a dead connection and then need to do it again after the connection - //is resuscitated - maxAuths = 2 ) // Sets up or recovers the Host's connection @@ -216,99 +212,17 @@ const ( func (c *ProtoComms) Send(host *Host, f func(conn *grpc.ClientConn) (*any.Any, error)) (result *any.Any, err error) { - if host.GetAddress() == "" { - return nil, errors.New("Host address is blank, host might be receive only.") - } - - numConnects, numAuths, lastEvent := 0, 0, 0 - -connect: - // Ensure the connection is running - if !host.Connected() { - host.transmissionToken = nil - //do not attempt to connect again if multiple attempts have been made - if numConnects == maxConnects { - return nil, errors.WithMessage(err, "Maximum number of connects attempted") - } - - //denote that a connection is being tried - lastEvent = con - - //attempt to make the connection - jww.INFO.Printf("Host %s not connected, attempting to connect...", host.id.String()) - err = host.connect() - //if connection cannot be made, do not retry - if err != nil { - return nil, errors.WithMessage(err, "Failed to connect") - } - - //denote the connection attempt - numConnects++ - } - -authorize: - // Establish authentication if required - if host.authenticationRequired() && host.transmissionToken == nil { - //do not attempt to connect again if multiple attempts have been made - if numAuths == maxAuths { - return nil, errors.New("Maximum number of authorizations attempted") - } - - //do not try multiple auths in a row - if lastEvent == auth { - return nil, errors.New("Cannot attempt to authorize with host multiple times in a row") - } - - //denote that an auth is being tried - lastEvent = auth - - jww.INFO.Printf("Attempting to establish authentication with host %s", host.id.String()) - err = host.authenticate(c.clientHandshake) - if err != nil { - //if failure of connection, retry connection - if isConnError(err) { - jww.INFO.Printf("Failed to auth due to connection issue: %s", err) - goto connect - } - //otherwise, return the error - return nil, errors.New("Failed to authenticate") - } - - //denote the authorization attempt - numAuths++ + jww.TRACE.Printf("Attempting to send to host: %s", host) + fSh := func(conn *grpc.ClientConn) (interface{}, error) { + return f(conn) } - //denote that a send is being tried - lastEvent = send - // Attempt to send to host - result, err = host.send(f) - // If failed to authenticate, retry negotiation by jumping to the top of the loop + anyFace, err := c.transmit(host, fSh) if err != nil { - //if failure of connection, retry connection - if isConnError(err) { - jww.INFO.Printf("Failed send due to connection issue: %s", err) - goto connect - } - - // Handle resetting authentication - if strings.Contains(err.Error(), AuthError(host.id).Error()) { - jww.INFO.Printf("Failed send due to auth error, retrying authentication: %s", err.Error()) - host.transmissionToken = nil - goto authorize - } - - // otherwise, return the error - return nil, errors.WithMessage(err, "Failed to send") + return nil, err } - return result, err -} - -// returns true if the connection error is one of the connection errors which -// should be retried -func isConnError(err error) bool { - return strings.Contains(err.Error(), "context deadline exceeded") || - strings.Contains(err.Error(), "connection refused") + return anyFace.(*any.Any), err } // Sets up or recovers the Host's connection @@ -317,22 +231,13 @@ func (c *ProtoComms) Stream(host *Host, f func(conn *grpc.ClientConn) ( interface{}, error)) (client interface{}, err error) { // Ensure the connection is running - jww.TRACE.Printf("Attempting to send to host: %s", host) - if !host.Connected() { - err = host.connect() - if err != nil { - return - } - } - - //establish authentication if required - if host.authenticationRequired() { - err = host.authenticate(c.clientHandshake) - if err != nil { - return - } - } + jww.TRACE.Printf("Attempting to stream to host: %s", host) + return c.transmit(host, f) +} - // Run the send function - return host.stream(f) +// returns true if the connection error is one of the connection errors which +// should be retried +func isConnError(err error) bool { + return strings.Contains(err.Error(), "context deadline exceeded") || + strings.Contains(err.Error(), "connection refused") } diff --git a/connect/host.go b/connect/host.go index 184ad1865ef10b27d28691bc7adc6629b193be96..0622c11d508c6b81a1bd1307d89ce6d017ba184f 100644 --- a/connect/host.go +++ b/connect/host.go @@ -11,12 +11,12 @@ package connect import ( "fmt" - "github.com/golang/protobuf/ptypes/any" "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/xx_network/crypto/signature/rsa" tlsCreds "gitlab.com/xx_network/crypto/tls" "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/comms/connect/token" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" @@ -24,6 +24,7 @@ import ( "math" "net" "sync" + "sync/atomic" "testing" "time" ) @@ -39,33 +40,31 @@ var KaClientOpts = keepalive.ClientParameters{ PermitWithoutStream: true, } -// Represents a reverse-authentication token -type Token []byte - // Information used to describe a connection to a host type Host struct { // System-wide ID of the Host id *id.ID // address:Port being connected to - address string + addressAtomic atomic.Value // PEM-format TLS Certificate certificate []byte /* Tokens shared with this Host establishing reverse authentication */ - // Token used for receiving from this host - receptionToken Token + // Live used for receiving from this host + receptionToken *token.Live - // Token used for sending to this host - transmissionToken Token + // Live used for sending to this host + transmissionToken *token.Live // Configure the maximum number of connection attempts - maxRetries int + maxRetries uint32 // GRPC connection object - connection *grpc.ClientConn + connection *grpc.ClientConn + connectionCount uint64 // TLS credentials object used to establish the connection credentials credentials.TransportCredentials @@ -80,29 +79,29 @@ type Host struct { // This is useful for determining whether a Host's key was hardcoded dynamicHost bool - // Read/Write Mutex for thread safety - mux sync.RWMutex + // Send lock + sendMux sync.RWMutex } // Creates a new Host object -func NewHost(id *id.ID, address string, cert []byte, disableTimeout, - enableAuth bool) (host *Host, err error) { +func NewHost(id *id.ID, address string, cert []byte, params HostParams) (host *Host, err error) { // Initialize the Host object host = &Host{ - id: id, - address: address, - certificate: cert, - enableAuth: enableAuth, + id: id, + certificate: cert, + enableAuth: params.AuthEnabled, + transmissionToken: token.NewLive(), + receptionToken: token.NewLive(), + maxRetries: params.MaxRetries, } - // Set the max number of retries for establishing a connection - if disableTimeout { - host.maxRetries = math.MaxInt32 - } else { - host.maxRetries = 100 + if host.maxRetries == 0 { + host.maxRetries = math.MaxUint32 } + host.UpdateAddress(address) + // Configure the host credentials err = host.setCredentials() return @@ -115,8 +114,10 @@ func newDynamicHost(id *id.ID, publicKey []byte) (host *Host, err error) { // IMPORTANT: This flag must be set to true for all dynamic Hosts // because the security properties for these Hosts differ host = &Host{ - id: id, - dynamicHost: true, + id: id, + dynamicHost: true, + transmissionToken: token.NewLive(), + receptionToken: token.NewLive(), } // Create the RSA Public Key object @@ -138,11 +139,12 @@ func (h *Host) GetPubKey() *rsa.PublicKey { } // Connected checks if the given Host's connection is alive -func (h *Host) Connected() bool { - h.mux.RLock() - defer h.mux.RUnlock() +// the uint is the connection count, it increments every time a reconnect occurs +func (h *Host) Connected() (bool, uint64) { + h.sendMux.RLock() + defer h.sendMux.RUnlock() - return h.isAlive() + return h.isAlive() && !h.authenticationRequired(), h.connectionCount } // GetId returns the id of the host @@ -152,21 +154,35 @@ func (h *Host) GetId() *id.ID { // GetAddress returns the address of the host. func (h *Host) GetAddress() string { - return h.address + a := h.addressAtomic.Load() + if a == nil { + return "" + } + return a.(string) } // UpdateAddress updates the address of the host func (h *Host) UpdateAddress(address string) { - h.address = address + h.addressAtomic.Store(address) } // Disconnect closes a the Host connection under the write lock func (h *Host) Disconnect() { - h.mux.Lock() - defer h.mux.Unlock() + h.sendMux.Lock() + defer h.sendMux.Unlock() h.disconnect() - h.transmissionToken = nil +} + +// ConditionalDisconnect closes a the Host connection under the write lock only +// if the connection count has not increased +func (h *Host) conditionalDisconnect(count uint64) { + h.sendMux.Lock() + defer h.sendMux.Unlock() + + if count == h.connectionCount { + h.disconnect() + } } // Returns whether or not the Host is able to be contacted @@ -193,72 +209,34 @@ func (h *Host) IsOnline() bool { // send checks that the host has a connection and sends if it does. // Operates under the host's read lock. -func (h *Host) send(f func(conn *grpc.ClientConn) (*any.Any, - error)) (*any.Any, error) { - - h.mux.RLock() - defer h.mux.RUnlock() - - if !h.isAlive() { - return nil, errors.New("Could not send, connection is not alive") - } +func (h *Host) transmit(f func(conn *grpc.ClientConn) (interface{}, + error)) (interface{}, error) { + h.sendMux.RLock() a, err := f(h.connection) - return a, err -} - -// stream checks that the host has a connection and streams if it does. -// Operates under the host's read lock. -func (h *Host) stream(f func(conn *grpc.ClientConn) ( - interface{}, error)) (interface{}, error) { + h.sendMux.RUnlock() - h.mux.RLock() - defer h.mux.RUnlock() - - if !h.isAlive() { - return nil, errors.New("Could not stream, connection is not alive") - } - - a, err := f(h.connection) return a, err } // connect attempts to connect to the host if it does not have a valid connection func (h *Host) connect() error { - h.mux.Lock() - defer h.mux.Unlock() - - //checks if the connection is active and skips reconnecting if it is - if h.isAlive() { - return nil - } - - h.disconnect() - h.transmissionToken = nil //connect to remote if err := h.connectHelper(); err != nil { return err } + h.connectionCount++ + return nil } // authenticationRequired Checks if new authentication is required with -// the remote +// the remote. This is used exclusively under the lock in protocoms.transmit so +// no lock is needed func (h *Host) authenticationRequired() bool { - h.mux.RLock() - defer h.mux.RUnlock() - - return h.enableAuth && h.transmissionToken == nil -} - -// Checks if the given Host's connection is alive -func (h *Host) authenticate(handshake func(host *Host) error) error { - h.mux.Lock() - defer h.mux.Unlock() - - return handshake(h) + return h.enableAuth && !h.transmissionToken.Has() } // isAlive returns true if the connection is non-nil and alive @@ -281,11 +259,12 @@ func (h *Host) disconnect() { err := h.connection.Close() if err != nil { jww.ERROR.Printf("Unable to close connection to %s: %+v", - h.address, errors.New(err.Error())) + h.GetAddress(), errors.New(err.Error())) } else { h.connection = nil } } + h.transmissionToken.Clear() } // connectHelper creates a connection while not under a write lock. @@ -299,19 +278,20 @@ func (h *Host) connectHelper() (err error) { securityDial = grpc.WithTransportCredentials(h.credentials) } else { // Create the gRPC client without TLS - jww.WARN.Printf("Connecting to %v without TLS!", h.address) + jww.WARN.Printf("Connecting to %v without TLS!", h.GetAddress()) securityDial = grpc.WithInsecure() } jww.DEBUG.Printf("Attempting to establish connection to %s using"+ - " credentials: %+v", h.address, securityDial) + " credentials: %+v", h.GetAddress(), securityDial) // Attempt to establish a new connection - for numRetries := 0; numRetries < h.maxRetries && !h.isAlive(); numRetries++ { + var numRetries uint32 + for numRetries = 0; numRetries < h.maxRetries && !h.isAlive(); numRetries++ { h.disconnect() jww.INFO.Printf("Connecting to %+v. Attempt number %+v of %+v", - h.address, numRetries, h.maxRetries) + h.GetAddress(), numRetries, h.maxRetries) // If timeout is enabled, the max wait time becomes // ~14 seconds (with maxRetries=100) @@ -322,13 +302,13 @@ func (h *Host) connectHelper() (err error) { ctx, cancel := ConnectionContext(time.Duration(backoffTime)) // Create the connection - h.connection, err = grpc.DialContext(ctx, h.address, + h.connection, err = grpc.DialContext(ctx, h.GetAddress(), securityDial, grpc.WithBlock(), grpc.WithKeepaliveParams(KaClientOpts)) if err != nil { jww.ERROR.Printf("Attempt number %+v to connect to %s failed: %+v\n", - numRetries, h.address, errors.New(err.Error())) + numRetries, h.GetAddress(), errors.New(err.Error())) } cancel() } @@ -337,11 +317,12 @@ func (h *Host) connectHelper() (err error) { if !h.isAlive() { h.disconnect() return errors.New(fmt.Sprintf( - "Last try to connect to %s failed. Giving up", h.address)) + "Last try to connect to %s failed. Giving up", + h.GetAddress())) } // Add the successful connection to the Manager - jww.INFO.Printf("Successfully connected to %v", h.address) + jww.INFO.Printf("Successfully connected to %v", h.GetAddress()) return } @@ -383,9 +364,9 @@ func (h *Host) setCredentials() error { // Stringer interface for connection func (h *Host) String() string { - h.mux.RLock() - defer h.mux.RUnlock() - addr := h.address + h.sendMux.RLock() + defer h.sendMux.RUnlock() + addr := h.GetAddress() actualConnection := h.connection creds := h.credentials @@ -405,13 +386,13 @@ func (h *Host) String() string { securityProtocol = creds.Info().SecurityProtocol } return fmt.Sprintf( - "ID: %v\tAddr: %v\tCertificate: %s...\tTransmission Token: %v"+ - "\tReception Token: %+v \tEnableAuth: %v"+ + "ID: %v\tAddr: %v\tCertificate: %s...\tTransmission Live: %v"+ + "\tReception Live: %+v \tEnableAuth: %v"+ "\tMaxRetries: %v\tConnState: %v"+ "\tTLS ServerName: %v\tTLS ProtocolVersion: %v\t"+ "TLS SecurityVersion: %v\tTLS SecurityProtocol: %v\n", - h.id, addr, h.certificate, h.transmissionToken, - h.receptionToken, h.enableAuth, h.maxRetries, state, + h.id, addr, h.certificate, h.transmissionToken.GetBytes(), + h.receptionToken.GetBytes(), h.enableAuth, h.maxRetries, state, serverName, protocolVersion, securityVersion, securityProtocol) } diff --git a/connect/hostParams.go b/connect/hostParams.go new file mode 100644 index 0000000000000000000000000000000000000000..437b91c8e3c24a0d813cb8b0d12531713e762cf5 --- /dev/null +++ b/connect/hostParams.go @@ -0,0 +1,15 @@ +package connect + +// Params object for host creation +type HostParams struct { + MaxRetries uint32 + AuthEnabled bool +} + +// Get default set of host params +func GetDefaultHostParams() HostParams { + return HostParams{ + MaxRetries: 100, + AuthEnabled: true, + } +} diff --git a/connect/host_test.go b/connect/host_test.go index f90f00bbed82fe3dfd3dc9bf6e0e372350616221..f9be8a2d0ae23c262744fc3414f95447f91c9fe1 100644 --- a/connect/host_test.go +++ b/connect/host_test.go @@ -15,16 +15,16 @@ import ( ) func TestHost_address(t *testing.T) { - var mgr Manager + mgr := newManager() testId := id.NewIdFromString("test", id.Node, t) testAddress := "test" - host, err := mgr.AddHost(testId, testAddress, nil, false, false) + host, err := mgr.AddHost(testId, testAddress, nil, GetDefaultHostParams()) if err != nil { t.Errorf("Unable to add host") return } - if host.address != testAddress { + if host.GetAddress() != testAddress { t.Errorf("Expected addresses to match") } } @@ -33,7 +33,6 @@ func TestHost_GetCertificate(t *testing.T) { testCert := []byte("TEST") host := Host{ - address: "", certificate: testCert, maxRetries: 0, connection: nil, @@ -65,7 +64,8 @@ func TestHost_GetId(t *testing.T) { func TestHost_GetAddress(t *testing.T) { // Test values testAddress := "192.167.1.1:8080" - testHost := Host{address: testAddress} + testHost := Host{} + testHost.UpdateAddress(testAddress) // Test function if testHost.GetAddress() != testAddress { @@ -78,7 +78,8 @@ func TestHost_GetAddress(t *testing.T) { func TestHost_UpdateAddress(t *testing.T) { testAddress := "192.167.1.1:8080" testUpdatedAddress := "192.167.1.1:8080" - testHost := Host{address: testAddress} + testHost := Host{} + testHost.UpdateAddress(testAddress) // Test function if testHost.GetAddress() != testAddress { @@ -120,8 +121,7 @@ func TestHost_IsOnline(t *testing.T) { // Create the host host, err := NewHost(id.NewIdFromString("test", id.Gateway, t), addr, nil, - false, - false) + GetDefaultHostParams()) if err != nil { t.Errorf("Unable to create host: %+v", host) return diff --git a/connect/manager.go b/connect/manager.go index 19fc3119b553297215a2b00ea6563fa29b5a1883..69e66d336f89a95f94a953ae6dad179a4e48bca6 100644 --- a/connect/manager.go +++ b/connect/manager.go @@ -15,67 +15,93 @@ import ( jww "github.com/spf13/jwalterweatherman" "gitlab.com/xx_network/primitives/id" "sync" + "testing" ) // The Manager object provides thread-safe access // to Host objects for top-level libraries type Manager struct { // A map of id.IDs to Hosts - connections sync.Map + connections map[id.ID]*Host + mux sync.RWMutex +} + +func newManager() *Manager { + return &Manager{ + connections: make(map[id.ID]*Host), + mux: sync.RWMutex{}, + } +} + +func NewManagerTesting(t *testing.T) *Manager { + if t == nil { + jww.FATAL.Panicf("NewMangerTesting is for testing only") + } + return newManager() } // Fetch a Host from the internal map func (m *Manager) GetHost(hostId *id.ID) (*Host, bool) { - value, ok := m.connections.Load(*hostId) + m.mux.RLock() + defer m.mux.RUnlock() + host, ok := m.connections[*hostId] if !ok { return nil, false } - host, ok := value.(*Host) return host, ok } // Creates and adds a Host object to the Manager using the given id -func (m *Manager) AddHost(id *id.ID, address string, - cert []byte, disableTimeout, enableAuth bool) (host *Host, err error) { +func (m *Manager) AddHost(hid *id.ID, address string, + cert []byte, params HostParams) (host *Host, err error) { + m.mux.Lock() + defer m.mux.Unlock() + + //check if the host already exists, if it does return it + host, ok := m.connections[*hid] + if ok { + return host, nil + } - host, err = NewHost(id, address, cert, disableTimeout, enableAuth) + //create the new host + host, err = NewHost(hid, address, cert, params) if err != nil { return nil, err } + //add the host to the map m.addHost(host) - return -} -// Removes a host from the connection manager -func (m *Manager) RemoveHost(id *id.ID) { - jww.DEBUG.Printf("Removing host: %v", id) - m.connections.Delete(*id) + return host, nil } -// Internal helper function that can add Hosts directly func (m *Manager) addHost(host *Host) { jww.DEBUG.Printf("Adding host: %s", host) - m.connections.Store(*host.id, host) + m.connections[*(host.id)] = host +} + +// Removes a host from the connection manager +func (m *Manager) RemoveHost(hid *id.ID) { + m.mux.Lock() + defer m.mux.Unlock() + delete(m.connections, *hid) } // Closes all client connections and removes them from Manager func (m *Manager) DisconnectAll() { - m.connections.Range(func(key interface{}, value interface{}) bool { - value.(*Host).Disconnect() - return true - }) + m.mux.RLock() + defer m.mux.RUnlock() + for _, host := range m.connections { + host.Disconnect() + } } // Implements Stringer for debug printing func (m *Manager) String() string { var result bytes.Buffer - m.connections.Range(func(key interface{}, value interface{}) bool { - k := key.(id.ID) + for k, host := range m.connections { result.WriteString(fmt.Sprintf("[%s]: %+v", - (&k).String(), value.(*Host))) - return true - }) - + (&k).String(), host)) + } return result.String() } diff --git a/connect/manager_test.go b/connect/manager_test.go index a4a404cd501c89711b7b74bc0319d35b64b52f22..f2535e5230812d5833cc463625dc7f59c1a63a99 100644 --- a/connect/manager_test.go +++ b/connect/manager_test.go @@ -42,7 +42,6 @@ func TestMain(m *testing.M) { func TestSetCredentials_InvalidCert(t *testing.T) { host := &Host{ - address: "", certificate: []byte("test"), } err := host.setCredentials() @@ -58,14 +57,14 @@ func TestConnectionManager_Disconnect(t *testing.T) { test := 2 pass := 0 address := ServerAddress - var manager Manager + manager := newManager() testId := id.NewIdFromString("testId", id.Node, t) - host, err := manager.AddHost(testId, address, nil, false, false) + host, err := manager.AddHost(testId, address, nil, GetDefaultHostParams()) if err != nil { t.Errorf("Unable to call connnect: %+v", err) } - _, inMap := manager.connections.Load(*testId) + _, inMap := manager.connections[*testId] if !inMap { t.Errorf("connect Function didn't add connection to map") @@ -96,11 +95,11 @@ func TestConnectionManager_DisconnectAll(t *testing.T) { pass := 0 address := ServerAddress address2 := ServerAddress2 - var manager Manager + manager := newManager() testId := id.NewIdFromString("testId", id.Generic, t) testId2 := id.NewIdFromString("TestId2", id.Generic, t) - host, err := manager.AddHost(testId, address, nil, false, false) + host, err := manager.AddHost(testId, address, nil, GetDefaultHostParams()) if err != nil { t.Errorf("Unable to call NewHost: %+v", err) } @@ -113,7 +112,7 @@ func TestConnectionManager_DisconnectAll(t *testing.T) { pass++ } - host2, err := manager.AddHost(testId2, address2, nil, false, false) + host2, err := manager.AddHost(testId2, address2, nil, GetDefaultHostParams()) if err != nil { t.Errorf("Unable to call NewHost: %+v", err) } @@ -127,7 +126,7 @@ func TestConnectionManager_DisconnectAll(t *testing.T) { t.Errorf("Unable to call connnect: %+v", err) } - _, inMap = manager.connections.Load(*testId2) + _, inMap = manager.connections[*testId2] if !inMap { t.Errorf("connect Function didn't add connection to map") @@ -152,13 +151,13 @@ func TestConnectionManager_DisconnectAll(t *testing.T) { } func TestConnectionManager_String(t *testing.T) { - var manager Manager + manager := newManager() //t.Log(manager) certPath := testkeys.GetNodeCertPath() certData := testkeys.LoadFromPath(certPath) testID := id.NewIdFromString("test", id.Node, t) - _, err := manager.AddHost(testID, "test", certData, false, false) + _, err := manager.AddHost(testID, "test", certData, GetDefaultHostParams()) if err != nil { t.Errorf("Unable to call NewHost: %+v", err) } @@ -170,7 +169,7 @@ func TestConnectionManager_String(t *testing.T) { // Show that if a connection is in the map, // it's no longer in the map after RemoveHost is called func TestConnectionManager_RemoveHost(t *testing.T) { - var manager Manager + manager := newManager() // After adding the host, the connection should be accessible id := id.NewIdFromString("i am a connection", id.Gateway, t) diff --git a/connect/token/map.go b/connect/token/map.go new file mode 100644 index 0000000000000000000000000000000000000000..b106a8819468d41a67971524114a770f6b6379c6 --- /dev/null +++ b/connect/token/map.go @@ -0,0 +1,53 @@ +package token + +import ( + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/crypto/nonce" + "sync" +) + +type Map struct { + m map[Token]nonce.Nonce + mux sync.Mutex +} + +func NewMap() *Map { + return &Map{ + m: make(map[Token]nonce.Nonce), + mux: sync.Mutex{}, + } +} + +func (m *Map) Generate() Token { + newNonce, err := nonce.NewNonce(nonce.RegistrationTTL) + if err != nil { + jww.FATAL.Panicf("Failed to generate new Token/Nonce pair: %s", err) + } + + newToken := GenerateToken(newNonce) + + m.mux.Lock() + + m.m[newToken] = newNonce + + m.mux.Unlock() + + jww.DEBUG.Printf("Token generated: %v", newToken) + return newToken + +} + +func (m *Map) Validate(token Token) bool { + m.mux.Lock() + retrievedNonce, ok := m.m[token] + delete(m.m, token) + m.mux.Unlock() + + if !ok { + return false + } + + return retrievedNonce.IsValid() + +} + diff --git a/connect/token/map_test.go b/connect/token/map_test.go new file mode 100644 index 0000000000000000000000000000000000000000..d841170b9e7a761a82793f6b50e7b9d6e1a7e400 --- /dev/null +++ b/connect/token/map_test.go @@ -0,0 +1,48 @@ +package token + +import ( + nonce2 "gitlab.com/elixxir/crypto/nonce" + "testing" +) + +// Unit test for NewMap +func TestNewMap(t *testing.T) { + m := NewMap() + if m.m == nil { + t.Errorf("Failed to initialize map") + } +} + +// Unit test for Map.Generate() +func TestMap_Generate(t *testing.T) { + m := NewMap() + _ = m.Generate() +} + +// Unit test for Map.Validate() +func TestMap_Validate(t *testing.T) { + m := NewMap() + + token := m.Generate() + + valid := m.Validate(token) + if !valid { + t.Error("Failed to validate token") + } +} + +// Test for invalid token +func TestMap_ValidateError(t *testing.T) { + m := NewMap() + + nonce, err := nonce2.NewNonce(nonce2.NonceLen) + if err != nil { + t.Errorf("Failed to generate nonce for token") + } + token := GenerateToken(nonce) + + valid := m.Validate(token) + if valid { + t.Error("Token should not have validated") + } +} diff --git a/connect/token/token.go b/connect/token/token.go new file mode 100644 index 0000000000000000000000000000000000000000..54613584356ab9478ddc6bc763445e661c921c0e --- /dev/null +++ b/connect/token/token.go @@ -0,0 +1,99 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// +package token + +import ( + "bytes" + "github.com/pkg/errors" + "gitlab.com/elixxir/crypto/nonce" + "sync" +) + +type Token [nonce.NonceLen]byte + +// Generates a new token and adds it to internal state +func GenerateToken(newNonce nonce.Nonce) Token { + return Token(newNonce.Value) +} + +func Unmarshal(newVal []byte) (Token, error) { + if len(newVal) != nonce.NonceLen { + return Token{}, errors.Errorf("New value is not of expected length. "+ + "Expected length of %d, received length: %d", nonce.NonceLen, len(newVal)) + } + var t Token + copy(t[:], newVal) + return t, nil +} + +func (t Token) Marshal() []byte { + return t[:] +} + +func (t Token) Equals(u Token) bool { + return bytes.Equal(t[:], u[:]) +} + +// Represents a reverse-authentication token +type Live struct { + mux sync.RWMutex + t Token + has bool +} + +// Constructor which initializes a token for +// use by the associated host object +func NewLive() *Live { + return &Live{ + has: false, + } +} + +// Get reads and returns the token +func (l *Live) Get() (Token, bool) { + l.mux.RLock() + defer l.mux.RUnlock() + var tCopy Token + copy(tCopy[:], l.t[:]) + return tCopy, l.has +} + +// Get reads and returns the token +func (l *Live) GetBytes() []byte { + t, ok := l.Get() + if !ok { + return nil + } else { + return t[:] + } +} + +//Returns true if a token is present +func (l *Live) Has() bool { + l.mux.RLock() + defer l.mux.RUnlock() + return l.has +} + +// Set rewrites the token for negotiation or renegotiation +func (l *Live) Set(newToken Token) { + l.mux.Lock() + copy(l.t[:], newToken[:]) + l.has = true + l.mux.Unlock() +} + +// Clear is used to set token to a nil value +// as store will not let you do this explicitly +func (l *Live) Clear() { + l.mux.Lock() + for i := 0; i < len(l.t); i++ { + l.t[i] = 0 + } + l.has = false + l.mux.Unlock() +} diff --git a/connect/token/token_test.go b/connect/token/token_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7d4bb9d0b528f592957d44a887e77d819ace4eac --- /dev/null +++ b/connect/token/token_test.go @@ -0,0 +1,257 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// +package token + +import ( + "bytes" + nonce2 "gitlab.com/elixxir/crypto/nonce" + "reflect" + "testing" +) + +// ############## +// # LIVE TESTS # +// ############## + +// Unit test for NewLive +func TestNewLive(t *testing.T) { + newToken := NewLive() + + // Test that token is empty on initialization + if !reflect.DeepEqual(newToken.t, Token{}) { + t.Errorf("New token's toke initialized incorrectly."+ + "\n\tExpected: %v"+ + "\n\tReceived: %v", nil, newToken.t) + } + if newToken.has { + t.Errorf("New token'starts off not clear.") + } +} + +// Unit test for Set +func TestLive_SetToken(t *testing.T) { + newToken := NewLive() + expectedVal := []byte("testToken") + tkn := Token{} + copy(tkn[:], expectedVal) + + // Set token's value + newToken.Set(tkn) + + // Check that the new value has been written to the token + if !bytes.Equal(tkn[:], newToken.GetBytes()) { + t.Errorf("Set did not write value as expected."+ + "\n\tExpected: %v"+ + "\n\tReceived: %v", tkn[:], newToken.GetBytes()) + } +} + +// Unit test for Get +func TestLive_GetToken(t *testing.T) { + newToken := NewLive() + + // Test Get on a newly initialized token object + if newToken.Has() { + t.Errorf("Get did not retrieve expected value on initialization.") + } + + // Set a new value for token + expectedVal := []byte("testToken") + tkn := Token{} + copy(tkn[:], expectedVal) + newToken.Set(tkn) + + // Test that the new value is successfully retrieved by Get + if !bytes.Equal(tkn[:], newToken.GetBytes()) { + t.Errorf("Get did not retrieve expected value after a Set call."+ + "\n\tExpected: %v"+ + "\n\tReceived: %v", expectedVal, newToken.GetBytes()) + } +} + +// Unit test for Live.Clear +func TestLive_Clear(t *testing.T) { + l := NewLive() + n1, err := nonce2.NewNonce(nonce2.NonceLen) + if err != nil { + t.Errorf("Failed to generate nonce") + } + t1 := GenerateToken(n1) + l.Set(t1) + + l.Clear() + + cleared := true + for i := 0; i < len(l.t); i++ { + if l.t[i] != 0 { + cleared = false + } + } + + if l.has || !cleared { + t.Errorf("Did not properly clear token") + } +} + +// Unit test for Live.Get +func TestLive_Get(t *testing.T) { + l := NewLive() + n1, err := nonce2.NewNonce(nonce2.NonceLen) + if err != nil { + t.Errorf("Failed to generate nonce") + } + t1 := GenerateToken(n1) + l.Set(t1) + + t2, _ := l.Get() + if !t2.Equals(t1) { + t.Error("Did not get same token we set") + } + + l.Clear() + + if l.t.Equals(t2) { + t.Error("token from live was not deep copied") + } +} + +// Unit test for Live.GetBytes +func TestLive_GetBytes(t *testing.T) { + l := NewLive() + n1, err := nonce2.NewNonce(nonce2.NonceLen) + if err != nil { + t.Errorf("Failed to generate nonce") + } + t1 := GenerateToken(n1) + l.Set(t1) + + b := l.GetBytes() + if bytes.Compare(b, t1.Marshal()) != 0 { + t.Error("Did not receive same token") + } + + l.Clear() + + if bytes.Compare(b, l.GetBytes()) == 0 { + t.Error("Clearing token also cleared received data") + } +} + +// Unit test for Live.Has +func TestLive_Has(t *testing.T) { + l := NewLive() + n1, err := nonce2.NewNonce(nonce2.NonceLen) + if err != nil { + t.Errorf("Failed to generate nonce") + } + t1 := GenerateToken(n1) + if l.Has() { + t.Error("Has was initialized as true") + } + + l.Set(t1) + if !l.Has() { + t.Error("Has was not set properly") + } +} + +// Unit test for Live.Set +func TestLive_Set(t *testing.T) { + l := NewLive() + n1, err := nonce2.NewNonce(nonce2.NonceLen) + if err != nil { + t.Errorf("Failed to generate nonce") + } + t1 := GenerateToken(n1) + + l.Set(t1) + + if !l.t.Equals(t1) || !l.has { + t.Error("Didn't properly set token") + } +} + +// ############### +// # TOKEN TESTS # +// ############### + +// Unit test for GenerateToken +func TestGenerateToken(t *testing.T) { + nonce, err := nonce2.NewNonce(nonce2.NonceLen) + if err != nil { + t.Errorf("Failed to generate nonce: %+v", err) + } + _ = GenerateToken(nonce) +} + +// Unit test for Unmarshal +func TestUnmarshal(t *testing.T) { + n1, err := nonce2.NewNonce(nonce2.NonceLen) + if err != nil { + t.Errorf("Failed to generate nonce") + } + t1 := GenerateToken(n1) + m := t1.Marshal() + + t2, err := Unmarshal(m) + if err != nil { + t.Errorf("Error unmarshalling token data: %+v", err) + } + if !t2.Equals(t1) { + t.Errorf("Unmarshalled and original tokens are different\nOriginal: %+v, Unmarshalled: %+v", t1, t2) + } +} + +// Error path test for Unmarshal +func TestUnmarshal_Error(t *testing.T) { + badData := []byte("test") + _, err := Unmarshal(badData) + if err == nil { + t.Error("Should have received an error for bad data") + } +} + +// Unit test for Token.Marshal +func TestToken_Marshal(t *testing.T) { + nonce, err := nonce2.NewNonce(nonce2.NonceLen) + if err != nil { + t.Errorf("Failed to generate nonce") + } + newToken := GenerateToken(nonce) + m := newToken.Marshal() + + if len(m) != len(newToken) { + t.Error("Lengths are different, Did not properly marshal token") + } + + m[0] = m[0] + 1 + if m[0] == newToken[0] { + t.Error("Token was not properly copied to new location") + } +} + +// Unit test for Token.Equals +func TestToken_Equals(t *testing.T) { + n1, err := nonce2.NewNonce(nonce2.NonceLen) + if err != nil { + t.Errorf("Failed to generate nonce") + } + t1 := GenerateToken(n1) + t1_copy := GenerateToken(n1) + if !t1.Equals(t1_copy) { + t.Errorf("Tokens made with same data were not equal") + } + + n2, err := nonce2.NewNonce(nonce2.NonceLen) + if err != nil { + t.Errorf("Failed to generate nonce 2") + } + t2 := GenerateToken(n2) + if t1.Equals(t2) { + t.Errorf("Tokens made with different data were equal") + } +} diff --git a/connect/transmit.go b/connect/transmit.go new file mode 100644 index 0000000000000000000000000000000000000000..30a1f70a9a96e84a1ce42e91865002de4d7349b3 --- /dev/null +++ b/connect/transmit.go @@ -0,0 +1,85 @@ +package connect + +import ( + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "google.golang.org/grpc" +) + +const MaxRetries = 3 + +// Sets up or recovers the Host's connection +// Then runs the given Send function +func (c *ProtoComms) transmit(host *Host, f func(conn *grpc.ClientConn) (interface{}, + error)) (result interface{}, err error) { + + if host.GetAddress() == "" { + return nil, errors.New("Host address is blank, host might be receive only.") + } + + for numRetries := 0; numRetries < MaxRetries; numRetries++ { + err = nil + //reconnect if necessary + connected, connectionCount := host.Connected() + if !connected { + connectionCount, err = c.connect(host, connectionCount) + if err != nil { + jww.WARN.Printf("Failed to connect to Host on attempt "+ + "%v/%v : %s", numRetries+1, MaxRetries, err) + continue + } + } + + //transmit + result, err = host.transmit(f) + + // if the transmission goes well or it is a domain specific error, return + if err == nil || !(isConnError(err) || IsAuthError(err)) { + return result, err + } + host.conditionalDisconnect(connectionCount) + jww.WARN.Printf("Failed to send to Host on attempt %v/%v", + numRetries+1, MaxRetries) + } + + return nil, err +} + +func (c *ProtoComms) connect(host *Host, count uint64) (uint64, error) { + host.sendMux.Lock() + defer host.sendMux.Unlock() + + //if the connection is alive return, it is possible for another transmission + //to connect between releasing the read lock and taking the write lick + if !host.isAlive() { + //connect to host + jww.INFO.Printf("Host %s not connected, attempting to connect...", + host.id) + err := host.connect() + + count = host.connectionCount + + //if connection cannot be made, do not retry + if err != nil { + host.disconnect() + return count, errors.WithMessagef(err, "Failed to connect to Host %s", + host.id) + } + } + + //check if authentication is needed + if host.authenticationRequired() { + jww.INFO.Printf("Attempting to establish authentication with host %s", + host.id) + err := c.clientHandshake(host) + + //if authentication cannot be made, do not retry + if err != nil { + host.disconnect() + return count, errors.WithMessagef(err, "Failed to authenticate with "+ + "host: %s", host.id) + } + } + + return count, nil +}