diff --git a/math/primes.go b/math/primes.go index 158fd83a7..53bba3111 100644 --- a/math/primes.go +++ b/math/primes.go @@ -32,3 +32,56 @@ func SafePrime(random io.Reader, bits int) (*big.Int, error) { } } } + +// SafePrimeConcurrent generates a safe prime concurrently. +func SafePrimeConcurrent(bits int, workers int) (*big.Int, error) { + found := make(chan *big.Int, 1) + errChan := make(chan error, workers) + exitFlag := false + + worker := func() { + defer func() { + exitFlag = true + }() + for { + if exitFlag { + return + } + // Generate a candidate prime q + q, err := rand.Prime(rand.Reader, bits-1) + if err != nil { + errChan <- err + return + } + + one := big.NewInt(1) + p := new(big.Int) + p.Lsh(q, 1).Add(p, one) + + // Check if p is prime + if p.ProbablyPrime(20) { + select { + case found <- p: + return + default: + return + } + } + } + } + + // Start worker goroutines + for i := 0; i < workers; i++ { + go worker() + } + + // Return the first result from any worker + for { + select { + case p := <-found: + return p, nil + case err := <-errChan: + return nil, err + } + } +} diff --git a/tss/rsa/rsa_threshold.go b/tss/rsa/rsa_threshold.go index 424b8236b..3354c3da6 100644 --- a/tss/rsa/rsa_threshold.go +++ b/tss/rsa/rsa_threshold.go @@ -71,6 +71,52 @@ func GenerateKey(random io.Reader, bits int) (*rsa.PrivateKey, error) { return priv, nil } +func GenerateKeyConcurrent(bits int, workers int) (*rsa.PrivateKey, error) { + p, err := cmath.SafePrimeConcurrent(bits/2, workers) + if err != nil { + return nil, err + } + + var q *big.Int + n := new(big.Int) + found := false + for !found { + q, err = cmath.SafePrimeConcurrent(bits-p.BitLen(), workers) + if err != nil { + return nil, err + } + + // check for different primes. + if p.Cmp(q) != 0 { + n.Mul(p, q) + // check n has the desired bitlength. + if n.BitLen() == bits { + found = true + } + } + } + + one := big.NewInt(1) + pminus1 := new(big.Int).Sub(p, one) + qminus1 := new(big.Int).Sub(q, one) + totient := new(big.Int).Mul(pminus1, qminus1) + + priv := new(rsa.PrivateKey) + priv.Primes = []*big.Int{p, q} + priv.N = n + priv.E = 65537 + priv.D = new(big.Int) + e := big.NewInt(int64(priv.E)) + ok := priv.D.ModInverse(e, totient) + if ok == nil { + return nil, errors.New("public key is not coprime to phi(n)") + } + + priv.Precompute() + + return priv, nil +} + // l or `Players`, the total number of Players. // t, the number of corrupted Players. // k=t+1 or `Threshold`, the number of signature shares needed to obtain a signature. diff --git a/tss/rsa/rsa_threshold_test.go b/tss/rsa/rsa_threshold_test.go index 82fc224d6..20297cd08 100644 --- a/tss/rsa/rsa_threshold_test.go +++ b/tss/rsa/rsa_threshold_test.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "math/big" + "runtime" "testing" "github.com/cloudflare/circl/internal/test" @@ -25,6 +26,15 @@ func TestGenerateKey(t *testing.T) { test.CheckOk(key.Validate() == nil, fmt.Sprintf("key is not valid: %v", key), t) } +func TestGenerateKeyConcurrent(t *testing.T) { + // [Warning]: this is only for tests, use a secure bitlen above 2048 bits. + bitlen := 128 + numCPU := runtime.NumCPU() + key, err := GenerateKeyConcurrent(bitlen, numCPU) + test.CheckNoErr(t, err, "failed to create key") + test.CheckOk(key.Validate() == nil, fmt.Sprintf("key is not valid: %v", key), t) +} + func createPrivateKey(p, q *big.Int, e int) *rsa.PrivateKey { return &rsa.PrivateKey{ PublicKey: rsa.PublicKey{