From 2e7fb078722decb07ab8da7fb59c8ec46fb32fd1 Mon Sep 17 00:00:00 2001
From: David Stainton <dstainton415@gmail.com>
Date: Mon, 17 Oct 2022 23:30:36 -0400
Subject: [PATCH] Add GenerateKeyPairWithRNG

---
 binding.go      | 57 +++++++++++++++++++++++++++++++++++++++++++++++--
 binding_test.go | 11 ++++++++++
 2 files changed, 66 insertions(+), 2 deletions(-)

diff --git a/binding.go b/binding.go
index 986174d..df19c0e 100644
--- a/binding.go
+++ b/binding.go
@@ -1,14 +1,31 @@
 package ctidh
 
-// #include "binding.h"
-// #include <csidh.h>
+/*
+#include "binding.h"
+#include <csidh.h>
+
+void custom_gen_private(private_key *priv) {
+  csidh_private_withrng(priv, fillrandom_custom);
+}
+
+void fillrandom_custom(
+  void *const outptr,
+  const size_t outsz,
+  const uintptr_t context)
+{
+  (void) context;
+  go_fillrandom(outptr, outsz);
+}
+*/
 import "C"
 import (
 	"bytes"
 	"crypto/hmac"
 	"encoding/pem"
 	"fmt"
+	"io"
 	"io/ioutil"
+	"sync"
 	"unsafe"
 )
 
@@ -33,6 +50,9 @@ var (
 
 	// ErrCTIDH indicates a group action failure.
 	ErrCTIDH error = fmt.Errorf("%s: group action failure", Name())
+
+	privateKeyRNGLock sync.Mutex
+	privateKeyRNG     io.Reader = nil
 )
 
 // ErrPEMKeyTypeMismatch returns an error indicating that we tried
@@ -300,6 +320,39 @@ func GenerateKeyPair() (*PrivateKey, *PublicKey) {
 	return privKey, DerivePublicKey(privKey)
 }
 
+//export go_fillrandom
+func go_fillrandom(outptr unsafe.Pointer, outsz C.size_t) {
+	buf := make([]byte, outsz)
+	privateKeyRNGLock.Lock()
+	rng := privateKeyRNG
+	privateKeyRNGLock.Unlock()
+	count, err := rng.Read(buf)
+	if err != nil {
+		panic(err)
+	}
+	if count != int(outsz) {
+		panic("rng fail")
+	}
+	p := uintptr(outptr)
+	for i := 0; i < int(outsz); i++ {
+		(*(*uint8)(unsafe.Pointer(p))) = uint8(buf[i])
+		p += 1
+	}
+}
+
+// GenerateKeyPairWithRNG uses the given RNG to derive a new keypair.
+// HOWEVER, if GenerateKeyPairWithRNG is called by multiple threads
+// then the rng is overwritten with each call which can unintentionally
+// cause multiple RNGs to be used to generate the keypair.
+func GenerateKeyPairWithRNG(rng io.Reader) (*PrivateKey, *PublicKey) {
+	privKey := new(PrivateKey)
+	privateKeyRNGLock.Lock()
+	privateKeyRNG = rng
+	privateKeyRNGLock.Unlock()
+	C.custom_gen_private(&privKey.privateKey)
+	return privKey, DerivePublicKey(privKey)
+}
+
 func groupAction(privateKey *PrivateKey, publicKey *PublicKey) *PublicKey {
 	sharedKey := new(PublicKey)
 	ok := C.csidh(&sharedKey.publicKey, &publicKey.publicKey, &privateKey.privateKey)
diff --git a/binding_test.go b/binding_test.go
index e18bba6..2efbfac 100644
--- a/binding_test.go
+++ b/binding_test.go
@@ -1,6 +1,7 @@
 package ctidh
 
 import (
+	"crypto/rand"
 	"os"
 	"path/filepath"
 	"testing"
@@ -8,6 +9,16 @@ import (
 	"github.com/stretchr/testify/require"
 )
 
+func TestGenerateKeyPairWithRNG(t *testing.T) {
+	privateKey, publicKey := GenerateKeyPairWithRNG(rand.Reader)
+
+	zeros := make([]byte, PublicKeySize)
+	require.NotEqual(t, privateKey.Bytes(), zeros)
+	require.NotEqual(t, publicKey.Bytes(), zeros)
+
+	t.Logf("privateKey.Bytes() %x", privateKey.Bytes())
+}
+
 func TestPrivateKeyPEMSerialization(t *testing.T) {
 	privateKey, _ := GenerateKeyPair()
 
-- 
GitLab