diff --git a/rpc/cmix.go b/rpc/cmix.go index 55d4b92def992b9e0123bf6af6d1eb7df910f799..a3af46f2b9a61e9983d3f8fc48cd149793937f7a 100644 --- a/rpc/cmix.go +++ b/rpc/cmix.go @@ -19,12 +19,17 @@ import ( "gitlab.com/elixxir/client/v4/cmix/message" "gitlab.com/elixxir/client/v4/cmix/rounds" "gitlab.com/elixxir/crypto/fastRNG" - "gitlab.com/elixxir/crypto/nike" "gitlab.com/elixxir/primitives/format" "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/id/ephemeral" ) +func GenerateRandomID(net cMixClient) (*id.ID, error) { + rng := net.RNGStreamGenerator().GetStream() + defer rng.Close() + return generateRandomID(rng) +} + // cMix functions required for this module. type cMixClient interface { AddIdentityWithHistory(id *id.ID, validUntil, beginning time.Time, @@ -166,29 +171,6 @@ func createRandomService(rng io.Reader) message.Service { } } -func generateRandomKeys(scheme nike.Nike, rng io.Reader) (nike.PrivateKey, - nike.PublicKey, error) { - // Create compatible keypair - privateKeyBytes := make([]byte, scheme.PrivateKeySize()) - n, err := rng.Read(privateKeyBytes) - if err != nil { - return nil, nil, err - } - if n != len(privateKeyBytes) { - return nil, nil, errors.Errorf("short read: %d < %d", n, - len(privateKeyBytes)) - } - - privateKey := scheme.NewEmptyPrivateKey() - err = privateKey.FromBytes(privateKeyBytes) - if err != nil { - return nil, nil, err - } - publicKey := scheme.DerivePublicKey(privateKey) - - return privateKey, publicKey, nil -} - func generateRandomID(rng io.Reader) (*id.ID, error) { idBytes := make([]byte, id.ArrIDLen) n, err := rng.Read(idBytes) diff --git a/rpc/send.go b/rpc/send.go index 51e6a30ab282367849f5840a535723590c7e351c..14c2bb14faf763c0c1f446ca455614ce9ed25084 100644 --- a/rpc/send.go +++ b/rpc/send.go @@ -86,6 +86,7 @@ func Send(net cMixClient, serverID *id.ID, serverKey nike.PublicKey, return responseErr(errors.Errorf("[RPC] Query too big, %d > %d", len(request), int(headerSz)-2-len(myID.Bytes()))) } + // ciphertext is the encrypted part of the message msgToSend := make([]byte, maxPayloadSz) copy(msgToSend, public.Bytes()) @@ -165,7 +166,6 @@ type response struct { // listeners complete. func (r *response) Close() { close(r.listener) - r.wg.Done() } func (r *response) Callback(respFn func(response []byte), @@ -189,6 +189,7 @@ func (r *response) Callback(respFn func(response []byte), if r.cipher != nil { r.cipher.Reset() } + r.wg.Done() }() return r } diff --git a/rpc/server.go b/rpc/server.go index 20f2da9ae475a17313fbb242bef9780844c2b5da..1ba104d34b335ae825d4da99c166af82ae068343 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -107,6 +107,7 @@ func (r *rpcServer) Process(cMixMsg format.Message, _ []string, _ []byte, if err != nil { jww.ERROR.Printf("[RPC] couldn't decrypt msg: %s, %+v", cMixMsg.GetKeyFP(), err) + jww.ERROR.Printf("[RPC] GARBLED: %v, %v", ct[:pkSz], ct[pkSz:]) return } @@ -160,14 +161,22 @@ func (r *rpcServer) Process(cMixMsg format.Message, _ []string, _ []byte, ct = append(replyMsgId[:], ct...) ciphertexts[i] = ct } - rnd, ids, err := send(r.net, ephCmixId, ciphertexts, - cmix.GetDefaultCMIXParams()) - if err != nil { - jww.ERROR.Printf("[RPC] bad reply to %s,%s: %+v", - ephCmixId, cMixMsg.GetKeyFP(), err) + for { + rnd, _, err := send(r.net, ephCmixId, + ciphertexts, + cmix.GetDefaultCMIXParams()) + if err != nil { + jww.ERROR.Printf( + "[RPC] bad reply to %s,%s: %+v", + ephCmixId, cMixMsg.GetKeyFP(), + err) + continue + } + jww.INFO.Printf("[RPC] reply to %s,%s "+ + "sent: %s", + ephCmixId, cMixMsg.GetKeyFP(), rnd.ID) + return } - jww.INFO.Printf("[RPC] reply to %s,%s sent: %v, %v", - ephCmixId, cMixMsg.GetKeyFP(), rnd, ids) } // Response with Internal Server Error on crash diff --git a/rpc/server_test.go b/rpc/server_test.go index da127aa42c966e200cf3eac6b6d4fb0c7a3c0d47..c01959a0bfa6685367ba8e0f0727d5a216635f45 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -9,12 +9,15 @@ package rpc import ( "bytes" + "encoding/base64" "encoding/json" "testing" "github.com/stretchr/testify/require" "gitlab.com/elixxir/client/v4/cmix" + "gitlab.com/elixxir/crypto/nike" "gitlab.com/elixxir/crypto/nike/ecdh" + "gitlab.com/xx_network/crypto/csprng" "gitlab.com/xx_network/primitives/id" ) @@ -69,3 +72,85 @@ func TestServer(t *testing.T) { d := r.Wait() require.Equal(t, uint64(len(d)), 10*maxPayloadSz) } + +func TestMarshalling(t *testing.T) { + rng := csprng.NewSystemRNG() + + for i := 0; i < 10; i++ { + pr1, pu1 := ecdh.ECDHNIKE.NewKeypair(rng) + pr2, pu2 := ecdh.ECDHNIKE.NewKeypair(rng) + + // sanity check + nk1 := pr1.DeriveSecret(pu2) + nk2 := pr2.DeriveSecret(pu1) + require.Equal(t, nk1, nk2) + + // Now Dump to Base64 string + pr1Str := base64.RawStdEncoding.EncodeToString(pr1.Bytes()) + pu1Str := base64.RawStdEncoding.EncodeToString(pu1.Bytes()) + pr2Str := base64.RawStdEncoding.EncodeToString(pr2.Bytes()) + pu2Str := base64.RawStdEncoding.EncodeToString(pu2.Bytes()) + + srvPriv, srvPub := checkKeys(pr1Str, pu1Str, t) + cliPriv, cliPub := checkKeys(pr2Str, pu2Str, t) + + sk1 := srvPriv.DeriveSecret(cliPub) + sk2 := cliPriv.DeriveSecret(srvPub) + + t.Logf("sk1: %+v", sk1) + t.Logf("sk2: %+v", sk2) + + require.Equal(t, sk1, sk2) + require.Equal(t, nk1, sk1) + } +} + +func TestVectors(t *testing.T) { + rng := csprng.NewSystemRNG() + + serverPriv, serverPub := ecdh.ECDHNIKE.NewKeypair(rng) + clientPriv, clientPub := ecdh.ECDHNIKE.NewKeypair(rng) + + nk1 := serverPriv.DeriveSecret(clientPub) + nk2 := clientPriv.DeriveSecret(serverPub) + + require.Equal(t, nk1, nk2) + t.Logf("nk1: %v", nk1) + + srvPrivStr := "WBfCA63BbC9vaOeRi0FpqlfKjOedQM6TW89AwlfnVKfKrGHPN5J9Qh0cC9I4yzEG6brIveyvwI8RvTeT0OjfTw" + srvPubStr := "yqxhzzeSfUIdHAvSOMsxBum6yL3sr8CPEb03k9Do308" + + clientPrivStr := "NOPOKU2+SrsQv7DSXkiOkUytGdZeF/+H+j4bn50XbSNnvFqJawvtnofZUW1wlUXsHa55CKLrLUp6nMV8f5orgQ" + clientPubStr := "Z7xaiWsL7Z6H2VFtcJVF7B2ueQii6y1KepzFfH+aK4E" + + srvPriv, srvPub := checkKeys(srvPrivStr, srvPubStr, t) + cliPriv, cliPub := checkKeys(clientPrivStr, clientPubStr, t) + + sk1 := srvPriv.DeriveSecret(cliPub) + sk2 := cliPriv.DeriveSecret(srvPub) + + t.Logf("sk1: %+v", sk1) + t.Logf("sk2: %+v", sk2) + + require.Equal(t, sk1, sk2) +} + +func checkKeys(privStr, pubStr string, t *testing.T) (nike.PrivateKey, + nike.PublicKey) { + privBytes, err := base64.RawStdEncoding.DecodeString(privStr) + require.NoError(t, err) + + priv := ecdh.ECDHNIKE.NewEmptyPrivateKey() + err = priv.FromBytes(privBytes) + require.NoError(t, err) + + derivedPub := priv.Scheme().DerivePublicKey(priv) + require.Equal(t, pubStr, base64.RawStdEncoding.EncodeToString( + derivedPub.Bytes())) + pub := ecdh.ECDHNIKE.NewEmptyPublicKey() + pubBytes, err := base64.RawStdEncoding.DecodeString(pubStr) + require.NoError(t, err) + err = pub.FromBytes(pubBytes) + require.NoError(t, err) + return priv, pub +}