From a96ffd16ad9bd195efc07c06efaabe000869dbc8 Mon Sep 17 00:00:00 2001 From: "Richard T. Carback III" <rick.carback@gmail.com> Date: Mon, 17 Jun 2024 11:40:40 -0400 Subject: [PATCH] Added some sanity checks for key generation/derivation based on broken keygen. --- rpc/cmix.go | 30 ++++------------ rpc/send.go | 3 +- rpc/server.go | 23 +++++++++---- rpc/server_test.go | 85 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 109 insertions(+), 32 deletions(-) diff --git a/rpc/cmix.go b/rpc/cmix.go index 55d4b92de..a3af46f2b 100644 --- a/rpc/cmix.go +++ b/rpc/cmix.go @@ -19,12 +19,17 @@ import ( "gitlab.com/elixxir/client/v4/cmix/message" "gitlab.com/elixxir/client/v4/cmix/rounds" "gitlab.com/elixxir/crypto/fastRNG" - "gitlab.com/elixxir/crypto/nike" "gitlab.com/elixxir/primitives/format" "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/id/ephemeral" ) +func GenerateRandomID(net cMixClient) (*id.ID, error) { + rng := net.RNGStreamGenerator().GetStream() + defer rng.Close() + return generateRandomID(rng) +} + // cMix functions required for this module. type cMixClient interface { AddIdentityWithHistory(id *id.ID, validUntil, beginning time.Time, @@ -166,29 +171,6 @@ func createRandomService(rng io.Reader) message.Service { } } -func generateRandomKeys(scheme nike.Nike, rng io.Reader) (nike.PrivateKey, - nike.PublicKey, error) { - // Create compatible keypair - privateKeyBytes := make([]byte, scheme.PrivateKeySize()) - n, err := rng.Read(privateKeyBytes) - if err != nil { - return nil, nil, err - } - if n != len(privateKeyBytes) { - return nil, nil, errors.Errorf("short read: %d < %d", n, - len(privateKeyBytes)) - } - - privateKey := scheme.NewEmptyPrivateKey() - err = privateKey.FromBytes(privateKeyBytes) - if err != nil { - return nil, nil, err - } - publicKey := scheme.DerivePublicKey(privateKey) - - return privateKey, publicKey, nil -} - func generateRandomID(rng io.Reader) (*id.ID, error) { idBytes := make([]byte, id.ArrIDLen) n, err := rng.Read(idBytes) diff --git a/rpc/send.go b/rpc/send.go index 51e6a30ab..14c2bb14f 100644 --- a/rpc/send.go +++ b/rpc/send.go @@ -86,6 +86,7 @@ func Send(net cMixClient, serverID *id.ID, serverKey nike.PublicKey, return responseErr(errors.Errorf("[RPC] Query too big, %d > %d", len(request), int(headerSz)-2-len(myID.Bytes()))) } + // ciphertext is the encrypted part of the message msgToSend := make([]byte, maxPayloadSz) copy(msgToSend, public.Bytes()) @@ -165,7 +166,6 @@ type response struct { // listeners complete. func (r *response) Close() { close(r.listener) - r.wg.Done() } func (r *response) Callback(respFn func(response []byte), @@ -189,6 +189,7 @@ func (r *response) Callback(respFn func(response []byte), if r.cipher != nil { r.cipher.Reset() } + r.wg.Done() }() return r } diff --git a/rpc/server.go b/rpc/server.go index 20f2da9ae..1ba104d34 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -107,6 +107,7 @@ func (r *rpcServer) Process(cMixMsg format.Message, _ []string, _ []byte, if err != nil { jww.ERROR.Printf("[RPC] couldn't decrypt msg: %s, %+v", cMixMsg.GetKeyFP(), err) + jww.ERROR.Printf("[RPC] GARBLED: %v, %v", ct[:pkSz], ct[pkSz:]) return } @@ -160,14 +161,22 @@ func (r *rpcServer) Process(cMixMsg format.Message, _ []string, _ []byte, ct = append(replyMsgId[:], ct...) ciphertexts[i] = ct } - rnd, ids, err := send(r.net, ephCmixId, ciphertexts, - cmix.GetDefaultCMIXParams()) - if err != nil { - jww.ERROR.Printf("[RPC] bad reply to %s,%s: %+v", - ephCmixId, cMixMsg.GetKeyFP(), err) + for { + rnd, _, err := send(r.net, ephCmixId, + ciphertexts, + cmix.GetDefaultCMIXParams()) + if err != nil { + jww.ERROR.Printf( + "[RPC] bad reply to %s,%s: %+v", + ephCmixId, cMixMsg.GetKeyFP(), + err) + continue + } + jww.INFO.Printf("[RPC] reply to %s,%s "+ + "sent: %s", + ephCmixId, cMixMsg.GetKeyFP(), rnd.ID) + return } - jww.INFO.Printf("[RPC] reply to %s,%s sent: %v, %v", - ephCmixId, cMixMsg.GetKeyFP(), rnd, ids) } // Response with Internal Server Error on crash diff --git a/rpc/server_test.go b/rpc/server_test.go index da127aa42..c01959a0b 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -9,12 +9,15 @@ package rpc import ( "bytes" + "encoding/base64" "encoding/json" "testing" "github.com/stretchr/testify/require" "gitlab.com/elixxir/client/v4/cmix" + "gitlab.com/elixxir/crypto/nike" "gitlab.com/elixxir/crypto/nike/ecdh" + "gitlab.com/xx_network/crypto/csprng" "gitlab.com/xx_network/primitives/id" ) @@ -69,3 +72,85 @@ func TestServer(t *testing.T) { d := r.Wait() require.Equal(t, uint64(len(d)), 10*maxPayloadSz) } + +func TestMarshalling(t *testing.T) { + rng := csprng.NewSystemRNG() + + for i := 0; i < 10; i++ { + pr1, pu1 := ecdh.ECDHNIKE.NewKeypair(rng) + pr2, pu2 := ecdh.ECDHNIKE.NewKeypair(rng) + + // sanity check + nk1 := pr1.DeriveSecret(pu2) + nk2 := pr2.DeriveSecret(pu1) + require.Equal(t, nk1, nk2) + + // Now Dump to Base64 string + pr1Str := base64.RawStdEncoding.EncodeToString(pr1.Bytes()) + pu1Str := base64.RawStdEncoding.EncodeToString(pu1.Bytes()) + pr2Str := base64.RawStdEncoding.EncodeToString(pr2.Bytes()) + pu2Str := base64.RawStdEncoding.EncodeToString(pu2.Bytes()) + + srvPriv, srvPub := checkKeys(pr1Str, pu1Str, t) + cliPriv, cliPub := checkKeys(pr2Str, pu2Str, t) + + sk1 := srvPriv.DeriveSecret(cliPub) + sk2 := cliPriv.DeriveSecret(srvPub) + + t.Logf("sk1: %+v", sk1) + t.Logf("sk2: %+v", sk2) + + require.Equal(t, sk1, sk2) + require.Equal(t, nk1, sk1) + } +} + +func TestVectors(t *testing.T) { + rng := csprng.NewSystemRNG() + + serverPriv, serverPub := ecdh.ECDHNIKE.NewKeypair(rng) + clientPriv, clientPub := ecdh.ECDHNIKE.NewKeypair(rng) + + nk1 := serverPriv.DeriveSecret(clientPub) + nk2 := clientPriv.DeriveSecret(serverPub) + + require.Equal(t, nk1, nk2) + t.Logf("nk1: %v", nk1) + + srvPrivStr := "WBfCA63BbC9vaOeRi0FpqlfKjOedQM6TW89AwlfnVKfKrGHPN5J9Qh0cC9I4yzEG6brIveyvwI8RvTeT0OjfTw" + srvPubStr := "yqxhzzeSfUIdHAvSOMsxBum6yL3sr8CPEb03k9Do308" + + clientPrivStr := "NOPOKU2+SrsQv7DSXkiOkUytGdZeF/+H+j4bn50XbSNnvFqJawvtnofZUW1wlUXsHa55CKLrLUp6nMV8f5orgQ" + clientPubStr := "Z7xaiWsL7Z6H2VFtcJVF7B2ueQii6y1KepzFfH+aK4E" + + srvPriv, srvPub := checkKeys(srvPrivStr, srvPubStr, t) + cliPriv, cliPub := checkKeys(clientPrivStr, clientPubStr, t) + + sk1 := srvPriv.DeriveSecret(cliPub) + sk2 := cliPriv.DeriveSecret(srvPub) + + t.Logf("sk1: %+v", sk1) + t.Logf("sk2: %+v", sk2) + + require.Equal(t, sk1, sk2) +} + +func checkKeys(privStr, pubStr string, t *testing.T) (nike.PrivateKey, + nike.PublicKey) { + privBytes, err := base64.RawStdEncoding.DecodeString(privStr) + require.NoError(t, err) + + priv := ecdh.ECDHNIKE.NewEmptyPrivateKey() + err = priv.FromBytes(privBytes) + require.NoError(t, err) + + derivedPub := priv.Scheme().DerivePublicKey(priv) + require.Equal(t, pubStr, base64.RawStdEncoding.EncodeToString( + derivedPub.Bytes())) + pub := ecdh.ECDHNIKE.NewEmptyPublicKey() + pubBytes, err := base64.RawStdEncoding.DecodeString(pubStr) + require.NoError(t, err) + err = pub.FromBytes(pubBytes) + require.NoError(t, err) + return priv, pub +} -- GitLab