From bf00cca44e70989f2e89e97f8d96d8f67944b3ea Mon Sep 17 00:00:00 2001 From: Jack Grigg Date: Tue, 26 Apr 2022 08:48:16 +0000 Subject: [PATCH] Add `Field::sum_of_products` method Closes zkcrypto/ff#79. --- CHANGELOG.md | 2 ++ src/lib.rs | 34 +++++++++++++++++++++++++ tests/derive.rs | 68 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 562a80d..242a9c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ and this library adheres to Rust's notion of [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added +- `ff::Field::{sum_of_products, sum_of_products_iter}` ## [0.13.0] - 2022-12-06 ### Added diff --git a/src/lib.rs b/src/lib.rs index f9eee3c..a771800 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -189,6 +189,40 @@ pub trait Field: res } + + /// Returns `a.into_iter().zip(b).fold(Self::ZERO, |acc, (a_i, b_i)| acc + a_i * b_i)`. + /// + /// This computes the "dot product" or "inner product" `a ⋅ b`. + /// + /// The provided implementation of this trait method uses the direct calculation given + /// above. Implementations of `Field` should override this to use more efficient + /// methods that take advantage of their internal representation, such as interleaving + /// or sharing modular reductions. + fn sum_of_products(a: [Self; T], b: [Self; T]) -> Self { + a.into_iter() + .zip(b) + .fold(Self::ZERO, |acc, (a_i, b_i)| acc + a_i * b_i) + } + + /// Returns `pairs.into_iter().fold(Self::ZERO, |acc, (a_i, b_i)| acc + (*a_i * b_i))`. + /// + /// This computes the "dot product" or "inner product" `a ⋅ b` of two equal-length + /// sequences of elements `a` and `b`, such that `pairs = a.iter().zip(b.iter())`. + /// + /// This method is generally slower than [`Self::sum_of_products`] but allows for the + /// number of pairs to be determined at runtime. + /// + /// The provided implementation of this trait method uses the direct calculation given + /// above. Implementations of `Field` should override this to use more efficient + /// methods that take advantage of their internal representation, such as interleaving + /// or sharing modular reductions. + fn sum_of_products_iter<'a, I: IntoIterator + Clone>( + pairs: I, + ) -> Self { + pairs + .into_iter() + .fold(Self::ZERO, |acc, (a_i, b_i)| acc + (*a_i * b_i)) + } } /// This represents an element of a non-binary prime field. diff --git a/tests/derive.rs b/tests/derive.rs index 5baf435..b3724c8 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -86,6 +86,74 @@ fn from_u128() { ); } +#[test] +fn sum_of_products() { + use ff::{Field, PrimeField}; + + let one = Bls381K12Scalar::one(); + + // [1, 2, 3, 4] + let values = { + let mut iter = (0..4).scan(one, |acc, _| { + let ret = *acc; + *acc += &one; + Some(ret) + }); + [ + iter.next().unwrap(), + iter.next().unwrap(), + iter.next().unwrap(), + iter.next().unwrap(), + ] + }; + + // We'll pair each value with itself. + let expected = Bls381K12Scalar::from_str_vartime("30").unwrap(); + + assert_eq!(Bls381K12Scalar::sum_of_products(values, values), expected,); +} + +#[test] +fn sum_of_products_iter() { + use ff::{Field, PrimeField}; + + let one = Bls381K12Scalar::one(); + + // [1, 2, 3, 4] + let values: Vec<_> = (0..4) + .scan(one, |acc, _| { + let ret = *acc; + *acc += &one; + Some(ret) + }) + .collect(); + + // We'll pair each value with itself. + let expected = Bls381K12Scalar::from_str_vartime("30").unwrap(); + + // Check that we can produce the necessary input from two iterators. + assert_eq!( + // Directly produces (&v, &v) + Bls381K12Scalar::sum_of_products_iter(values.iter().zip(values.iter())), + expected, + ); + + // Check that we can produce the necessary input from an iterator of values. + assert_eq!( + // Maps &v to (&v, &v) + Bls381K12Scalar::sum_of_products_iter(values.iter().map(|v| (v, v))), + expected, + ); + + // Check that we can produce the necessary input from an iterator of tuples. + let tuples: Vec<_> = values.into_iter().map(|v| (v, v)).collect(); + assert_eq!( + // Maps &(a, b) to (&a, &b) + Bls381K12Scalar::sum_of_products_iter(tuples.iter().map(|(a, b)| (a, b))), + expected, + ); +} + #[test] fn batch_inversion() { use ff::{BatchInverter, Field};