Skip to content
Snippets Groups Projects
Commit 6cf8213d authored by Josh Brooks's avatar Josh Brooks
Browse files

Update keys for new registration values

parent 41adc129
No related branches found
No related tags found
2 merge requests!53Release,!29Josh/databaseless
...@@ -56,9 +56,9 @@ func Test_attemptSendCmix(t *testing.T) { ...@@ -56,9 +56,9 @@ func Test_attemptSendCmix(t *testing.T) {
nid2 := id.NewIdFromString("jakexx360", id.Node, t) nid2 := id.NewIdFromString("jakexx360", id.Node, t)
nid3 := id.NewIdFromString("westparkhome", id.Node, t) nid3 := id.NewIdFromString("westparkhome", id.Node, t)
grp := cyclic.NewGroup(large.NewInt(7), large.NewInt(13)) grp := cyclic.NewGroup(large.NewInt(7), large.NewInt(13))
sess1.Cmix().Add(nid1, grp.NewInt(1)) sess1.Cmix().Add(nid1, grp.NewInt(1), 0, nil)
sess1.Cmix().Add(nid2, grp.NewInt(2)) sess1.Cmix().Add(nid2, grp.NewInt(2), 0, nil)
sess1.Cmix().Add(nid3, grp.NewInt(3)) sess1.Cmix().Add(nid3, grp.NewInt(3), 0, nil)
timestamps := []uint64{ timestamps := []uint64{
uint64(now.Add(-30 * time.Second).UnixNano()), //PENDING uint64(now.Add(-30 * time.Second).UnixNano()), //PENDING
......
...@@ -55,9 +55,9 @@ func Test_attemptSendManyCmix(t *testing.T) { ...@@ -55,9 +55,9 @@ func Test_attemptSendManyCmix(t *testing.T) {
nid2 := id.NewIdFromString("jakexx360", id.Node, t) nid2 := id.NewIdFromString("jakexx360", id.Node, t)
nid3 := id.NewIdFromString("westparkhome", id.Node, t) nid3 := id.NewIdFromString("westparkhome", id.Node, t)
grp := cyclic.NewGroup(large.NewInt(7), large.NewInt(13)) grp := cyclic.NewGroup(large.NewInt(7), large.NewInt(13))
sess1.Cmix().Add(nid1, grp.NewInt(1)) sess1.Cmix().Add(nid1, grp.NewInt(1), 0, nil)
sess1.Cmix().Add(nid2, grp.NewInt(2)) sess1.Cmix().Add(nid2, grp.NewInt(2), 0, nil)
sess1.Cmix().Add(nid3, grp.NewInt(3)) sess1.Cmix().Add(nid3, grp.NewInt(3), 0, nil)
timestamps := []uint64{ timestamps := []uint64{
uint64(now.Add(-30 * time.Second).UnixNano()), // PENDING uint64(now.Add(-30 * time.Second).UnixNano()), // PENDING
......
...@@ -114,6 +114,8 @@ func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, ...@@ -114,6 +114,8 @@ func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface,
jww.INFO.Printf("registerWithNode() begin registration with node: %s", nodeID) jww.INFO.Printf("registerWithNode() begin registration with node: %s", nodeID)
var transmissionKey *cyclic.Int var transmissionKey *cyclic.Int
var validUntil uint64
var keyId []byte
// TODO: should move this to a precanned user initialization // TODO: should move this to a precanned user initialization
if uci.IsPrecanned() { if uci.IsPrecanned() {
userNum := int(uci.GetTransmissionID().Bytes()[7]) userNum := int(uci.GetTransmissionID().Bytes()[7])
...@@ -125,7 +127,7 @@ func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, ...@@ -125,7 +127,7 @@ func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface,
jww.INFO.Printf("transmissionKey: %v", transmissionKey.Bytes()) jww.INFO.Printf("transmissionKey: %v", transmissionKey.Bytes())
} else { } else {
// Request key from server // Request key from server
transmissionKey, err = requestKey(sender, comms, ngw, regSig, transmissionKey, keyId, validUntil, err = requestKey(sender, comms, ngw, regSig,
registrationTimestampNano, uci, store, rng, stop) registrationTimestampNano, uci, store, rng, stop)
if err != nil { if err != nil {
...@@ -134,16 +136,17 @@ func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, ...@@ -134,16 +136,17 @@ func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface,
} }
store.Add(nodeID, transmissionKey) store.Add(nodeID, transmissionKey, validUntil, keyId)
jww.INFO.Printf("Completed registration with node %s", nodeID) jww.INFO.Printf("Completed registration with node %s", nodeID)
return nil return nil
} }
func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw network.NodeGateway, func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface,
regSig []byte, registrationTimestampNano int64, uci *user.CryptographicIdentity, ngw network.NodeGateway, regSig []byte, registrationTimestampNano int64,
store *cmix.Store, rng csprng.Source, stop *stoppable.Single) (*cyclic.Int, error) { uci *user.CryptographicIdentity, store *cmix.Store, rng csprng.Source,
stop *stoppable.Single) (*cyclic.Int, []byte, uint64, error) {
dhPub := store.GetDHPublicKey().Bytes() dhPub := store.GetDHPublicKey().Bytes()
...@@ -152,7 +155,7 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne ...@@ -152,7 +155,7 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne
confirmation := &pb.ClientRegistrationConfirmation{RSAPubKey: string(userPubKeyRSA), Timestamp: registrationTimestampNano} confirmation := &pb.ClientRegistrationConfirmation{RSAPubKey: string(userPubKeyRSA), Timestamp: registrationTimestampNano}
confirmationSerialized, err := proto.Marshal(confirmation) confirmationSerialized, err := proto.Marshal(confirmation)
if err != nil { if err != nil {
return nil, err return nil, nil, 0, err
} }
keyRequest := &pb.ClientKeyRequest{ keyRequest := &pb.ClientKeyRequest{
...@@ -168,7 +171,7 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne ...@@ -168,7 +171,7 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne
serializedMessage, err := proto.Marshal(keyRequest) serializedMessage, err := proto.Marshal(keyRequest)
if err != nil { if err != nil {
return nil, err return nil, nil, 0, err
} }
opts := rsa.NewDefaultOptions() opts := rsa.NewDefaultOptions()
...@@ -181,14 +184,14 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne ...@@ -181,14 +184,14 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne
clientSig, err := rsa.Sign(rng, uci.GetTransmissionRSA(), opts.Hash, clientSig, err := rsa.Sign(rng, uci.GetTransmissionRSA(), opts.Hash,
data, opts) data, opts)
if err != nil { if err != nil {
return nil, err return nil, nil, 0, err
} }
gwid := ngw.Gateway.ID gwid := ngw.Gateway.ID
gatewayID, err := id.Unmarshal(gwid) gatewayID, err := id.Unmarshal(gwid)
if err != nil { if err != nil {
jww.ERROR.Println("registerWithNode() failed to decode gatewayID") jww.ERROR.Println("registerWithNode() failed to decode gatewayID")
return nil, err return nil, nil, 0, err
} }
// Request nonce message from gateway // Request nonce message from gateway
...@@ -211,16 +214,14 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne ...@@ -211,16 +214,14 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne
}, stop) }, stop)
if err != nil { if err != nil {
return nil, err return nil, nil, 0, err
} }
signedKeyResponse := result.(*pb.SignedKeyResponse) signedKeyResponse := result.(*pb.SignedKeyResponse)
if signedKeyResponse.Error != "" { if signedKeyResponse.Error != "" {
return nil, errors.New(signedKeyResponse.Error) return nil, nil, 0, errors.New(signedKeyResponse.Error)
} }
// fixme: node signs the response but client has a partial ndf
// w/ no node cert info, so it can't verify the signature
// Hash the response // Hash the response
h.Reset() h.Reset()
h.Write(signedKeyResponse.KeyResponse) h.Write(signedKeyResponse.KeyResponse)
...@@ -229,27 +230,27 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne ...@@ -229,27 +230,27 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne
// Load node certificate // Load node certificate
gatewayCert, err := tls.LoadCertificate(ngw.Gateway.TlsCertificate) gatewayCert, err := tls.LoadCertificate(ngw.Gateway.TlsCertificate)
if err != nil { if err != nil {
return nil, errors.WithMessagef(err, "Unable to load node's certificate") return nil, nil, 0, errors.WithMessagef(err, "Unable to load node's certificate")
} }
// Extract public key // Extract public key
nodePubKey, err := tls.ExtractPublicKey(gatewayCert) nodePubKey, err := tls.ExtractPublicKey(gatewayCert)
if err != nil { if err != nil {
return nil, errors.WithMessagef(err, "Unable to load node's public key") return nil, nil, 0, errors.WithMessagef(err, "Unable to load node's public key")
} }
// Verify the response signature // Verify the response signature
err = rsa.Verify(nodePubKey, opts.Hash, hashedResponse, err = rsa.Verify(nodePubKey, opts.Hash, hashedResponse,
signedKeyResponse.KeyResponseSignedByGateway.Signature, opts) signedKeyResponse.KeyResponseSignedByGateway.Signature, opts)
if err != nil { if err != nil {
return nil, errors.WithMessagef(err, "Could not verify node's signature") return nil, nil, 0, errors.WithMessagef(err, "Could not verify node's signature")
} }
// Unmarshal the response // Unmarshal the response
keyResponse := &pb.ClientKeyResponse{} keyResponse := &pb.ClientKeyResponse{}
err = proto.Unmarshal(signedKeyResponse.KeyResponse, keyResponse) err = proto.Unmarshal(signedKeyResponse.KeyResponse, keyResponse)
if err != nil { if err != nil {
return nil, errors.WithMessagef(err, "Failed to unmarshal client key response") return nil, nil, 0, errors.WithMessagef(err, "Failed to unmarshal client key response")
} }
h.Reset() h.Reset()
...@@ -266,18 +267,18 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne ...@@ -266,18 +267,18 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne
h.Reset() h.Reset()
if !registration.VerifyClientHMAC(sessionKey.Bytes(), keyResponse.EncryptedClientKey, if !registration.VerifyClientHMAC(sessionKey.Bytes(), keyResponse.EncryptedClientKey,
h, keyResponse.EncryptedClientKeyHMAC) { h, keyResponse.EncryptedClientKeyHMAC) {
return nil, errors.WithMessagef(err, "Failed to verify client HMAC") return nil, nil, 0, errors.WithMessagef(err, "Failed to verify client HMAC")
} }
// Decrypt the client key // Decrypt the client key
clientKey, err := chacha.Decrypt(sessionKey.Bytes(), keyResponse.EncryptedClientKey) clientKey, err := chacha.Decrypt(sessionKey.Bytes(), keyResponse.EncryptedClientKey)
if err != nil { if err != nil {
return nil, errors.WithMessagef(err, "Failed to decrypt client key") return nil, nil, 0, errors.WithMessagef(err, "Failed to decrypt client key")
} }
// Construct the transmission key from the client key // Construct the transmission key from the client key
transmissionKey := store.GetGroup().NewIntFromBytes(clientKey) transmissionKey := store.GetGroup().NewIntFromBytes(clientKey)
// Use Client keypair to sign Server nonce // Use Client keypair to sign Server nonce
return transmissionKey, nil return transmissionKey, keyResponse.KeyID, keyResponse.ValidUntil, nil
} }
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
package cmix package cmix
import ( import (
"bytes"
"encoding/binary"
jww "github.com/spf13/jwalterweatherman" jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/client/storage/versioned" "gitlab.com/elixxir/client/storage/versioned"
"gitlab.com/elixxir/crypto/cyclic" "gitlab.com/elixxir/crypto/cyclic"
...@@ -20,14 +22,17 @@ const currentKeyVersion = 0 ...@@ -20,14 +22,17 @@ const currentKeyVersion = 0
type key struct { type key struct {
kv *versioned.KV kv *versioned.KV
k *cyclic.Int k *cyclic.Int
keyId []byte
validUntil uint64
storeKey string storeKey string
} }
func newKey(kv *versioned.KV, k *cyclic.Int, id *id.ID) *key { func newKey(kv *versioned.KV, k *cyclic.Int, id *id.ID, validUntil uint64, keyId []byte) *key {
nk := &key{ nk := &key{
kv: kv, kv: kv,
k: k, k: k,
keyId: keyId,
validUntil: validUntil,
storeKey: keyKey(id), storeKey: keyKey(id),
} }
...@@ -89,15 +94,59 @@ func (k *key) delete(kv *versioned.KV, id *id.ID) { ...@@ -89,15 +94,59 @@ func (k *key) delete(kv *versioned.KV, id *id.ID) {
} }
} }
// makes a binary representation of the given key in the keystore // makes a binary representation of the given key and key values
// in the keystore
func (k *key) marshal() ([]byte, error) { func (k *key) marshal() ([]byte, error) {
return k.k.GobEncode() buff := bytes.NewBuffer(nil)
keyBytes, err := k.k.GobEncode()
if err != nil {
return nil, err
}
// Write key length
b := make([]byte, 8)
binary.LittleEndian.PutUint64(b, uint64(len(keyBytes)))
buff.Write(b)
// Write key
buff.Write(keyBytes)
// Write the keyId length
binary.LittleEndian.PutUint64(b, uint64(len(k.keyId)))
buff.Write(b)
// Write keyID
buff.Write(k.keyId)
// Write valid until
binary.LittleEndian.PutUint64(b, k.validUntil)
buff.Write(b)
return buff.Bytes(), nil
} }
// resets the data of the key from the binary representation of the key passed in // resets the data of the key from the binary representation of the key passed in
func (k *key) unmarshal(b []byte) error { func (k *key) unmarshal(b []byte) error {
buff := bytes.NewBuffer(b)
// Get the key length
keyLen := int(binary.LittleEndian.Uint64(buff.Next(8)))
// Decode the key length
k.k = &cyclic.Int{} k.k = &cyclic.Int{}
return k.k.GobDecode(b) err := k.k.GobDecode(buff.Next(keyLen))
if err != nil {
return err
}
// Get the keyID length
keyIDLen := int(binary.LittleEndian.Uint64(buff.Next(8)))
k.keyId = buff.Next(keyIDLen)
// Get the valid until value
k.validUntil = binary.LittleEndian.Uint64(buff.Next(8))
return nil
} }
// Adheres to the stringer interface // Adheres to the stringer interface
......
...@@ -34,7 +34,8 @@ type Store struct { ...@@ -34,7 +34,8 @@ type Store struct {
nodes map[id.ID]*key nodes map[id.ID]*key
dhPrivateKey *cyclic.Int dhPrivateKey *cyclic.Int
dhPublicKey *cyclic.Int dhPublicKey *cyclic.Int
validUntil uint64
keyId []byte
grp *cyclic.Group grp *cyclic.Group
kv *versioned.KV kv *versioned.KV
mux sync.RWMutex mux sync.RWMutex
...@@ -98,11 +99,12 @@ func LoadStore(kv *versioned.KV) (*Store, error) { ...@@ -98,11 +99,12 @@ func LoadStore(kv *versioned.KV) (*Store, error) {
// Add adds the key for a round to the cMix storage object. Saves the updated // Add adds the key for a round to the cMix storage object. Saves the updated
// list of nodes and the key to disk. // list of nodes and the key to disk.
func (s *Store) Add(nid *id.ID, k *cyclic.Int) { func (s *Store) Add(nid *id.ID, k *cyclic.Int,
validUntil uint64, keyId []byte) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
nodeKey := newKey(s.kv, k, nid) nodeKey := newKey(s.kv, k, nid, validUntil, keyId)
s.nodes[*nid] = nodeKey s.nodes[*nid] = nodeKey
if err := s.save(); err != nil { if err := s.save(); err != nil {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
package cmix package cmix
import ( import (
"bytes"
"gitlab.com/elixxir/client/storage/versioned" "gitlab.com/elixxir/client/storage/versioned"
"gitlab.com/elixxir/crypto/cyclic" "gitlab.com/elixxir/crypto/cyclic"
"gitlab.com/elixxir/crypto/diffieHellman" "gitlab.com/elixxir/crypto/diffieHellman"
...@@ -16,6 +17,7 @@ import ( ...@@ -16,6 +17,7 @@ import (
"gitlab.com/xx_network/crypto/large" "gitlab.com/xx_network/crypto/large"
"gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/id"
"testing" "testing"
"time"
) )
// Happy path // Happy path
...@@ -57,9 +59,9 @@ func TestStore_AddRemove(t *testing.T) { ...@@ -57,9 +59,9 @@ func TestStore_AddRemove(t *testing.T) {
testStore, _ := makeTestStore() testStore, _ := makeTestStore()
nodeId := id.NewIdFromString("test", id.Node, t) nodeId := id.NewIdFromString("test", id.Node, t)
key := testStore.grp.NewInt(5) k := testStore.grp.NewInt(5)
keyId := []byte("keyId")
testStore.Add(nodeId, key) testStore.Add(nodeId, k, 0, keyId)
if _, exists := testStore.nodes[*nodeId]; !exists { if _, exists := testStore.nodes[*nodeId]; !exists {
t.Fatal("Failed to add node key") t.Fatal("Failed to add node key")
} }
...@@ -80,7 +82,7 @@ func TestStore_AddHas(t *testing.T) { ...@@ -80,7 +82,7 @@ func TestStore_AddHas(t *testing.T) {
nodeId := id.NewIdFromString("test", id.Node, t) nodeId := id.NewIdFromString("test", id.Node, t)
key := testStore.grp.NewInt(5) key := testStore.grp.NewInt(5)
testStore.Add(nodeId, key) testStore.Add(nodeId, key, 0, nil)
if _, exists := testStore.nodes[*nodeId]; !exists { if _, exists := testStore.nodes[*nodeId]; !exists {
t.Fatal("Failed to add node key") t.Fatal("Failed to add node key")
} }
...@@ -113,9 +115,17 @@ func TestLoadStore(t *testing.T) { ...@@ -113,9 +115,17 @@ func TestLoadStore(t *testing.T) {
// Add a test node key // Add a test node key
nodeId := id.NewIdFromString("test", id.Node, t) nodeId := id.NewIdFromString("test", id.Node, t)
key := testStore.grp.NewInt(5) k := testStore.grp.NewInt(5)
testTime, err := time.Parse(time.RFC3339,
"2012-12-21T22:08:41+00:00")
if err != nil {
t.Fatalf("Could not parse precanned time: %v", err.Error())
}
expectedValid := uint64(testTime.UnixNano())
testStore.Add(nodeId, key) expectedKeyId := []byte("expectedKeyID")
testStore.Add(nodeId, k, uint64(expectedValid), expectedKeyId)
// Load the store and check its attributes // Load the store and check its attributes
store, err := LoadStore(kv) store, err := LoadStore(kv)
...@@ -131,6 +141,18 @@ func TestLoadStore(t *testing.T) { ...@@ -131,6 +141,18 @@ func TestLoadStore(t *testing.T) {
if len(store.nodes) != len(testStore.nodes) { if len(store.nodes) != len(testStore.nodes) {
t.Errorf("LoadStore failed to load node keys") t.Errorf("LoadStore failed to load node keys")
} }
circuit := connect.NewCircuit([]*id.ID{nodeId})
keys, _ := store.GetRoundKeys(circuit)
if keys.keys[0].validUntil != expectedValid {
t.Errorf("Unexpected valid until value loaded from store."+
"\n\tExpected: %v\n\tReceived: %v", expectedValid, keys.keys[0].validUntil)
}
if !bytes.Equal(keys.keys[0].keyId, expectedKeyId) {
t.Errorf("Unexpected keyID value loaded from store."+
"\n\tExpected: %v\n\tReceived: %v", expectedKeyId, keys.keys[0].keyId)
}
} }
// Happy path // Happy path
...@@ -145,7 +167,7 @@ func TestStore_GetRoundKeys(t *testing.T) { ...@@ -145,7 +167,7 @@ func TestStore_GetRoundKeys(t *testing.T) {
for i := 0; i < numIds; i++ { for i := 0; i < numIds; i++ {
nodeIds[i] = id.NewIdFromUInt(uint64(i)+1, id.Node, t) nodeIds[i] = id.NewIdFromUInt(uint64(i)+1, id.Node, t)
key := testStore.grp.NewInt(int64(i) + 1) key := testStore.grp.NewInt(int64(i) + 1)
testStore.Add(nodeIds[i], key) testStore.Add(nodeIds[i], key, 0, nil)
// This is wack but it cleans up after the test // This is wack but it cleans up after the test
defer testStore.Remove(nodeIds[i]) defer testStore.Remove(nodeIds[i])
...@@ -176,8 +198,8 @@ func TestStore_GetRoundKeys_Missing(t *testing.T) { ...@@ -176,8 +198,8 @@ func TestStore_GetRoundKeys_Missing(t *testing.T) {
// Only add every other node so there are missing nodes // Only add every other node so there are missing nodes
if i%2 == 0 { if i%2 == 0 {
testStore.Add(nodeIds[i], key) testStore.Add(nodeIds[i], key, 0, nil)
testStore.Add(nodeIds[i], key) testStore.Add(nodeIds[i], key, 0, nil)
// This is wack but it cleans up after the test // This is wack but it cleans up after the test
defer testStore.Remove(nodeIds[i]) defer testStore.Remove(nodeIds[i])
...@@ -211,7 +233,7 @@ func TestStore_Count(t *testing.T) { ...@@ -211,7 +233,7 @@ func TestStore_Count(t *testing.T) {
count := 50 count := 50
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
store.Add(id.NewIdFromUInt(uint64(i), id.Node, t), grp.NewInt(int64(42+i))) store.Add(id.NewIdFromUInt(uint64(i), id.Node, t), grp.NewInt(int64(42+i)), 0, nil)
} }
if store.Count() != count { if store.Count() != count {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment