diff --git a/network/message/sendCmix_test.go b/network/message/sendCmix_test.go index d3d7e116bd4eb5337628adccc87d422b0ae4b327..984e5cd2d01d4f4e687b72f6e8217c9440c980cd 100644 --- a/network/message/sendCmix_test.go +++ b/network/message/sendCmix_test.go @@ -56,9 +56,9 @@ func Test_attemptSendCmix(t *testing.T) { nid2 := id.NewIdFromString("jakexx360", id.Node, t) nid3 := id.NewIdFromString("westparkhome", id.Node, t) grp := cyclic.NewGroup(large.NewInt(7), large.NewInt(13)) - sess1.Cmix().Add(nid1, grp.NewInt(1)) - sess1.Cmix().Add(nid2, grp.NewInt(2)) - sess1.Cmix().Add(nid3, grp.NewInt(3)) + sess1.Cmix().Add(nid1, grp.NewInt(1), 0, nil) + sess1.Cmix().Add(nid2, grp.NewInt(2), 0, nil) + sess1.Cmix().Add(nid3, grp.NewInt(3), 0, nil) timestamps := []uint64{ uint64(now.Add(-30 * time.Second).UnixNano()), //PENDING diff --git a/network/message/sendManyCmix_test.go b/network/message/sendManyCmix_test.go index 1b0da6c083739765496e111628c5c3a74bdce05a..b787b3abc4445f5266a2cea51ced2429cdddba40 100644 --- a/network/message/sendManyCmix_test.go +++ b/network/message/sendManyCmix_test.go @@ -55,9 +55,9 @@ func Test_attemptSendManyCmix(t *testing.T) { nid2 := id.NewIdFromString("jakexx360", id.Node, t) nid3 := id.NewIdFromString("westparkhome", id.Node, t) grp := cyclic.NewGroup(large.NewInt(7), large.NewInt(13)) - sess1.Cmix().Add(nid1, grp.NewInt(1)) - sess1.Cmix().Add(nid2, grp.NewInt(2)) - sess1.Cmix().Add(nid3, grp.NewInt(3)) + sess1.Cmix().Add(nid1, grp.NewInt(1), 0, nil) + sess1.Cmix().Add(nid2, grp.NewInt(2), 0, nil) + sess1.Cmix().Add(nid3, grp.NewInt(3), 0, nil) timestamps := []uint64{ uint64(now.Add(-30 * time.Second).UnixNano()), // PENDING diff --git a/network/node/register.go b/network/node/register.go index 497cf22ad040c1d53a1eab765fbd3814fbe438bc..a4424eb18b16b98f2b1766f68101fcf7c8fe11b4 100644 --- a/network/node/register.go +++ b/network/node/register.go @@ -114,6 +114,8 @@ func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, jww.INFO.Printf("registerWithNode() begin registration with node: %s", nodeID) var transmissionKey *cyclic.Int + var validUntil uint64 + var keyId []byte // TODO: should move this to a precanned user initialization if uci.IsPrecanned() { userNum := int(uci.GetTransmissionID().Bytes()[7]) @@ -125,7 +127,7 @@ func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, jww.INFO.Printf("transmissionKey: %v", transmissionKey.Bytes()) } else { // Request key from server - transmissionKey, err = requestKey(sender, comms, ngw, regSig, + transmissionKey, keyId, validUntil, err = requestKey(sender, comms, ngw, regSig, registrationTimestampNano, uci, store, rng, stop) if err != nil { @@ -134,16 +136,17 @@ func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, } - store.Add(nodeID, transmissionKey) + store.Add(nodeID, transmissionKey, validUntil, keyId) jww.INFO.Printf("Completed registration with node %s", nodeID) return nil } -func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw network.NodeGateway, - regSig []byte, registrationTimestampNano int64, uci *user.CryptographicIdentity, - store *cmix.Store, rng csprng.Source, stop *stoppable.Single) (*cyclic.Int, error) { +func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, + ngw network.NodeGateway, regSig []byte, registrationTimestampNano int64, + uci *user.CryptographicIdentity, store *cmix.Store, rng csprng.Source, + stop *stoppable.Single) (*cyclic.Int, []byte, uint64, error) { dhPub := store.GetDHPublicKey().Bytes() @@ -152,7 +155,7 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne confirmation := &pb.ClientRegistrationConfirmation{RSAPubKey: string(userPubKeyRSA), Timestamp: registrationTimestampNano} confirmationSerialized, err := proto.Marshal(confirmation) if err != nil { - return nil, err + return nil, nil, 0, err } keyRequest := &pb.ClientKeyRequest{ @@ -168,7 +171,7 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne serializedMessage, err := proto.Marshal(keyRequest) if err != nil { - return nil, err + return nil, nil, 0, err } opts := rsa.NewDefaultOptions() @@ -181,14 +184,14 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne clientSig, err := rsa.Sign(rng, uci.GetTransmissionRSA(), opts.Hash, data, opts) if err != nil { - return nil, err + return nil, nil, 0, err } gwid := ngw.Gateway.ID gatewayID, err := id.Unmarshal(gwid) if err != nil { jww.ERROR.Println("registerWithNode() failed to decode gatewayID") - return nil, err + return nil, nil, 0, err } // Request nonce message from gateway @@ -211,16 +214,14 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne }, stop) if err != nil { - return nil, err + return nil, nil, 0, err } signedKeyResponse := result.(*pb.SignedKeyResponse) if signedKeyResponse.Error != "" { - return nil, errors.New(signedKeyResponse.Error) + return nil, nil, 0, errors.New(signedKeyResponse.Error) } - // fixme: node signs the response but client has a partial ndf - // w/ no node cert info, so it can't verify the signature // Hash the response h.Reset() h.Write(signedKeyResponse.KeyResponse) @@ -229,27 +230,27 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne // Load node certificate gatewayCert, err := tls.LoadCertificate(ngw.Gateway.TlsCertificate) if err != nil { - return nil, errors.WithMessagef(err, "Unable to load node's certificate") + return nil, nil, 0, errors.WithMessagef(err, "Unable to load node's certificate") } // Extract public key nodePubKey, err := tls.ExtractPublicKey(gatewayCert) if err != nil { - return nil, errors.WithMessagef(err, "Unable to load node's public key") + return nil, nil, 0, errors.WithMessagef(err, "Unable to load node's public key") } // Verify the response signature err = rsa.Verify(nodePubKey, opts.Hash, hashedResponse, signedKeyResponse.KeyResponseSignedByGateway.Signature, opts) if err != nil { - return nil, errors.WithMessagef(err, "Could not verify node's signature") + return nil, nil, 0, errors.WithMessagef(err, "Could not verify node's signature") } // Unmarshal the response keyResponse := &pb.ClientKeyResponse{} err = proto.Unmarshal(signedKeyResponse.KeyResponse, keyResponse) if err != nil { - return nil, errors.WithMessagef(err, "Failed to unmarshal client key response") + return nil, nil, 0, errors.WithMessagef(err, "Failed to unmarshal client key response") } h.Reset() @@ -266,18 +267,18 @@ func requestKey(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw ne h.Reset() if !registration.VerifyClientHMAC(sessionKey.Bytes(), keyResponse.EncryptedClientKey, h, keyResponse.EncryptedClientKeyHMAC) { - return nil, errors.WithMessagef(err, "Failed to verify client HMAC") + return nil, nil, 0, errors.WithMessagef(err, "Failed to verify client HMAC") } // Decrypt the client key clientKey, err := chacha.Decrypt(sessionKey.Bytes(), keyResponse.EncryptedClientKey) if err != nil { - return nil, errors.WithMessagef(err, "Failed to decrypt client key") + return nil, nil, 0, errors.WithMessagef(err, "Failed to decrypt client key") } // Construct the transmission key from the client key transmissionKey := store.GetGroup().NewIntFromBytes(clientKey) // Use Client keypair to sign Server nonce - return transmissionKey, nil + return transmissionKey, keyResponse.KeyID, keyResponse.ValidUntil, nil } diff --git a/storage/cmix/key.go b/storage/cmix/key.go index 3ab6fc084cff007e879cd541afc44dbf910ea10d..6c2954ac4fa65f30b50cd771f73652357b7c19bb 100644 --- a/storage/cmix/key.go +++ b/storage/cmix/key.go @@ -8,6 +8,8 @@ package cmix import ( + "bytes" + "encoding/binary" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/storage/versioned" "gitlab.com/elixxir/crypto/cyclic" @@ -18,17 +20,20 @@ import ( const currentKeyVersion = 0 type key struct { - kv *versioned.KV - k *cyclic.Int - - storeKey string + kv *versioned.KV + k *cyclic.Int + keyId []byte + validUntil uint64 + storeKey string } -func newKey(kv *versioned.KV, k *cyclic.Int, id *id.ID) *key { +func newKey(kv *versioned.KV, k *cyclic.Int, id *id.ID, validUntil uint64, keyId []byte) *key { nk := &key{ - kv: kv, - k: k, - storeKey: keyKey(id), + kv: kv, + k: k, + keyId: keyId, + validUntil: validUntil, + storeKey: keyKey(id), } if err := nk.save(); err != nil { @@ -89,15 +94,59 @@ func (k *key) delete(kv *versioned.KV, id *id.ID) { } } -// makes a binary representation of the given key in the keystore +// makes a binary representation of the given key and key values +// in the keystore func (k *key) marshal() ([]byte, error) { - return k.k.GobEncode() + buff := bytes.NewBuffer(nil) + keyBytes, err := k.k.GobEncode() + if err != nil { + return nil, err + } + + // Write key length + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(len(keyBytes))) + buff.Write(b) + + // Write key + buff.Write(keyBytes) + + // Write the keyId length + binary.LittleEndian.PutUint64(b, uint64(len(k.keyId))) + buff.Write(b) + + // Write keyID + buff.Write(k.keyId) + + // Write valid until + binary.LittleEndian.PutUint64(b, k.validUntil) + buff.Write(b) + + return buff.Bytes(), nil } // resets the data of the key from the binary representation of the key passed in func (k *key) unmarshal(b []byte) error { + buff := bytes.NewBuffer(b) + + // Get the key length + keyLen := int(binary.LittleEndian.Uint64(buff.Next(8))) + + // Decode the key length k.k = &cyclic.Int{} - return k.k.GobDecode(b) + err := k.k.GobDecode(buff.Next(keyLen)) + if err != nil { + return err + } + + // Get the keyID length + keyIDLen := int(binary.LittleEndian.Uint64(buff.Next(8))) + k.keyId = buff.Next(keyIDLen) + + // Get the valid until value + k.validUntil = binary.LittleEndian.Uint64(buff.Next(8)) + + return nil } // Adheres to the stringer interface diff --git a/storage/cmix/store.go b/storage/cmix/store.go index 95a67d7789c3f8f35642e96ef9d81cc864f4b0cb..32b7d506c7de63b4fa901773ff6ba5551747d293 100644 --- a/storage/cmix/store.go +++ b/storage/cmix/store.go @@ -34,10 +34,11 @@ type Store struct { nodes map[id.ID]*key dhPrivateKey *cyclic.Int dhPublicKey *cyclic.Int - - grp *cyclic.Group - kv *versioned.KV - mux sync.RWMutex + validUntil uint64 + keyId []byte + grp *cyclic.Group + kv *versioned.KV + mux sync.RWMutex } // NewStore returns a new cMix storage object. @@ -98,11 +99,12 @@ func LoadStore(kv *versioned.KV) (*Store, error) { // Add adds the key for a round to the cMix storage object. Saves the updated // list of nodes and the key to disk. -func (s *Store) Add(nid *id.ID, k *cyclic.Int) { +func (s *Store) Add(nid *id.ID, k *cyclic.Int, + validUntil uint64, keyId []byte) { s.mux.Lock() defer s.mux.Unlock() - nodeKey := newKey(s.kv, k, nid) + nodeKey := newKey(s.kv, k, nid, validUntil, keyId) s.nodes[*nid] = nodeKey if err := s.save(); err != nil { diff --git a/storage/cmix/store_test.go b/storage/cmix/store_test.go index 63329b769b1706dc4a341988a03e28fe84a37642..d84b2dff66723020d5c0b16f7b7d2b0d059911d1 100644 --- a/storage/cmix/store_test.go +++ b/storage/cmix/store_test.go @@ -8,6 +8,7 @@ package cmix import ( + "bytes" "gitlab.com/elixxir/client/storage/versioned" "gitlab.com/elixxir/crypto/cyclic" "gitlab.com/elixxir/crypto/diffieHellman" @@ -16,6 +17,7 @@ import ( "gitlab.com/xx_network/crypto/large" "gitlab.com/xx_network/primitives/id" "testing" + "time" ) // Happy path @@ -57,9 +59,9 @@ func TestStore_AddRemove(t *testing.T) { testStore, _ := makeTestStore() nodeId := id.NewIdFromString("test", id.Node, t) - key := testStore.grp.NewInt(5) - - testStore.Add(nodeId, key) + k := testStore.grp.NewInt(5) + keyId := []byte("keyId") + testStore.Add(nodeId, k, 0, keyId) if _, exists := testStore.nodes[*nodeId]; !exists { t.Fatal("Failed to add node key") } @@ -80,7 +82,7 @@ func TestStore_AddHas(t *testing.T) { nodeId := id.NewIdFromString("test", id.Node, t) key := testStore.grp.NewInt(5) - testStore.Add(nodeId, key) + testStore.Add(nodeId, key, 0, nil) if _, exists := testStore.nodes[*nodeId]; !exists { t.Fatal("Failed to add node key") } @@ -113,9 +115,17 @@ func TestLoadStore(t *testing.T) { // Add a test node key nodeId := id.NewIdFromString("test", id.Node, t) - key := testStore.grp.NewInt(5) + k := testStore.grp.NewInt(5) + testTime, err := time.Parse(time.RFC3339, + "2012-12-21T22:08:41+00:00") + if err != nil { + t.Fatalf("Could not parse precanned time: %v", err.Error()) + } + expectedValid := uint64(testTime.UnixNano()) - testStore.Add(nodeId, key) + expectedKeyId := []byte("expectedKeyID") + + testStore.Add(nodeId, k, uint64(expectedValid), expectedKeyId) // Load the store and check its attributes store, err := LoadStore(kv) @@ -131,6 +141,18 @@ func TestLoadStore(t *testing.T) { if len(store.nodes) != len(testStore.nodes) { t.Errorf("LoadStore failed to load node keys") } + + circuit := connect.NewCircuit([]*id.ID{nodeId}) + keys, _ := store.GetRoundKeys(circuit) + if keys.keys[0].validUntil != expectedValid { + t.Errorf("Unexpected valid until value loaded from store."+ + "\n\tExpected: %v\n\tReceived: %v", expectedValid, keys.keys[0].validUntil) + } + if !bytes.Equal(keys.keys[0].keyId, expectedKeyId) { + t.Errorf("Unexpected keyID value loaded from store."+ + "\n\tExpected: %v\n\tReceived: %v", expectedKeyId, keys.keys[0].keyId) + } + } // Happy path @@ -145,7 +167,7 @@ func TestStore_GetRoundKeys(t *testing.T) { for i := 0; i < numIds; i++ { nodeIds[i] = id.NewIdFromUInt(uint64(i)+1, id.Node, t) key := testStore.grp.NewInt(int64(i) + 1) - testStore.Add(nodeIds[i], key) + testStore.Add(nodeIds[i], key, 0, nil) // This is wack but it cleans up after the test defer testStore.Remove(nodeIds[i]) @@ -176,8 +198,8 @@ func TestStore_GetRoundKeys_Missing(t *testing.T) { // Only add every other node so there are missing nodes if i%2 == 0 { - testStore.Add(nodeIds[i], key) - testStore.Add(nodeIds[i], key) + testStore.Add(nodeIds[i], key, 0, nil) + testStore.Add(nodeIds[i], key, 0, nil) // This is wack but it cleans up after the test defer testStore.Remove(nodeIds[i]) @@ -211,7 +233,7 @@ func TestStore_Count(t *testing.T) { count := 50 for i := 0; i < count; i++ { - store.Add(id.NewIdFromUInt(uint64(i), id.Node, t), grp.NewInt(int64(42+i))) + store.Add(id.NewIdFromUInt(uint64(i), id.Node, t), grp.NewInt(int64(42+i)), 0, nil) } if store.Count() != count {