Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ and this library adheres to Rust's notion of
[Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- `ff::Field::{sum_of_products, sum_of_products_iter}`

## [0.13.0] - 2022-12-06
### Added
Expand Down
34 changes: 34 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,40 @@ pub trait Field:

res
}

/// Returns `a.into_iter().zip(b).fold(Self::ZERO, |acc, (a_i, b_i)| acc + a_i * b_i)`.
///
/// This computes the "dot product" or "inner product" `a ⋅ b`.
///
/// The provided implementation of this trait method uses the direct calculation given
/// above. Implementations of `Field` should override this to use more efficient
/// methods that take advantage of their internal representation, such as interleaving
/// or sharing modular reductions.
fn sum_of_products<const T: usize>(a: [Self; T], b: [Self; T]) -> Self {
a.into_iter()
.zip(b)
.fold(Self::ZERO, |acc, (a_i, b_i)| acc + a_i * b_i)
}

/// Returns `pairs.into_iter().fold(Self::ZERO, |acc, (a_i, b_i)| acc + (*a_i * b_i))`.
///
/// This computes the "dot product" or "inner product" `a ⋅ b` of two equal-length
/// sequences of elements `a` and `b`, such that `pairs = a.iter().zip(b.iter())`.
///
/// This method is generally slower than [`Self::sum_of_products`] but allows for the
/// number of pairs to be determined at runtime.
///
/// The provided implementation of this trait method uses the direct calculation given
/// above. Implementations of `Field` should override this to use more efficient
/// methods that take advantage of their internal representation, such as interleaving
/// or sharing modular reductions.
fn sum_of_products_iter<'a, I: IntoIterator<Item = (&'a Self, &'a Self)> + Clone>(
pairs: I,
) -> Self {
pairs
.into_iter()
.fold(Self::ZERO, |acc, (a_i, b_i)| acc + (*a_i * b_i))
}
}

/// This represents an element of a non-binary prime field.
Expand Down
68 changes: 68 additions & 0 deletions tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,74 @@ fn from_u128() {
);
}

#[test]
fn sum_of_products() {
use ff::{Field, PrimeField};

let one = Bls381K12Scalar::one();

// [1, 2, 3, 4]
let values = {
let mut iter = (0..4).scan(one, |acc, _| {
let ret = *acc;
*acc += &one;
Some(ret)
});
[
iter.next().unwrap(),
iter.next().unwrap(),
iter.next().unwrap(),
iter.next().unwrap(),
]
};

// We'll pair each value with itself.
let expected = Bls381K12Scalar::from_str_vartime("30").unwrap();

assert_eq!(Bls381K12Scalar::sum_of_products(values, values), expected,);
}

#[test]
fn sum_of_products_iter() {
use ff::{Field, PrimeField};

let one = Bls381K12Scalar::one();

// [1, 2, 3, 4]
let values: Vec<_> = (0..4)
.scan(one, |acc, _| {
let ret = *acc;
*acc += &one;
Some(ret)
})
.collect();

// We'll pair each value with itself.
let expected = Bls381K12Scalar::from_str_vartime("30").unwrap();

// Check that we can produce the necessary input from two iterators.
assert_eq!(
// Directly produces (&v, &v)
Bls381K12Scalar::sum_of_products_iter(values.iter().zip(values.iter())),
expected,
);

// Check that we can produce the necessary input from an iterator of values.
assert_eq!(
// Maps &v to (&v, &v)
Bls381K12Scalar::sum_of_products_iter(values.iter().map(|v| (v, v))),
expected,
);

// Check that we can produce the necessary input from an iterator of tuples.
let tuples: Vec<_> = values.into_iter().map(|v| (v, v)).collect();
assert_eq!(
// Maps &(a, b) to (&a, &b)
Bls381K12Scalar::sum_of_products_iter(tuples.iter().map(|(a, b)| (a, b))),
expected,
);
}

#[test]
fn batch_inversion() {
use ff::{BatchInverter, Field};
Expand Down
Loading