diff --git a/.gitignore b/.gitignore index b3a38e13b..fe666147b 100644 --- a/.gitignore +++ b/.gitignore @@ -14,11 +14,16 @@ *.pkp *.pkv *.np +*.sp +*.spc +*.spctx +spark_proofs/ params_for_recursive_verifier params artifacts/ spartan_vm_debug/ mavros_debug/ +mavros/ # Don't ignore benchmarking artifacts !tooling/provekit-bench/benches/* diff --git a/Cargo.lock b/Cargo.lock index b1f19ea22..b9c4c5f57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4584,7 +4584,7 @@ dependencies = [ [[package]] name = "provekit-bench" -version = "0.1.0" +version = "1.0.0" dependencies = [ "acir", "anyhow", @@ -4606,14 +4606,16 @@ dependencies = [ [[package]] name = "provekit-cli" -version = "0.1.0" +version = "1.0.0" dependencies = [ "acir", "anyhow", "argh", "ark-ff 0.5.0", "base64", + "bincode", "hex", + "mavros-artifacts", "nargo", "nargo_toml", "noir_artifact_cli", @@ -4624,6 +4626,7 @@ dependencies = [ "provekit-gnark", "provekit-prover", "provekit-r1cs-compiler", + "provekit-spark", "provekit-verifier", "rayon", "serde_json", @@ -4635,7 +4638,7 @@ dependencies = [ [[package]] name = "provekit-common" -version = "0.1.0" +version = "1.0.0" dependencies = [ "acir", "anyhow", @@ -4675,7 +4678,7 @@ dependencies = [ [[package]] name = "provekit-ffi" -version = "0.1.0" +version = "1.0.0" dependencies = [ "anyhow", "libc", @@ -4705,7 +4708,7 @@ dependencies = [ [[package]] name = "provekit-prover" -version = "0.1.0" +version = "1.0.0" dependencies = [ "acir", "anyhow", @@ -4726,7 +4729,7 @@ dependencies = [ [[package]] name = "provekit-r1cs-compiler" -version = "0.1.0" +version = "1.0.0" dependencies = [ "acir", "anyhow", @@ -4746,8 +4749,23 @@ dependencies = [ ] [[package]] -name = "provekit-verifier" +name = "provekit-spark" version = "0.1.0" +dependencies = [ + "anyhow", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", + "ark-std 0.5.0", + "provekit-common", + "rayon", + "serde", + "tracing", + "whir", +] + +[[package]] +name = "provekit-verifier" +version = "1.0.0" dependencies = [ "anyhow", "ark-std 0.5.0", @@ -7034,7 +7052,7 @@ dependencies = [ [[package]] name = "verifier-server" -version = "0.1.0" +version = "1.0.0" dependencies = [ "anyhow", "axum", diff --git a/Cargo.toml b/Cargo.toml index dc6ce9676..4f21e9c26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ members = [ "tooling/provekit-wasm", "tooling/verifier-server", "ntt", + "provekit/spark", "poseidon2", "playground/passport-input-gen", ] @@ -102,6 +103,7 @@ provekit-ffi = { path = "tooling/provekit-ffi" } provekit-gnark = { path = "tooling/provekit-gnark" } provekit-prover = { path = "provekit/prover", default-features = false } provekit-r1cs-compiler = { path = "provekit/r1cs-compiler" } +provekit-spark = { path = "provekit/spark" } provekit-verifier = { path = "provekit/verifier" } provekit-verifier-server = { path = "tooling/verifier-server" } provekit-wasm = { path = "tooling/provekit-wasm" } diff --git a/provekit/common/src/file/binary_format.rs b/provekit/common/src/file/binary_format.rs index 44ff55717..ed4c44284 100644 --- a/provekit/common/src/file/binary_format.rs +++ b/provekit/common/src/file/binary_format.rs @@ -18,10 +18,16 @@ pub const PROVER_FORMAT: [u8; 8] = *b"PrvKitPr"; pub const PROVER_VERSION: (u16, u16) = (1, 2); pub const VERIFIER_FORMAT: [u8; 8] = *b"PrvKitVr"; -pub const VERIFIER_VERSION: (u16, u16) = (1, 3); +pub const VERIFIER_VERSION: (u16, u16) = (1, 4); pub const NOIR_PROOF_SCHEME_FORMAT: [u8; 8] = *b"NrProScm"; pub const NOIR_PROOF_SCHEME_VERSION: (u16, u16) = (1, 2); pub const NOIR_PROOF_FORMAT: [u8; 8] = *b"NPSProof"; -pub const NOIR_PROOF_VERSION: (u16, u16) = (1, 1); +pub const NOIR_PROOF_VERSION: (u16, u16) = (1, 2); + +pub const SPARK_PROOF_FORMAT: [u8; 8] = *b"SparkPrf"; +pub const SPARK_PROOF_VERSION: (u16, u16) = (1, 0); + +pub const SPARK_CONTEXT_FORMAT: [u8; 8] = *b"SparkCtx"; +pub const SPARK_CONTEXT_VERSION: (u16, u16) = (1, 0); diff --git a/provekit/common/src/file/io/mod.rs b/provekit/common/src/file/io/mod.rs index 049c984a7..964a8f17c 100644 --- a/provekit/common/src/file/io/mod.rs +++ b/provekit/common/src/file/io/mod.rs @@ -3,11 +3,12 @@ mod buf_ext; mod counting_writer; mod json; +pub use self::bin::Compression; use { self::{ bin::{ deserialize_from_bytes, read_bin, read_hash_config as read_hash_config_bin, - serialize_to_bytes, write_bin, Compression, + serialize_to_bytes, write_bin, }, buf_ext::BufExt, counting_writer::CountingWriter, @@ -29,7 +30,7 @@ pub trait FileFormat: Serialize + for<'a> Deserialize<'a> { } /// Helper trait to optionally extract hash config. -pub(crate) trait MaybeHashAware { +pub trait MaybeHashAware { fn maybe_hash_config(&self) -> Option; } diff --git a/provekit/common/src/lib.rs b/provekit/common/src/lib.rs index 3953207d8..593b064d7 100644 --- a/provekit/common/src/lib.rs +++ b/provekit/common/src/lib.rs @@ -1,7 +1,7 @@ pub mod file; pub use file::binary_format; pub mod hash_config; -mod interner; +pub mod interner; mod mavros; mod noir_proof_scheme; pub mod ntt; @@ -11,6 +11,7 @@ pub mod prefix_covector; mod prover; mod r1cs; pub mod skyscraper; +pub mod spark; pub mod sparse_matrix; mod transcript_sponge; pub mod u256_arith; @@ -32,6 +33,7 @@ pub use { prefix_covector::{OffsetCovector, PrefixCovector, SparseCovector}, prover::{NoirProver, Prover}, r1cs::R1CS, + spark::{MatrixDimensions, SparkSetup, SparkWhirConfigs}, transcript_sponge::TranscriptSponge, verifier::Verifier, whir_r1cs::{R1csHash, WhirConfig, WhirR1CSProof, WhirR1CSScheme, WhirZkConfig}, diff --git a/provekit/common/src/spark.rs b/provekit/common/src/spark.rs new file mode 100644 index 000000000..acf029bba --- /dev/null +++ b/provekit/common/src/spark.rs @@ -0,0 +1,72 @@ +use { + crate::{utils::serde_ark, FieldElement, HashConfig, WhirConfig, WhirR1CSProof}, + anyhow::{Context, Result}, + serde::{Deserialize, Serialize}, + sha3::{Digest, Sha3_256}, +}; + +/// A single column-axis SPARK query: an evaluation point on the column axis +/// plus three claimed evaluations (for the A, B, C matrices). The row axis is +/// shared across all queries in a [`SparkQueryBatch`]. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct SparkColQuery { + #[serde(with = "serde_ark")] + pub col: Vec, + #[serde(with = "serde_ark")] + pub claimed_a: FieldElement, + #[serde(with = "serde_ark")] + pub claimed_b: FieldElement, + #[serde(with = "serde_ark")] + pub claimed_c: FieldElement, +} + +/// A batch of SPARK queries that all share the same row evaluation point. +/// The shared-row invariant is structural: a batch *cannot* express a +/// mixed-row set, so the SPARK prover and verifier do not need a runtime +/// check. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct SparkQueryBatch { + #[serde(with = "serde_ark")] + pub row: Vec, + pub queries: Vec, +} + +impl SparkQueryBatch { + /// Stable Fiat-Shamir instance binding for the batch. + pub fn hash_bytes(&self) -> Result<[u8; 32]> { + let bytes = + postcard::to_allocvec(self).context("serializing SparkQueryBatch for hash_bytes")?; + Ok(Sha3_256::digest(&bytes).into()) + } +} + +/// Dimensions of the (padded) sparse R1CS matrix that SPARK is committing to. +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct MatrixDimensions { + pub num_rows: usize, + pub num_cols: usize, + pub nonzero_terms: usize, +} + +/// WHIR configurations used by SPARK for each committed polynomial axis. +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct SparkWhirConfigs { + pub row: WhirConfig, + pub col: WhirConfig, + pub num_terms_2batched: WhirConfig, + pub num_terms_5batched: WhirConfig, +} + +/// Verifier-side SPARK setup: WHIR configs, matrix dimensions, the preprocessed +/// commitment transcript, and the hash config used to seed Fiat-Shamir. +/// +/// This struct is embedded in [`Verifier`](crate::Verifier) so the SPARK +/// commitments come from the trusted `.pkv` key rather than an attacker- +/// supplied setup file. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SparkSetup { + pub whir_configs: SparkWhirConfigs, + pub matrix_dimensions: MatrixDimensions, + pub transcript: WhirR1CSProof, + pub hash_config: HashConfig, +} diff --git a/provekit/common/src/utils/sumcheck.rs b/provekit/common/src/utils/sumcheck.rs index d703d8cd4..859d1d54e 100644 --- a/provekit/common/src/utils/sumcheck.rs +++ b/provekit/common/src/utils/sumcheck.rs @@ -152,6 +152,11 @@ fn eval_eq( } } +/// Evaluates a quadratic polynomial on a value +pub fn eval_quadratic_poly(poly: [FieldElement; 3], point: FieldElement) -> FieldElement { + poly[0] + point * (poly[1] + point * poly[2]) +} + /// Evaluates a cubic polynomial on a value pub fn eval_cubic_poly(poly: [FieldElement; 4], point: FieldElement) -> FieldElement { poly[0] + point * (poly[1] + point * (poly[2] + point * poly[3])) diff --git a/provekit/common/src/verifier.rs b/provekit/common/src/verifier.rs index 2663cff61..0861a4538 100644 --- a/provekit/common/src/verifier.rs +++ b/provekit/common/src/verifier.rs @@ -1,7 +1,7 @@ use { crate::{ - noir_proof_scheme::NoirProofScheme, utils::serde_jsonify, whir_r1cs::WhirR1CSScheme, - HashConfig, R1CS, + noir_proof_scheme::NoirProofScheme, spark::SparkSetup, utils::serde_jsonify, + whir_r1cs::WhirR1CSScheme, HashConfig, R1CS, }, noirc_abi::Abi, serde::{Deserialize, Serialize}, @@ -11,13 +11,15 @@ use { /// serialized to a `.pkv` file by `prepare` and loaded by `verify` (or by /// `generate-gnark-inputs` for the recursive path). /// -/// Holds the R1CS, the WHIR-for-witness commitment configuration, and the -/// ABI needed to bind public inputs back to their Noir-level names. +/// Holds the R1CS, the WHIR-for-witness commitment configuration, the SPARK +/// setup (when `prepare --spark` was used), and the ABI needed to bind public +/// inputs back to their Noir-level names. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Verifier { pub hash_config: HashConfig, pub r1cs: R1CS, pub whir_for_witness: Option, + pub spark_setup: Option, #[serde(with = "serde_jsonify")] pub abi: Abi, } @@ -28,12 +30,14 @@ impl Verifier { NoirProofScheme::Noir(d) => Self { r1cs: d.r1cs, whir_for_witness: Some(d.whir_for_witness), + spark_setup: None, abi: d.witness_generator.abi.clone(), hash_config: d.hash_config, }, NoirProofScheme::Mavros(d) => Self { r1cs: d.r1cs, whir_for_witness: Some(d.whir_for_witness), + spark_setup: None, abi: d.abi.clone(), hash_config: d.hash_config, }, diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 1f5474a9f..80361b875 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -9,8 +9,8 @@ use { acir::native_types::{Witness, WitnessMap}, anyhow::{Context, Result}, provekit_common::{ - utils::noir_to_native, FieldElement, NoirElement, NoirProof, NoirProver, Prover, - PublicInputs, TranscriptSponge, + spark::SparkQueryBatch, utils::noir_to_native, FieldElement, NoirElement, NoirProof, + NoirProver, Prover, PublicInputs, TranscriptSponge, }, std::mem::{size_of, take}, whir::transcript::ProverState, @@ -41,6 +41,10 @@ pub use {ec_arith::ec_scalar_mul, r1cs::solve_witness_vec}; /// `prove` and `prove_with_toml` are native-only (cfg-gated out on wasm32). /// `prove_with_witness` is available on all targets. `MavrosProver` does not /// support `prove_with_witness` (errors at runtime). +/// +/// Callers that also need the SPARK query batch (produced as a side output) +/// use [`prove_with_spark_toml`](Prove::prove_with_spark_toml), which returns +/// the proof and the batch together. pub trait Prove { #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))] fn prove(self, input_map: InputMap) -> Result; @@ -49,6 +53,12 @@ pub trait Prove { fn prove_with_toml(self, prover_toml: impl AsRef) -> Result; fn prove_with_witness(self, witness: WitnessMap) -> Result; + + #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))] + fn prove_with_spark_toml( + self, + prover_toml: impl AsRef, + ) -> Result<(NoirProof, SparkQueryBatch)>; } #[instrument(skip_all)] @@ -83,318 +93,375 @@ fn generate_noir_witness( .witness) } -impl Prove for NoirProver { - #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))] - #[instrument(skip_all)] - fn prove(mut self, input_map: InputMap) -> Result { - let witness = generate_noir_witness(&mut self, input_map)?; - self.prove_with_witness(witness) - } - - #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))] - #[instrument(skip_all)] - fn prove_with_toml(self, prover_toml: impl AsRef) -> Result { - let (input_map, _return_value) = - read_inputs_from_file(prover_toml.as_ref(), self.witness_generator.abi())?; - self.prove(input_map) +#[instrument(skip_all)] +fn prove_noir_inner( + prover: NoirProver, + acir_witness_idx_to_value_map: WitnessMap, + produce_spark_query: bool, +) -> Result<(NoirProof, Option)> { + provekit_common::register_ntt(); + + let mut public_input_indices = prover.program.functions[0].public_inputs().indices(); + public_input_indices.sort_unstable(); + let public_inputs = if public_input_indices.is_empty() { + PublicInputs::new() + } else { + let values = public_input_indices + .iter() + .map(|&idx| { + let noir_val = acir_witness_idx_to_value_map + .get(&Witness::from(idx)) + .ok_or_else(|| anyhow::anyhow!("Missing public input at index {idx}"))?; + Ok(noir_to_native(*noir_val)) + }) + .collect::>>()?; + PublicInputs::from_vec(values) + }; + + drop(prover.program); + drop(prover.witness_generator); + + // R1CS matrices are only needed at sumcheck; compress to free memory during + // commits. + let compressed_r1cs = + CompressedR1CS::compress(prover.r1cs).context("While compressing R1CS")?; + let num_witnesses = compressed_r1cs.num_witnesses(); + let num_constraints = compressed_r1cs.num_constraints(); + + // Set up transcript with public inputs bound to the instance. + let instance = public_inputs.hash_bytes(prover.hash_config); + let ds = prover + .whir_for_witness + .create_domain_separator() + .instance(&instance); + let mut merlin = ProverState::new(&ds, TranscriptSponge::from_config(prover.hash_config)); + + // Allocate space for real + virtual witnesses. Virtual witnesses are + // computation-only (zero entries in A/B/C) but needed by builders. + let mut witness: Vec> = + vec![None; compressed_r1cs.num_witnesses_for_solving()]; + + // Solve w1 (or all witnesses if no challenges). + { + let _s = info_span!("solve_w1").entered(); + crate::r1cs::solve_witness_vec( + &mut witness, + prover.split_witness_builders.w1_layers, + &acir_witness_idx_to_value_map, + &mut merlin, + ) + .context("While solving w1 witnesses")?; } - #[instrument(skip_all)] - fn prove_with_witness( - self, - acir_witness_idx_to_value_map: WitnessMap, - ) -> Result { - provekit_common::register_ntt(); - - let mut public_input_indices = self.program.functions[0].public_inputs().indices(); - public_input_indices.sort_unstable(); - let public_inputs = if public_input_indices.is_empty() { - PublicInputs::new() - } else { - let values = public_input_indices - .iter() - .map(|&idx| { - let noir_val = acir_witness_idx_to_value_map - .get(&Witness::from(idx)) - .ok_or_else(|| anyhow::anyhow!("Missing public input at index {idx}"))?; - Ok(noir_to_native(*noir_val)) - }) - .collect::>>()?; - PublicInputs::from_vec(values) - }; - - drop(self.program); - drop(self.witness_generator); - - // R1CS matrices are only needed at sumcheck; compress to free memory during - // commits. - let compressed_r1cs = - CompressedR1CS::compress(self.r1cs).context("While compressing R1CS")?; - let num_witnesses = compressed_r1cs.num_witnesses(); - let num_constraints = compressed_r1cs.num_constraints(); - - // Set up transcript with public inputs bound to the instance. - let instance = public_inputs.hash_bytes(self.hash_config); - let ds = self - .whir_for_witness - .create_domain_separator() - .instance(&instance); - - let mut merlin = ProverState::new(&ds, TranscriptSponge::from_config(self.hash_config)); - - // Allocate space for real + virtual witnesses. Virtual witnesses are - // computation-only (zero entries in A/B/C) but needed by builders. - let mut witness: Vec> = - vec![None; compressed_r1cs.num_witnesses_for_solving()]; - - // Solve w1 (or all witnesses if no challenges). + // Compress w2 layers to free memory during w1 commit (only when + // challenges exist; otherwise just drop them). + let has_challenges = prover.whir_for_witness.num_challenges > 0; + let compressed_w2_layers = if has_challenges { + Some( + CompressedLayers::compress(prover.split_witness_builders.w2_layers) + .context("While compressing w2 layers")?, + ) + } else { + drop(prover.split_witness_builders.w2_layers); + None + }; + + debug!( + witness_heap_bytes = witness.capacity() * size_of::>(), + compressed_r1cs_blob_bytes = compressed_r1cs.blob_len(), + "component sizes after solve_w1" + ); + + let w1 = { + let _s = info_span!("allocate_w1").entered(); + witness[..prover.whir_for_witness.w1_size] + .iter() + .map(|w| w.ok_or_else(|| anyhow::anyhow!("Some witnesses in w1 are missing"))) + .collect::>>()? + }; + + crate::logging::log_commit_input("noir_w1", &w1, prover.whir_for_witness.domain_size()); + let commitment_1 = prover + .whir_for_witness + .commit(&mut merlin, num_witnesses, num_constraints, w1, true) + .context("While committing to w1")?; + + let commitments = if has_challenges { + let w2_layers = compressed_w2_layers + .unwrap() + .decompress() + .context("While decompressing w2 layers")?; { - let _s = info_span!("solve_w1").entered(); + let _s = info_span!("solve_w2").entered(); crate::r1cs::solve_witness_vec( &mut witness, - self.split_witness_builders.w1_layers, + w2_layers, &acir_witness_idx_to_value_map, &mut merlin, ) - .context("While solving w1 witnesses")?; + .context("While solving w2 witnesses")?; } - - // Compress w2 layers to free memory during w1 commit (only when - // challenges exist; otherwise just drop them). - let has_challenges = self.whir_for_witness.num_challenges > 0; - let compressed_w2_layers = if has_challenges { - Some( - CompressedLayers::compress(self.split_witness_builders.w2_layers) - .context("While compressing w2 layers")?, - ) - } else { - drop(self.split_witness_builders.w2_layers); - None - }; - - debug!( - witness_heap_bytes = witness.capacity() * size_of::>(), - compressed_r1cs_blob_bytes = compressed_r1cs.blob_len(), - "component sizes after solve_w1" - ); - - let w1 = { - let _s = info_span!("allocate_w1").entered(); - witness[..self.whir_for_witness.w1_size] + drop(acir_witness_idx_to_value_map); + + let w2 = { + let _s = info_span!("allocate_w2").entered(); + // Only real w2 witnesses (exclude virtual at the end). + debug_assert!( + prover.whir_for_witness.w1_size <= num_witnesses, + "w1_size ({}) exceeds num_witnesses ({})", + prover.whir_for_witness.w1_size, + num_witnesses + ); + witness[prover.whir_for_witness.w1_size..num_witnesses] .iter() - .map(|w| w.ok_or_else(|| anyhow::anyhow!("Some witnesses in w1 are missing"))) + .map(|w| w.ok_or_else(|| anyhow::anyhow!("Some witnesses in w2 are missing"))) .collect::>>()? }; - crate::logging::log_commit_input("noir_w1", &w1, self.whir_for_witness.domain_size()); - let commitment_1 = self + crate::logging::log_commit_input("noir_w2", &w2, prover.whir_for_witness.domain_size()); + let commitment_2 = prover .whir_for_witness - .commit(&mut merlin, num_witnesses, num_constraints, w1, true) - .context("While committing to w1")?; - - let commitments = if has_challenges { - let w2_layers = compressed_w2_layers - .unwrap() - .decompress() - .context("While decompressing w2 layers")?; - { - let _s = info_span!("solve_w2").entered(); - crate::r1cs::solve_witness_vec( - &mut witness, - w2_layers, - &acir_witness_idx_to_value_map, - &mut merlin, - ) - .context("While solving w2 witnesses")?; - } - drop(acir_witness_idx_to_value_map); - - let w2 = { - let _s = info_span!("allocate_w2").entered(); - // Only real w2 witnesses (exclude virtual at the end). - debug_assert!( - self.whir_for_witness.w1_size <= num_witnesses, - "w1_size ({}) exceeds num_witnesses ({})", - self.whir_for_witness.w1_size, - num_witnesses - ); - witness[self.whir_for_witness.w1_size..num_witnesses] - .iter() - .map(|w| w.ok_or_else(|| anyhow::anyhow!("Some witnesses in w2 are missing"))) - .collect::>>()? - }; - - crate::logging::log_commit_input("noir_w2", &w2, self.whir_for_witness.domain_size()); - let commitment_2 = self - .whir_for_witness - .commit(&mut merlin, num_witnesses, num_constraints, w2, false) - .context("While committing to w2")?; - - vec![commitment_1, commitment_2] - } else { - drop(acir_witness_idx_to_value_map); - vec![commitment_1] - }; - - // Decompress R1CS for the sumcheck and matrix operations. - let r1cs = compressed_r1cs - .decompress() - .context("While decompressing R1CS")?; - - #[cfg(test)] - r1cs.test_witness_satisfaction( - &witness[..num_witnesses] - .iter() - .map(|w| w.unwrap()) - .collect::>(), - ) - .context("While verifying R1CS instance")?; - - // Extract only real witnesses (first num_witnesses) for the sumcheck. - // Virtual witnesses at [num_witnesses, num_witnesses+num_virtual) were - // needed for builder computation but have zero entries in A/B/C. - let full_witness: Vec = witness[..num_witnesses] + .commit(&mut merlin, num_witnesses, num_constraints, w2, false) + .context("While committing to w2")?; + + vec![commitment_1, commitment_2] + } else { + drop(acir_witness_idx_to_value_map); + vec![commitment_1] + }; + + // Decompress R1CS for the sumcheck and matrix operations. + let r1cs = compressed_r1cs + .decompress() + .context("While decompressing R1CS")?; + + #[cfg(test)] + r1cs.test_witness_satisfaction( + &witness[..num_witnesses] .iter() - .enumerate() - .map(|(i, w)| w.ok_or_else(|| anyhow::anyhow!("Witness {i} unsolved after solving"))) - .collect::>>()?; - - let whir_r1cs_proof = self + .map(|w| w.unwrap()) + .collect::>(), + ) + .context("While verifying R1CS instance")?; + + // Extract only real witnesses (first num_witnesses) for the sumcheck. + // Virtual witnesses at [num_witnesses, num_witnesses+num_virtual) were + // needed for builder computation but have zero entries in A/B/C. + let full_witness: Vec = witness[..num_witnesses] + .iter() + .enumerate() + .map(|(i, w)| w.ok_or_else(|| anyhow::anyhow!("Witness {i} unsolved after solving"))) + .collect::>>()?; + + let (whir_r1cs_proof, r1cs_spark_queries) = if produce_spark_query { + let (proof, batch) = prover + .whir_for_witness + .prove_noir_with_spark(merlin, r1cs, commitments, full_witness, &public_inputs) + .context("While proving R1CS instance")?; + (proof, Some(batch)) + } else { + let proof = prover .whir_for_witness .prove_noir(merlin, r1cs, commitments, full_witness, &public_inputs) .context("While proving R1CS instance")?; + (proof, None) + }; - Ok(NoirProof { + Ok(( + NoirProof { public_inputs, whir_r1cs_proof, - }) - } + }, + r1cs_spark_queries, + )) } -#[cfg(not(target_arch = "wasm32"))] -impl Prove for MavrosProver { - #[cfg(feature = "witness-generation")] +impl Prove for NoirProver { + #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))] + #[instrument(skip_all)] fn prove(mut self, input_map: InputMap) -> Result { - provekit_common::register_ntt(); - - let params = crate::input_utils::ordered_params_from_btreemap(&self.abi, &input_map)?; - let phase1 = mavros_interpreter::run_phase1( - &mut self.witgen_binary, - self.witness_layout, - self.constraints_layout, - ¶ms, - ); - drop(self.witgen_binary); + let witness = generate_noir_witness(&mut self, input_map)?; + self.prove_with_witness(witness) + } - let num_public_inputs = self.num_public_inputs; - let public_inputs = if num_public_inputs == 0 { - PublicInputs::new() - } else { - PublicInputs::from_vec(phase1.out_wit_pre_comm[1..=num_public_inputs].to_vec()) - }; + #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))] + #[instrument(skip_all)] + fn prove_with_toml(self, prover_toml: impl AsRef) -> Result { + let (input_map, _return_value) = + read_inputs_from_file(prover_toml.as_ref(), self.witness_generator.abi())?; + self.prove(input_map) + } - // Set up transcript with public inputs bound to the instance. - let instance = public_inputs.hash_bytes(self.hash_config); - let ds = self - .whir_for_witness - .create_domain_separator() - .instance(&instance); - let mut merlin = ProverState::new(&ds, TranscriptSponge::from_config(self.hash_config)); - - info!( - ?self.witness_layout, - ?self.constraints_layout, - scheme_domain_len = self.whir_for_witness.domain_size(), - "Mavros witness layout" + #[instrument(skip_all)] + fn prove_with_witness(self, witness: WitnessMap) -> Result { + let (proof, _) = prove_noir_inner(self, witness, false)?; + Ok(proof) + } + + #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))] + #[instrument(skip_all)] + fn prove_with_spark_toml( + mut self, + prover_toml: impl AsRef, + ) -> Result<(NoirProof, SparkQueryBatch)> { + let (input_map, _return_value) = + read_inputs_from_file(prover_toml.as_ref(), self.witness_generator.abi())?; + let witness = generate_noir_witness(&mut self, input_map)?; + let (proof, batch) = prove_noir_inner(self, witness, true)?; + Ok(( + proof, + batch.expect("spark batch must be produced when requested"), + )) + } +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "witness-generation"))] +fn prove_mavros_inner( + mut prover: MavrosProver, + input_map: InputMap, + produce_spark_query: bool, +) -> Result<(NoirProof, Option)> { + provekit_common::register_ntt(); + + let params = crate::input_utils::ordered_params_from_btreemap(&prover.abi, &input_map)?; + let phase1 = mavros_interpreter::run_phase1( + &mut prover.witgen_binary, + prover.witness_layout, + prover.constraints_layout, + ¶ms, + ); + drop(prover.witgen_binary); + + let num_public_inputs = prover.num_public_inputs; + let public_inputs = if num_public_inputs == 0 { + PublicInputs::new() + } else { + PublicInputs::from_vec(phase1.out_wit_pre_comm[1..=num_public_inputs].to_vec()) + }; + + // Set up transcript with public inputs bound to the instance. + let instance = public_inputs.hash_bytes(prover.hash_config); + let ds = prover + .whir_for_witness + .create_domain_separator() + .instance(&instance); + let mut merlin = ProverState::new(&ds, TranscriptSponge::from_config(prover.hash_config)); + + info!( + ?prover.witness_layout, + ?prover.constraints_layout, + scheme_domain_len = prover.whir_for_witness.domain_size(), + "Mavros witness layout" + ); + + let w1 = phase1.out_wit_pre_comm.clone(); + crate::logging::log_commit_input( + "mavros_w1_pre_commitment", + &w1, + prover.whir_for_witness.domain_size(), + ); + let commitment_1 = prover + .whir_for_witness + .commit( + &mut merlin, + prover.witness_layout.size(), + prover.constraints_layout.algebraic_size, + w1, + true, + ) + .context("While committing to w1")?; + + let (commitments, witgen_result) = if prover.whir_for_witness.num_challenges > 0 { + let challenges: Vec = (0..prover.witness_layout.challenges_size) + .map(|_| merlin.verifier_message()) + .collect(); + + let mut witgen_result = mavros_interpreter::run_phase2( + phase1, + &challenges, + prover.witness_layout, + prover.constraints_layout, ); - let w1 = phase1.out_wit_pre_comm.clone(); + let w2 = take(&mut witgen_result.out_wit_post_comm); crate::logging::log_commit_input( - "mavros_w1_pre_commitment", - &w1, - self.whir_for_witness.domain_size(), + "mavros_w2_post_commitment", + &w2, + prover.whir_for_witness.domain_size(), ); - let commitment_1 = self + let commitment_2 = prover .whir_for_witness .commit( &mut merlin, - self.witness_layout.size(), - self.constraints_layout.algebraic_size, - w1, - true, + prover.witness_layout.size(), + prover.constraints_layout.algebraic_size, + w2, + false, ) - .context("While committing to w1")?; - - let (commitments, witgen_result) = if self.whir_for_witness.num_challenges > 0 { - let challenges: Vec = (0..self.witness_layout.challenges_size) - .map(|_| merlin.verifier_message()) - .collect(); - - let witgen_result = mavros_interpreter::run_phase2( - phase1, - &challenges, - self.witness_layout, - self.constraints_layout, - ); - - let mut witgen_result = witgen_result; - let w2 = take(&mut witgen_result.out_wit_post_comm); - crate::logging::log_commit_input( - "mavros_w2_post_commitment", - &w2, - self.whir_for_witness.domain_size(), - ); - let commitment_2 = self - .whir_for_witness - .commit( - &mut merlin, - self.witness_layout.size(), - self.constraints_layout.algebraic_size, - w2, - false, - ) - .context("While committing to w2")?; - - (vec![commitment_1, commitment_2], witgen_result) - } else { - let witgen_result = mavros_interpreter::run_phase2( - phase1, - &[], - self.witness_layout, - self.constraints_layout, - ); - (vec![commitment_1], witgen_result) - }; + .context("While committing to w2")?; + + (vec![commitment_1, commitment_2], witgen_result) + } else { + let witgen_result = mavros_interpreter::run_phase2( + phase1, + &[], + prover.witness_layout, + prover.constraints_layout, + ); + (vec![commitment_1], witgen_result) + }; - let whir_r1cs_proof = self + let (whir_r1cs_proof, r1cs_spark_queries) = if produce_spark_query { + let (proof, batch) = prover + .whir_for_witness + .prove_mavros_with_spark( + merlin, + witgen_result, + commitments, + &public_inputs, + prover.witness_layout, + prover.constraints_layout, + &prover.ad_binary, + ) + .context("While proving R1CS instance")?; + (proof, Some(batch)) + } else { + let proof = prover .whir_for_witness .prove_mavros( merlin, witgen_result, commitments, &public_inputs, - self.witness_layout, - self.constraints_layout, - &self.ad_binary, + prover.witness_layout, + prover.constraints_layout, + &prover.ad_binary, ) .context("While proving R1CS instance")?; + (proof, None) + }; - Ok(NoirProof { + Ok(( + NoirProof { public_inputs, whir_r1cs_proof, - }) + }, + r1cs_spark_queries, + )) +} + +#[cfg(not(target_arch = "wasm32"))] +impl Prove for MavrosProver { + #[cfg(feature = "witness-generation")] + fn prove(self, input_map: InputMap) -> Result { + let (proof, _) = prove_mavros_inner(self, input_map, false)?; + Ok(proof) } #[cfg(feature = "witness-generation")] #[instrument(skip_all)] fn prove_with_toml(self, prover_toml: impl AsRef) -> Result { - let project_path = prover_toml - .as_ref() - .parent() - .context("Could not derive project path from Prover.toml path")?; - - let input_map = - crate::input_utils::read_prover_inputs(&project_path.to_path_buf(), &self.abi)?; + let input_map = mavros_input_map_from_toml(&self.abi, prover_toml.as_ref())?; self.prove(input_map) } @@ -403,6 +470,27 @@ impl Prove for MavrosProver { "prove_with_witness is not supported for Mavros prover" )) } + + #[cfg(feature = "witness-generation")] + fn prove_with_spark_toml( + self, + prover_toml: impl AsRef, + ) -> Result<(NoirProof, SparkQueryBatch)> { + let input_map = mavros_input_map_from_toml(&self.abi, prover_toml.as_ref())?; + let (proof, batch) = prove_mavros_inner(self, input_map, true)?; + Ok(( + proof, + batch.expect("spark batch must be produced when requested"), + )) + } +} + +#[cfg(all(not(target_arch = "wasm32"), feature = "witness-generation"))] +fn mavros_input_map_from_toml(abi: &noirc_abi::Abi, prover_toml: &Path) -> Result { + let project_path = prover_toml + .parent() + .context("Could not derive project path from Prover.toml path")?; + crate::input_utils::read_prover_inputs(&project_path.to_path_buf(), abi) } impl Prove for Prover { @@ -433,4 +521,15 @@ impl Prove for Prover { } } } + + #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))] + fn prove_with_spark_toml( + self, + prover_toml: impl AsRef, + ) -> Result<(NoirProof, SparkQueryBatch)> { + match self { + Prover::Noir(p) => p.prove_with_spark_toml(prover_toml), + Prover::Mavros(p) => p.prove_with_spark_toml(prover_toml), + } + } } diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 3b1a5a97e..c939ea0ac 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -9,6 +9,7 @@ use { compute_public_eval, expand_powers, make_challenge_weight, make_public_weight, OffsetCovector, }, + spark::{SparkColQuery, SparkQueryBatch}, utils::{ pad_to_power_of_two, sumcheck::{ @@ -64,7 +65,17 @@ pub trait WhirR1CSProver { public_inputs: &PublicInputs, ) -> Result; + fn prove_noir_with_spark( + &self, + merlin: ProverState, + r1cs: R1CS, + commitments: Vec, + full_witness: Vec, + public_inputs: &PublicInputs, + ) -> Result<(WhirR1CSProof, SparkQueryBatch)>; + #[cfg(not(target_arch = "wasm32"))] + #[allow(clippy::too_many_arguments)] fn prove_mavros( &self, merlin: ProverState, @@ -75,6 +86,19 @@ pub trait WhirR1CSProver { constraints_layout: ConstraintsLayout, ad_binary: &[u64], ) -> Result; + + #[cfg(not(target_arch = "wasm32"))] + #[allow(clippy::too_many_arguments)] + fn prove_mavros_with_spark( + &self, + merlin: ProverState, + witgen: WitgenResult, + commitments: Vec, + public_inputs: &PublicInputs, + witness_layout: WitnessLayout, + constraints_layout: ConstraintsLayout, + ad_binary: &[u64], + ) -> Result<(WhirR1CSProof, SparkQueryBatch)>; } impl WhirR1CSProver for WhirR1CSScheme { @@ -142,55 +166,54 @@ impl WhirR1CSProver for WhirR1CSScheme { #[instrument(skip_all)] fn prove_noir( &self, - mut merlin: ProverState, + merlin: ProverState, r1cs: R1CS, commitments: Vec, full_witness: Vec, public_inputs: &PublicInputs, ) -> Result { - ensure!(!commitments.is_empty(), "Need at least one commitment"); - - let (a, b, c) = calculate_witness_bounds(&r1cs, &full_witness); - drop(full_witness); - - let blinding = commitments[0] - .blinding - .as_ref() - .expect("c1 must carry blinding state"); - - let (alpha, blinding_eval) = run_zk_sumcheck_prover( - a, - b, - c, - &mut merlin, - self.m_0, - &blinding.polynomial, - &commitments[0].polynomial, - blinding.offset, - ); - - let (at, bt, ct) = transpose_r1cs_matrices(&r1cs); - let alphas = multiply_transposed_by_eq_alpha(&at, &bt, &ct, &alpha, &r1cs); + let (proof, _) = prove_noir_inner( + self, + merlin, + r1cs, + commitments, + full_witness, + public_inputs, + false, + )?; + Ok(proof) + } - let blinding_offset = blinding.offset; - let blinding_weights = expand_powers::<4>(&alpha); - prove_from_alphas( + #[instrument(skip_all)] + fn prove_noir_with_spark( + &self, + merlin: ProverState, + r1cs: R1CS, + commitments: Vec, + full_witness: Vec, + public_inputs: &PublicInputs, + ) -> Result<(WhirR1CSProof, SparkQueryBatch)> { + let (proof, batch) = prove_noir_inner( self, merlin, - alphas, - blinding_eval, - blinding_offset, - blinding_weights, + r1cs, commitments, + full_witness, public_inputs, - ) + true, + )?; + Ok(( + proof, + batch.expect("spark batch must be produced when requested"), + )) } #[cfg(not(target_arch = "wasm32"))] #[instrument(skip_all)] + #[allow(clippy::too_many_arguments)] fn prove_mavros( &self, - mut merlin: ProverState, + merlin: ProverState, witgen: WitgenResult, commitments: Vec, public_inputs: &PublicInputs, @@ -198,62 +221,195 @@ impl WhirR1CSProver for WhirR1CSScheme { constraints_layout: ConstraintsLayout, ad_binary: &[u64], ) -> Result { - ensure!(!commitments.is_empty(), "Need at least one commitment"); - - let blinding = commitments[0] - .blinding - .as_ref() - .expect("c1 must carry blinding state"); - - let [a, b, c] = [witgen.out_a, witgen.out_b, witgen.out_c]; - let (alpha, blinding_eval) = run_zk_sumcheck_prover( - a, - b, - c, - &mut merlin, - self.m_0, - &blinding.polynomial, - &commitments[0].polynomial, - blinding.offset, - ); - - let eq_alpha = - calculate_evaluations_over_boolean_hypercube_for_eq(&alpha, 1 << alpha.len()); - let (ad_a, ad_b, ad_c, _) = mavros_vm::interpreter::run_ad( - ad_binary, - &eq_alpha[..constraints_layout.size()], + let (proof, _) = prove_mavros_inner( + self, + merlin, + witgen, + commitments, + public_inputs, witness_layout, constraints_layout, - ); - let alphas = [ad_a, ad_b, ad_c]; - - let blinding_offset = blinding.offset; - let blinding_weights = expand_powers::<4>(&alpha); + ad_binary, + false, + )?; + Ok(proof) + } - prove_from_alphas( + #[cfg(not(target_arch = "wasm32"))] + #[instrument(skip_all)] + #[allow(clippy::too_many_arguments)] + fn prove_mavros_with_spark( + &self, + merlin: ProverState, + witgen: WitgenResult, + commitments: Vec, + public_inputs: &PublicInputs, + witness_layout: WitnessLayout, + constraints_layout: ConstraintsLayout, + ad_binary: &[u64], + ) -> Result<(WhirR1CSProof, SparkQueryBatch)> { + let (proof, batch) = prove_mavros_inner( self, merlin, + witgen, + commitments, + public_inputs, + witness_layout, + constraints_layout, + ad_binary, + true, + )?; + Ok(( + proof, + batch.expect("spark batch must be produced when requested"), + )) + } +} + +#[instrument(skip_all)] +#[allow(clippy::too_many_arguments)] +fn prove_noir_inner( + scheme: &WhirR1CSScheme, + mut merlin: ProverState, + r1cs: R1CS, + commitments: Vec, + full_witness: Vec, + public_inputs: &PublicInputs, + produce_spark_query: bool, +) -> Result<(WhirR1CSProof, Option)> { + ensure!(!commitments.is_empty(), "Need at least one commitment"); + + let (a, b, c) = calculate_witness_bounds(&r1cs, &full_witness); + drop(full_witness); + + let blinding = commitments[0] + .blinding + .as_ref() + .expect("c1 must carry blinding state"); + + let (alpha, blinding_eval) = run_zk_sumcheck_prover( + a, + b, + c, + &mut merlin, + scheme.m_0, + &blinding.polynomial, + &commitments[0].polynomial, + blinding.offset, + ); + + let (at, bt, ct) = transpose_r1cs_matrices(&r1cs); + let alphas = multiply_transposed_by_eq_alpha(&at, &bt, &ct, &alpha, &r1cs); + + let blinding_offset = blinding.offset; + let blinding_weights = expand_powers::<4>(&alpha); + + prove_from_alphas( + scheme, + merlin, + ProveFromAlphasCtx { + alpha, alphas, blinding_eval, blinding_offset, blinding_weights, commitments, - public_inputs, - ) - } + produce_spark_query, + }, + public_inputs, + ) } +#[cfg(not(target_arch = "wasm32"))] #[instrument(skip_all)] -fn prove_from_alphas( +#[allow(clippy::too_many_arguments)] +fn prove_mavros_inner( scheme: &WhirR1CSScheme, mut merlin: ProverState, - alphas: [Vec; 3], - blinding_eval: FieldElement, - blinding_offset: usize, - blinding_weights: Vec, + witgen: WitgenResult, commitments: Vec, public_inputs: &PublicInputs, -) -> Result { + witness_layout: WitnessLayout, + constraints_layout: ConstraintsLayout, + ad_binary: &[u64], + produce_spark_query: bool, +) -> Result<(WhirR1CSProof, Option)> { + ensure!(!commitments.is_empty(), "Need at least one commitment"); + + let blinding = commitments[0] + .blinding + .as_ref() + .expect("c1 must carry blinding state"); + + let [a, b, c] = [witgen.out_a, witgen.out_b, witgen.out_c]; + let (alpha, blinding_eval) = run_zk_sumcheck_prover( + a, + b, + c, + &mut merlin, + scheme.m_0, + &blinding.polynomial, + &commitments[0].polynomial, + blinding.offset, + ); + + let eq_alpha = calculate_evaluations_over_boolean_hypercube_for_eq(&alpha, 1 << alpha.len()); + let (ad_a, ad_b, ad_c, _) = mavros_vm::interpreter::run_ad( + ad_binary, + &eq_alpha[..constraints_layout.size()], + witness_layout, + constraints_layout, + ); + let alphas = [ad_a, ad_b, ad_c]; + + let blinding_offset = blinding.offset; + let blinding_weights = expand_powers::<4>(&alpha); + + prove_from_alphas( + scheme, + merlin, + ProveFromAlphasCtx { + alpha, + alphas, + blinding_eval, + blinding_offset, + blinding_weights, + commitments, + produce_spark_query, + }, + public_inputs, + ) +} + +/// Owned inputs to the post-sumcheck proving stage ([`prove_from_alphas`]), +/// produced by [`prove_noir_inner`] / [`prove_mavros_inner`]. +struct ProveFromAlphasCtx { + alpha: Vec, + alphas: [Vec; 3], + blinding_eval: FieldElement, + blinding_offset: usize, + blinding_weights: Vec, + commitments: Vec, + produce_spark_query: bool, +} + +#[instrument(skip_all)] +fn prove_from_alphas( + scheme: &WhirR1CSScheme, + mut merlin: ProverState, + ctx: ProveFromAlphasCtx, + public_inputs: &PublicInputs, +) -> Result<(WhirR1CSProof, Option)> { + let ProveFromAlphasCtx { + alpha, + alphas, + blinding_eval, + blinding_offset, + blinding_weights, + commitments, + produce_spark_query, + } = ctx; + let public_inputs_hash = public_inputs.hash(scheme.hash_config); let public_inputs_len = public_inputs.len(); @@ -263,7 +419,7 @@ fn prove_from_alphas( let domain_size = 1usize << scheme.m; - if is_single { + let spark_queries: Option = if is_single { // Single commitment path let commitment = commitments .into_iter() @@ -290,19 +446,54 @@ fn prove_from_alphas( let blinding_covector = OffsetCovector::new(blinding_weights, blinding_offset, domain_size); + let alpha_weight_data: Option, usize)>> = + produce_spark_query.then(|| { + weights + .iter() + .map(|w| (w.vector().to_vec(), w.size())) + .collect() + }); + let mut boxed_weights: Vec>> = weights .into_iter() .map(|w| Box::new(w) as Box>) .collect(); boxed_weights.push(Box::new(blinding_covector)); - let _ = scheme.whir_witness.prove( + let public_offset = if public_inputs.is_empty() { 0 } else { 1 }; + + let final_claim = scheme.whir_witness.prove( &mut merlin, vec![Cow::Borrowed(commitment.polynomial.as_slice())], commitment.witness, boxed_weights, Cow::Borrowed(&evaluations), ); + + if let Some(alpha_weight_data) = alpha_weight_data { + let evaluations: [FieldElement; 3] = alpha_weight_data + [public_offset..(public_offset + 3)] + .iter() + .map(|(vec, ds)| { + let w = PrefixCovector::new(vec.clone(), *ds); + w.mle_evaluate(&final_claim.evaluation_point) + }) + .collect::>() + .try_into() + .expect("exactly 3 alpha-weight evaluations"); + + Some(SparkQueryBatch { + row: alpha, + queries: vec![SparkColQuery { + col: final_claim.evaluation_point, + claimed_a: evaluations[0], + claimed_b: evaluations[1], + claimed_c: evaluations[2], + }], + }) + } else { + None + } } else { // Dual commitment path let mut commitments = commitments.into_iter(); @@ -355,12 +546,15 @@ fn prove_from_alphas( None }; + let has_public = public_1.is_some(); + let public_offset_1 = if has_public { 1 } else { 0 }; + let WhirR1CSCommitment { witness: w1, polynomial: p1, .. } = c1; - { + let (final_claim1, rlc1) = { let mut weights = build_prefix_covectors(scheme.m, alphas_1); let mut evaluations: Vec = Vec::new(); if let Some(pe) = public_1 { @@ -373,20 +567,43 @@ fn prove_from_alphas( let blinding_covector = OffsetCovector::new(blinding_weights, blinding_offset, domain_size); + let alpha_weight_data_1: Option, usize)>> = produce_spark_query + .then(|| { + weights[public_offset_1..public_offset_1 + 3] + .iter() + .map(|w| (w.vector().to_vec(), w.size())) + .collect() + }); + let mut boxed_weights: Vec>> = weights .into_iter() .map(|w| Box::new(w) as Box>) .collect(); boxed_weights.push(Box::new(blinding_covector)); - let _ = scheme.whir_witness.prove( + let final_claim1 = scheme.whir_witness.prove( &mut merlin, vec![Cow::Borrowed(p1.as_slice())], w1, boxed_weights, Cow::Borrowed(&evaluations), ); - } + + let claimed1: Option<[FieldElement; 3]> = + alpha_weight_data_1.map(|alpha_weight_data_1| { + alpha_weight_data_1 + .iter() + .map(|(vec, ds)| { + let w = PrefixCovector::new(vec.clone(), *ds); + w.mle_evaluate(&final_claim1.evaluation_point) + }) + .collect::>() + .try_into() + .expect("exactly 3 alpha-weight evaluations") + }); + + (final_claim1, claimed1) + }; drop(p1); let WhirR1CSCommitment { @@ -394,10 +611,18 @@ fn prove_from_alphas( polynomial: p2, .. } = c2; - { + let (final_claim2, rlc2) = { let weights = build_prefix_covectors(scheme.m, alphas_2); let mut evaluations: Vec = evals_2; + let alpha_weight_data_2: Option, usize)>> = produce_spark_query + .then(|| { + weights[0..3] + .iter() + .map(|w| (w.vector().to_vec(), w.size())) + .collect() + }); + let mut boxed_weights: Vec>> = weights .into_iter() .map(|w| Box::new(w) as Box>) @@ -409,23 +634,69 @@ fn prove_from_alphas( boxed_weights.push(Box::new(cw)); } - let _ = scheme.whir_witness.prove( + let final_claim2 = scheme.whir_witness.prove( &mut merlin, vec![Cow::Borrowed(p2.as_slice())], w2, boxed_weights, Cow::Borrowed(&evaluations), ); + + let claimed2: Option<[FieldElement; 3]> = + alpha_weight_data_2.map(|alpha_weight_data_2| { + alpha_weight_data_2 + .iter() + .map(|(vec, ds)| { + let w = PrefixCovector::new(vec.clone(), *ds); + w.mle_evaluate(&final_claim2.evaluation_point) + }) + .collect::>() + .try_into() + .expect("exactly 3 alpha-weight evaluations") + }); + + (final_claim2, claimed2) + }; + + match (rlc1, rlc2) { + (Some(claimed1), Some(claimed2)) => { + let mut col1 = final_claim1.evaluation_point.clone(); + col1.insert(0, FieldElement::zero()); + let query1 = SparkColQuery { + col: col1, + claimed_a: claimed1[0], + claimed_b: claimed1[1], + claimed_c: claimed1[2], + }; + + let mut col2 = final_claim2.evaluation_point.clone(); + col2.insert(0, FieldElement::one()); + let query2 = SparkColQuery { + col: col2, + claimed_a: claimed2[0], + claimed_b: claimed2[1], + claimed_c: claimed2[2], + }; + + Some(SparkQueryBatch { + row: alpha, + queries: vec![query1, query2], + }) + } + _ => None, } - } + }; let proof = merlin.proof(); - Ok(WhirR1CSProof { - narg_string: proof.narg_string, - hints: proof.hints, - #[cfg(debug_assertions)] - pattern: proof.pattern, - }) + Ok(( + WhirR1CSProof { + narg_string: proof.narg_string, + hints: proof.hints, + #[cfg(debug_assertions)] + pattern: proof.pattern, + }, + spark_queries, + )) } pub fn compute_blinding_coefficients_for_round( diff --git a/provekit/spark/Cargo.toml b/provekit/spark/Cargo.toml new file mode 100644 index 000000000..386f119d7 --- /dev/null +++ b/provekit/spark/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "provekit-spark" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +provekit-common.workspace = true +ark-ff.workspace = true +ark-serialize.workspace = true +ark-std.workspace = true +anyhow.workspace = true +serde.workspace = true +whir.workspace = true +tracing.workspace = true +rayon.workspace = true + +[lints] +workspace = true diff --git a/provekit/spark/SPARK.md b/provekit/spark/SPARK.md new file mode 100644 index 000000000..01b8bf3ff --- /dev/null +++ b/provekit/spark/SPARK.md @@ -0,0 +1,128 @@ +# SPARK + +Reference for this implementation +- SPARK: https://eprint.iacr.org/2019/550 +- Stronger security analysis of SPARK: https://people.cs.georgetown.edu/jthaler/Lasso-paper.pdf + +## Proposed prototype workflow +1. Provekit prepare step + - Compiles the circuit and writes the prover/verifier artifacts (`.pkp`, `.pkv`). + - With `--spark`, also runs SPARK preprocessing once: the SPARK setup is + folded into the verifier key (`.pkv`) and the bundled SPARK prover + context (matrix, witnesses, setup) is written to `.spctx`. + +2. Provekit prove step + - Runs the provekit prover and obtains the Noir proof plus the deferred + matrix evaluations (SPARK queries). + - Writes each query as `spark_query_.json` to `--spark-queries-dir`. + +3. Provekit prove-spark step + - Reads the SPARK prover context (`--spctx`) produced in step 1 and the + queries from `--spark-dir`, then produces a single batched SPARK proof + (`spark_proof.sp` written back to the same directory). When the query + set has more than one entry the prover RLC's them with a + transcript-derived `beta` and runs a parallel sumcheck before falling + into the single-query SPARK protocol; with one query it goes straight to + that protocol. + +4. Provekit and SPARK verify step + - Verifies Provekit and SPARK proofs + +## Design decisions + +### Pack $A$, $B$, $C$ into one block matrix Z: +This is a result from Marcin (https://gist.github.com/kustosz/14b62de666f721ab855536e575891bd1) + +**The trick:** + +$$Z = \begin{bmatrix} A & B \\ 0 & C \end{bmatrix}$$ + +Same total non-zeros, double the dimensions. Then for any $\beta$, $p$, and $q$: + +$$A(p,q) + \beta B(p,q) + \beta^2 C(p,q) = (1+\beta)^2 \cdot Z\!\left(\tfrac{\beta}{1+\beta}, p,\ \tfrac{\beta}{1+\beta}, q\right)$$ + +One matrix, one commitment, one opening. + +### Batching GPA and WHIR proofs + +- Combining GPA + - Products of hashes corresponding to read sets and write sets of row-wise and column-wise memory check are combined into one GPA + - Products of hashes corresponding to init and final vectors are combined into one GPA (separate for row-wise and col-wise memory). Possible optimization - if number of rows and columns for the matrix are ensured to be equal, we can combine them into one GPA. + +- WHIR Batching +| `num_terms_2batched` e-values are committed and opened together. Opened once in sumcheck and once in rs_ws GPA +| `num_terms_4batched` | Address/timestamp values for row-wise and col-wise memory checks are committed and opened together + +### Split witness: two SPARK queries +The current ZK WHIR doesn't support batching which would enable easier handling of split witness commitment. + +The current implementation emits **two SPARK queries** for the dual-commitment +path — one per split half. Both queries are then batched into a single SPARK +proof by RLC'ing their per-matrix claims with a transcript-derived `beta` and +running one parallel sumcheck of `Σ_i β^i · eq(col_i, x) · M(α, x)` for +M ∈ {A, B, C}. The folded values become the claims of a single synthesized +query passed into the single-query SPARK protocol. + + +## Full workflow for the `range-check-u8` Noir passport circuit: + +```bash +cargo build --release --bin provekit-cli + +cd noir-examples/noir-r1cs-test-programs/range-check-u8 +nargo compile + + +# 1. Prepare the circuit (compiles and writes prover/verifier artifacts; the +# SPARK setup is folded into the .pkv, and the bundled SPARK prover context +# is written to .spctx). +cargo run --release --bin provekit-cli -- prepare ./target/main.json \ + --pkp ./spark-artifacts/range-check-u8.pkp \ + --pkv ./spark-artifacts/range-check-u8.pkv \ + --spark \ + --spctx ./spark-artifacts/range-check-u8.spctx + +# 2. Prove (generates Noir proof + writes SPARK queries to disk). +# `--produce-spark-query` is required, otherwise no queries are written. +cargo run --release --bin provekit-cli -- prove \ + -p ./spark-artifacts/range-check-u8.pkp \ + -i ./Prover.toml \ + -o ./spark-artifacts/range-check-u8-proof.np \ + --spark-queries-dir ./spark_proofs \ + --produce-spark-query + +# 3. Generate one batched SPARK proof covering every query written in step 2. +# The prover reads all `spark_query_*.json` files in --spark-dir plus the +# SPARK prover context from --spctx, batches the queries, and writes a +# single ./spark_proofs/spark_proof.sp. +cargo run --release --bin provekit-cli -- prove-spark ./spark-artifacts/range-check-u8.pkp \ + --spark-dir ./spark_proofs \ + --spctx ./spark-artifacts/range-check-u8.spctx + +# 4. Natively verify the Noir proof. Native verification evaluates MLE directly. Spark proofs are useful only in the recursive verifier. +cargo run --release --bin provekit-cli -- verify \ + -v ./spark-artifacts/range-check-u8.pkv \ + --proof ./spark-artifacts/range-check-u8-proof.np + +# 5. Verify the batched SPARK proof. The verifier pulls the SPARK setup from +# the trusted .pkv. Pass every query that the prover saw, in index order +# (`_0`, `_1`, ...) — the transcript instance is bound to the postcard- +# serialized query slice, so order matters. +cargo run --release --bin provekit-cli -- verify-spark \ + ./spark_proofs/spark_proof.sp \ + ./spark-artifacts/range-check-u8.pkv \ + ./spark_proofs/spark_queries.json + +# Or, equivalently, with a glob (single-digit indices sort lexically): +# cargo run --release --bin provekit-cli -- verify-spark \ +# ./spark_proofs/spark_proof.sp \ +# ./spark-artifacts/range-check-u8.pkv \ +# ./spark_proofs/spark_query_*.json + +# TODO: 6. Recursively verify the Noir proof and SPARK. +``` + +The `range-check-u8` circuit uses the multi-challenge Noir API, so the +provekit prover takes the dual-commitment path and emits **two** spark +queries (`spark_query_0.json` and `spark_query_1.json`). Single-commitment +circuits emit just `spark_query_0.json` and step 5 needs only that one path. \ No newline at end of file diff --git a/provekit/spark/src/gpa.rs b/provekit/spark/src/gpa.rs new file mode 100644 index 000000000..4d1d3721f --- /dev/null +++ b/provekit/spark/src/gpa.rs @@ -0,0 +1,378 @@ +use { + anyhow::{ensure, Context}, + ark_ff::{AdditiveGroup, Field}, + provekit_common::{ + utils::{ + next_power_of_two, + sumcheck::{ + calculate_eq, calculate_evaluations_over_boolean_hypercube_for_eq, eval_cubic_poly, + sumcheck_fold_map_reduce, + }, + HALF, + }, + FieldElement, TranscriptSponge, + }, + tracing::instrument, + whir::transcript::{ProverState, VerifierMessage, VerifierState}, +}; + +#[instrument(skip_all)] +pub fn run_gpa2( + merlin: &mut ProverState, + leaves: Vec, +) -> anyhow::Result> { + let mut layers = calculate_binary_multiplication_tree(leaves)?; + + let mut drain = layers.drain(1..); + + let first_layer = drain.next().context("GPA tree has fewer than 2 layers")?; + let (accumulated_randomness, mut sumcheck_claim) = add_line_to_transcript(merlin, first_layer); + let mut accumulated_randomness = accumulated_randomness.to_vec(); + + for layer in drain { + (sumcheck_claim, accumulated_randomness) = + run_gpa_sumcheck(merlin, layer, sumcheck_claim, accumulated_randomness)?; + } + + Ok(accumulated_randomness) +} + +#[instrument(skip_all)] +pub fn run_gpa4( + merlin: &mut ProverState, + leaves: Vec, +) -> anyhow::Result> { + let mut layers = calculate_binary_multiplication_tree(leaves)?; + + let mut drain = layers.drain(2..); + + let coeffs = drain.next().context("GPA tree has fewer than 3 layers")?; + let coeffs = [ + coeffs[0], + coeffs[1] - coeffs[0], + coeffs[2] - coeffs[0], + coeffs[3] - coeffs[2] - coeffs[1] + coeffs[0], + ]; + + for c in &coeffs { + merlin.prover_message(c); + } + + let r0: FieldElement = merlin.verifier_message(); + let r1: FieldElement = merlin.verifier_message(); + let mut accumulated_randomness = vec![r0, r1]; + + let mut sumcheck_claim = coeffs[0] + coeffs[1] * r1 + coeffs[2] * r0 + coeffs[3] * r0 * r1; + + for layer in drain { + (sumcheck_claim, accumulated_randomness) = + run_gpa_sumcheck(merlin, layer, sumcheck_claim, accumulated_randomness)?; + } + + Ok(accumulated_randomness) +} + +fn calculate_binary_multiplication_tree( + array_to_prove: Vec, +) -> anyhow::Result>> { + use rayon::prelude::*; + + ensure!( + array_to_prove.len() == (1 << next_power_of_two(array_to_prove.len())), + "Input length must be power of two" + ); + + let mut layers = vec![]; + let mut current_layer = array_to_prove; + + while current_layer.len() > 1 { + let next_layer: Vec = current_layer + .par_chunks_exact(2) + .map(|pair| pair[0] * pair[1]) + .collect(); + + layers.push(current_layer); + current_layer = next_layer; + } + + layers.push(current_layer); + layers.reverse(); + Ok(layers) +} + +fn add_line_to_transcript( + merlin: &mut ProverState, + arr: Vec, +) -> ([FieldElement; 1], FieldElement) { + let line_poly = [arr[0], arr[1] - arr[0]]; + + for c in line_poly.iter() { + merlin.prover_message(c); + } + + let challenge: FieldElement = merlin.verifier_message(); + + let next_claim = line_poly[0] + line_poly[1] * challenge; + + ([challenge], next_claim) +} + +fn run_gpa_sumcheck( + merlin: &mut ProverState, + layer: Vec, + mut sumcheck_claim: FieldElement, + accumulated_randomness: Vec, +) -> anyhow::Result<(FieldElement, Vec)> { + let (mut even_layer, mut odd_layer) = split_even_odd(layer); + + let mut eq_evaluations = calculate_evaluations_over_boolean_hypercube_for_eq( + &accumulated_randomness, + 1 << accumulated_randomness.len(), + ); + let mut challenge; + let mut round_randomness = Vec::::new(); + let mut fold = None; + + loop { + let [eval_at_0, eval_at_neg1, eval_at_inf_over_x3] = sumcheck_fold_map_reduce( + [&mut eq_evaluations, &mut even_layer, &mut odd_layer], + fold, + |[eq, v0, v1]| { + [ + eq.0 * v0.0 * v1.0, + (eq.0 + eq.0 - eq.1) * (v0.0 + v0.0 - v0.1) * (v1.0 + v1.0 - v1.1), + (eq.1 - eq.0) * (v0.1 - v0.0) * (v1.1 - v1.0), + ] + }, + ); + + if fold.is_some() { + eq_evaluations.truncate(eq_evaluations.len() / 2); + even_layer.truncate(even_layer.len() / 2); + odd_layer.truncate(odd_layer.len() / 2); + } + + let poly_coeffs = reconstruct_cubic_from_evaluations( + sumcheck_claim, + eval_at_0, + eval_at_neg1, + eval_at_inf_over_x3, + ); + + ensure!( + sumcheck_claim + == poly_coeffs[0] + + poly_coeffs[0] + + poly_coeffs[1] + + poly_coeffs[2] + + poly_coeffs[3], + "Sumcheck binding check failed" + ); + + for coeff in &poly_coeffs { + merlin.prover_message(coeff); + } + challenge = merlin.verifier_message(); + + fold = Some(challenge); + sumcheck_claim = eval_cubic_poly(poly_coeffs, challenge); + round_randomness.push(challenge); + + if eq_evaluations.len() <= 2 { + break; + } + } + + let final_v0 = even_layer[0] + (even_layer[1] - even_layer[0]) * challenge; + let final_v1 = odd_layer[0] + (odd_layer[1] - odd_layer[0]) * challenge; + let final_v2 = eq_evaluations[0] + (eq_evaluations[1] - eq_evaluations[0]) * challenge; + + ensure!( + sumcheck_claim == final_v0 * final_v1 * final_v2, + "GPA sumcheck claim mismatch" + ); + + let line_coeffs = [final_v0, final_v1 - final_v0]; + + for c in &line_coeffs { + merlin.prover_message(c); + } + + let line_challenge: FieldElement = merlin.verifier_message(); + let next_claim = line_coeffs[0] + line_coeffs[1] * line_challenge; + round_randomness.push(line_challenge); + + Ok((next_claim, round_randomness)) +} + +fn reconstruct_cubic_from_evaluations( + binding_value: FieldElement, + at_0: FieldElement, + at_neg1: FieldElement, + at_inf_over_x3: FieldElement, +) -> [FieldElement; 4] { + let mut coeffs = [FieldElement::ZERO; 4]; + + coeffs[0] = at_0; + coeffs[2] = HALF * (binding_value + at_neg1 - at_0 - at_0 - at_0); + coeffs[3] = at_inf_over_x3; + coeffs[1] = binding_value - coeffs[0] - coeffs[0] - coeffs[3] - coeffs[2]; + + coeffs +} + +fn split_even_odd(input: Vec) -> (Vec, Vec) { + input + .chunks_exact(2) + .map(|chunk| (chunk[0], chunk[1])) + .unzip() +} + +pub struct GPASumcheckResult { + pub claimed_values: Vec, + pub last_sumcheck_value: FieldElement, + pub randomness: Vec, +} + +fn read_msgs( + arthur: &mut VerifierState<'_, TranscriptSponge>, + label: &str, +) -> anyhow::Result<[FieldElement; N]> { + let mut out = [FieldElement::ZERO; N]; + for (i, slot) in out.iter_mut().enumerate() { + *slot = arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read {label} [{i}]"))?; + } + Ok(out) +} + +fn run_gpa_layers( + arthur: &mut VerifierState<'_, TranscriptSponge>, + start_layer: usize, + height_of_binary_tree: usize, + mut sumcheck_value: FieldElement, + mut prev_randomness: Vec, + cubic_label: &str, + line_label: &str, +) -> anyhow::Result<(FieldElement, Vec)> { + let mut current_randomness = Vec::::new(); + + for layer_idx in start_layer..height_of_binary_tree - 1 { + for _ in 0..layer_idx { + let cubic_coeffs = read_msgs::<4>(arthur, cubic_label)?; + let sumcheck_challenge: FieldElement = arthur.verifier_message(); + + ensure!( + eval_cubic_poly(cubic_coeffs, FieldElement::ZERO) + + eval_cubic_poly(cubic_coeffs, FieldElement::ONE) + == sumcheck_value, + "Sumcheck verification failed at layer {layer_idx}" + ); + + current_randomness.push(sumcheck_challenge); + sumcheck_value = eval_cubic_poly(cubic_coeffs, sumcheck_challenge); + } + + let line_coeffs = read_msgs::<2>(arthur, line_label)?; + let line_challenge: FieldElement = arthur.verifier_message(); + + let expected_line_value = calculate_eq(&prev_randomness, ¤t_randomness) + * eval_line(&line_coeffs, &FieldElement::ZERO) + * eval_line(&line_coeffs, &FieldElement::ONE); + ensure!( + expected_line_value == sumcheck_value, + "Line evaluation mismatch" + ); + + current_randomness.push(line_challenge); + prev_randomness = current_randomness; + current_randomness = Vec::new(); + sumcheck_value = eval_line(&line_coeffs, &line_challenge); + } + + Ok((sumcheck_value, prev_randomness)) +} + +#[instrument(skip_all)] +pub fn gpa_sumcheck_verifier2( + arthur: &mut VerifierState<'_, TranscriptSponge>, + height_of_binary_tree: usize, +) -> anyhow::Result { + let claimed_values = read_msgs::<2>(arthur, "GPA2 claimed value")?; + let line_challenge: FieldElement = arthur.verifier_message(); + + let initial_sumcheck_value = eval_line(&claimed_values, &line_challenge); + let initial_randomness = vec![line_challenge]; + + let (last_sumcheck_value, randomness) = run_gpa_layers( + arthur, + 1, + height_of_binary_tree, + initial_sumcheck_value, + initial_randomness, + "GPA2 cubic coeff", + "GPA2 line coeff", + )?; + + let claimed_values = vec![claimed_values[0], claimed_values[0] + claimed_values[1]]; + + Ok(GPASumcheckResult { + claimed_values, + last_sumcheck_value, + randomness, + }) +} + +#[instrument(skip_all)] +pub fn gpa_sumcheck_verifier4( + arthur: &mut VerifierState<'_, TranscriptSponge>, + height_of_binary_tree: usize, +) -> anyhow::Result { + let claimed_values = read_msgs::<4>(arthur, "GPA4 claimed value")?; + let r0: FieldElement = arthur.verifier_message(); + let r1: FieldElement = arthur.verifier_message(); + let initial_randomness = vec![r0, r1]; + + let initial_sumcheck_value = claimed_values[0] + + claimed_values[1] * r1 + + claimed_values[2] * r0 + + claimed_values[3] * r0 * r1; + + let (last_sumcheck_value, randomness) = run_gpa_layers( + arthur, + 2, + height_of_binary_tree, + initial_sumcheck_value, + initial_randomness, + "GPA4 cubic coeff", + "GPA4 line coeff", + )?; + + let claimed_values = vec![ + claimed_values[0], + claimed_values[0] + claimed_values[1], + claimed_values[0] + claimed_values[2], + claimed_values[0] + claimed_values[1] + claimed_values[2] + claimed_values[3], + ]; + + Ok(GPASumcheckResult { + claimed_values, + last_sumcheck_value, + randomness, + }) +} + +pub fn eval_line(poly: &[FieldElement], point: &FieldElement) -> FieldElement { + poly[0] + *point * poly[1] +} + +pub fn calculate_adr(randomness: &[FieldElement]) -> FieldElement { + randomness + .iter() + .rev() + .enumerate() + .fold(FieldElement::ZERO, |acc, (i, &r)| { + acc + r * FieldElement::from(1u64 << i) + }) +} diff --git a/provekit/spark/src/lib.rs b/provekit/spark/src/lib.rs new file mode 100644 index 000000000..f2c9d2d38 --- /dev/null +++ b/provekit/spark/src/lib.rs @@ -0,0 +1,18 @@ +pub(crate) mod gpa; +pub(crate) mod memory; +pub(crate) mod prover; +mod serde_whir_witness; +pub(crate) mod setup; +pub(crate) mod sumcheck; +pub(crate) mod types; +pub(crate) mod utils; +pub(crate) mod verifier; + +pub use { + provekit_common::{MatrixDimensions, SparkSetup, SparkWhirConfigs}, + prover::SparkProverScheme, + setup::preprocess_spark, + types::{SparkMatrix, SparkProof, SparkProverContext, SparkWitnesses}, + utils::calculate_memory, + verifier::SparkVerifierScheme, +}; diff --git a/provekit/spark/src/memory.rs b/provekit/spark/src/memory.rs new file mode 100644 index 000000000..653c38c1e --- /dev/null +++ b/provekit/spark/src/memory.rs @@ -0,0 +1,164 @@ +use { + crate::{ + gpa::{calculate_adr, gpa_sumcheck_verifier2, run_gpa2}, + types::{Challenges, WhirWitness}, + utils::read_hint, + }, + anyhow::{ensure, Result}, + ark_ff::AdditiveGroup, + ark_std::One, + provekit_common::{FieldElement, TranscriptSponge, WhirConfig}, + rayon::prelude::*, + std::borrow::Cow, + tracing::instrument, + whir::{ + algebra::{linear_form::MultilinearExtension, multilinear_extend}, + protocols::irs_commit::Commitment, + transcript::{ProverState, VerifierState}, + }, +}; + +pub struct AxisConfig<'a> { + pub eq_memory: &'a [FieldElement], + pub final_timestamp: &'a [FieldElement], + pub whir_config: &'a WhirConfig, +} + +#[instrument(skip_all)] +pub fn prove_axis_init_final_product( + merlin: &mut ProverState, + config: AxisConfig<'_>, + final_ts_witness: &WhirWitness, + challenges: &Challenges, +) -> Result<()> { + let gamma = &challenges.gamma; + let tau = &challenges.tau; + let gamma_sq = *gamma * *gamma; + + let n = config.eq_memory.len(); + debug_assert_eq!( + config.final_timestamp.len(), + n, + "eq_memory and final_timestamp must have equal length" + ); + + // Per element: + // init[i] = i*gamma_sq + eq[i]*gamma - tau + // final[i] = init[i] + final_ts[i] + let gpa_leaves = tracing::info_span!("build_init_final_vecs").in_scope(|| { + let mut buf = vec![FieldElement::ZERO; 2 * n]; + let (init_section, final_section) = buf.split_at_mut(n); + + init_section + .par_iter_mut() + .zip(final_section.par_iter_mut()) + .enumerate() + .for_each(|(i, (init_slot, final_slot))| { + let base = + FieldElement::from(i as u64) * gamma_sq + config.eq_memory[i] * gamma - tau; + *init_slot = base; + *final_slot = base + config.final_timestamp[i]; + }); + buf + }); + + let gpa_randomness = run_gpa2(merlin, gpa_leaves)?; + let (_combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); + + let final_ts_eval = multilinear_extend(config.final_timestamp, evaluation_randomness); + merlin.prover_hint_ark(&final_ts_eval); + + produce_whir_proof( + merlin, + evaluation_randomness, + &[config.final_timestamp], + config.whir_config, + final_ts_witness, + )?; + + Ok(()) +} + +#[instrument(skip_all)] +pub fn verify_axis( + arthur: &mut VerifierState<'_, TranscriptSponge>, + num_axis_items: usize, + whir_config: &WhirConfig, + finalts_commitment: &Commitment, + init_mem_fn: impl Fn(&[FieldElement]) -> FieldElement, + tau: &FieldElement, + gamma: &FieldElement, + claimed_rs: &FieldElement, + claimed_ws: &FieldElement, +) -> Result<()> { + let gpa_result = gpa_sumcheck_verifier2( + arthur, + provekit_common::utils::next_power_of_two(num_axis_items) + 2, + )?; + + let claimed_init = gpa_result.claimed_values[0]; + let claimed_final = gpa_result.claimed_values[1]; + let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); + + let gamma_sq = *gamma * *gamma; + + let init_adr = calculate_adr(evaluation_randomness); + let init_mem = init_mem_fn(evaluation_randomness); + let init_opening = init_adr * gamma_sq + init_mem * gamma - tau; + + let final_cntr: FieldElement = read_hint(arthur, "final counter")?; + + let eval_weight = MultilinearExtension::new(evaluation_randomness.to_vec()); + let finalts_claim = whir_config + .verify(arthur, &[finalts_commitment], &[final_cntr]) + .map_err(|e| anyhow::anyhow!("WHIR verify failed: {e}"))?; + finalts_claim + .verify([&eval_weight as &dyn whir::algebra::linear_form::LinearForm]) + .map_err(|e| anyhow::anyhow!("FinalClaim check failed for final timestamps: {e}"))?; + + let final_opening = init_adr * gamma_sq + init_mem * gamma + final_cntr - tau; + + let evaluated_value = init_opening * (FieldElement::one() - last_randomness[0]) + + final_opening * last_randomness[0]; + + ensure!( + evaluated_value == gpa_result.last_sumcheck_value, + "init/final GPA sumcheck final value inconsistent with evaluated multilinear extension" + ); + + ensure!( + claimed_init * claimed_ws == claimed_final * claimed_rs, + "memory-checking product mismatch: init * ws != final * rs" + ); + + Ok(()) +} + +#[instrument(skip_all)] +pub fn produce_whir_proof( + merlin: &mut ProverState, + evaluation_point: &[FieldElement], + vectors: &[&[FieldElement]], + config: &WhirConfig, + witness: &WhirWitness, +) -> Result<()> { + let lf = MultilinearExtension::new(evaluation_point.to_vec()); + + let evaluations: Vec = vectors + .iter() + .map(|v| multilinear_extend(v, evaluation_point)) + .collect(); + + _ = config.prove( + merlin, + vectors.iter().map(|v| Cow::Borrowed(*v)).collect(), + vec![Cow::Borrowed(witness)], + vec![Box::new(lf) + as Box< + dyn whir::algebra::linear_form::LinearForm, + >], + Cow::Borrowed(&evaluations), + ); + + Ok(()) +} diff --git a/provekit/spark/src/prover.rs b/provekit/spark/src/prover.rs new file mode 100644 index 000000000..09c3cd4bf --- /dev/null +++ b/provekit/spark/src/prover.rs @@ -0,0 +1,608 @@ +use { + crate::{ + gpa::run_gpa4, + memory::{prove_axis_init_final_product, AxisConfig}, + sumcheck::run_spark_sumcheck, + types::{ + Challenges, EValuesForMatrix, MatrixDimensions, Memory, SparkMatrix, SparkProof, + SparkProverContext, SparkWhirConfigs, WhirWitness, + }, + utils::{alphas_from_spark, calculate_memory}, + }, + anyhow::{ensure, Result}, + ark_ff::{AdditiveGroup, Field, Zero}, + provekit_common::{ + spark::{SparkColQuery, SparkQueryBatch}, + utils::{ + next_power_of_two, + sumcheck::{ + calculate_evaluations_over_boolean_hypercube_for_eq, eval_quadratic_poly, + sumcheck_fold_map_reduce, + }, + }, + FieldElement, HashConfig, TranscriptSponge, WhirConfig, WhirR1CSProof, + }, + rayon::{join, prelude::*}, + std::borrow::Cow, + tracing::instrument, + whir::{ + algebra::{linear_form::MultilinearExtension, multilinear_extend}, + engines::EngineId, + parameters::ProtocolParameters, + transcript::{DomainSeparator, ProverState, VerifierMessage}, + }, +}; + +pub struct SparkProverScheme { + pub whir_configs: SparkWhirConfigs, + pub matrix_dimensions: MatrixDimensions, +} + +pub fn new_whir_config_for_size( + log_size: usize, + batch_size: usize, + hash_id: EngineId, +) -> WhirConfig { + let nv = log_size.max(4); + + let whir_params = ProtocolParameters { + unique_decoding: false, + initial_folding_factor: 3, + security_level: 128, + pow_bits: 10, + folding_factor: 3, + starting_log_inv_rate: 2, + batch_size, + hash_id, + }; + + WhirConfig::new(1 << nv, &whir_params) +} + +impl SparkProverScheme { + pub fn new_for_r1cs(r1cs: &provekit_common::R1CS, hash_config: HashConfig) -> Self { + let num_rows = 2 * r1cs.num_constraints(); + let num_cols = 2 * r1cs.num_witnesses(); + let nonzero_terms = + r1cs.a().iter().count() + r1cs.b().iter().count() + r1cs.c().iter().count(); + + Self::new(num_rows, num_cols, nonzero_terms, hash_config) + } + + pub fn new( + num_rows: usize, + num_cols: usize, + nonzero_terms: usize, + hash_config: HashConfig, + ) -> Self { + let padded_num_entries = 1 << next_power_of_two(nonzero_terms); + let hash_id = hash_config.engine_id(); + + let row_config = new_whir_config_for_size(next_power_of_two(num_rows), 1, hash_id); + let col_config = new_whir_config_for_size(next_power_of_two(num_cols), 1, hash_id); + let num_terms_2batched_config = + new_whir_config_for_size(next_power_of_two(padded_num_entries), 2, hash_id); + let num_terms_5batched_config = + new_whir_config_for_size(next_power_of_two(padded_num_entries), 5, hash_id); + + Self { + whir_configs: SparkWhirConfigs { + row: row_config, + col: col_config, + num_terms_2batched: num_terms_2batched_config, + num_terms_5batched: num_terms_5batched_config, + }, + matrix_dimensions: MatrixDimensions { + num_rows, + num_cols, + nonzero_terms, + }, + } + } +} + +impl SparkProverScheme { + #[instrument(skip_all)] + pub fn prove( + &self, + spark_data: &SparkProverContext, + batch: &SparkQueryBatch, + ) -> Result { + ensure!( + !batch.queries.is_empty(), + "SPARK prover needs at least one query" + ); + + let padded_num_entries = spark_data.matrix.coo.val.len(); + + let mut merlin = ProverState::new( + &DomainSeparator::protocol(&self.whir_configs) + .session(&spark_data.setup.transcript.narg_string) + .instance(&batch.hash_bytes()?), + TranscriptSponge::from_config(spark_data.setup.hash_config), + ); + + let request: SparkColQuery = if batch.queries.len() == 1 { + batch.queries[0].clone() + } else { + let alphas = alphas_from_spark( + &spark_data.matrix, + &spark_data.setup.matrix_dimensions, + &batch.row, + ); + + let beta: FieldElement = merlin.verifier_message(); + + let domain_size = spark_data.setup.matrix_dimensions.num_cols / 2; + let mut hypercube = vec![FieldElement::ZERO; domain_size]; + let mut claimed_evals = [FieldElement::ZERO; 3]; + let mut beta_pow = FieldElement::ONE; + for query in &batch.queries { + let eq = + calculate_evaluations_over_boolean_hypercube_for_eq(&query.col, domain_size); + for (slot, &e) in hypercube.iter_mut().zip(eq.iter()) { + *slot += beta_pow * e; + } + claimed_evals[0] += beta_pow * query.claimed_a; + claimed_evals[1] += beta_pow * query.claimed_b; + claimed_evals[2] += beta_pow * query.claimed_c; + beta_pow *= beta; + } + + let alpha_refs: [&[FieldElement]; 3] = [&alphas[0], &alphas[1], &alphas[2]]; + let (folded_values, folding_randomness) = + run_parallel_sumchecks(&mut merlin, &hypercube, alpha_refs, claimed_evals)?; + + SparkColQuery { + col: folding_randomness, + claimed_a: folded_values[1], + claimed_b: folded_values[2], + claimed_c: folded_values[3], + } + }; + + let r: FieldElement = merlin.verifier_message(); + ensure!( + !(FieldElement::ONE + r).is_zero(), + "SPARK RLC randomness must not equal -1 (would zero the denominator)" + ); + + let b1 = r / (FieldElement::ONE + r); + let combined = request.claimed_a + r * request.claimed_b + r * r * request.claimed_c; + let claimed_value = combined / (FieldElement::ONE + r) / (FieldElement::ONE + r); + + let (memory, e_values) = + compute_spark_data(&batch.row, &request, b1, spark_data, padded_num_entries); + + prove_spark( + &mut merlin, + spark_data, + &e_values, + claimed_value, + &memory, + &self.whir_configs, + )?; + + let proof = merlin.proof(); + Ok(SparkProof(WhirR1CSProof { + narg_string: proof.narg_string, + hints: proof.hints, + #[cfg(debug_assertions)] + pattern: proof.pattern, + })) + } +} + +#[instrument(skip_all)] +fn compute_spark_data( + row: &[FieldElement], + request: &SparkColQuery, + b1: FieldElement, + spark_data: &SparkProverContext, + padded_num_entries: usize, +) -> (Memory, EValuesForMatrix) { + let memory = calculate_memory(b1, row, &request.col); + let e_values = compute_e_values(spark_data, &memory, padded_num_entries); + (memory, e_values) +} + +#[instrument(skip_all)] +fn compute_e_values( + spark_data: &SparkProverContext, + memory: &Memory, + padded_num_entries: usize, +) -> EValuesForMatrix { + let (e_rx, e_ry) = rayon::join( + || { + spark_data.matrix.coo.row[..padded_num_entries] + .par_iter() + .map(|&r| memory.eq_rx[r]) + .collect() + }, + || { + spark_data.matrix.coo.col[..padded_num_entries] + .par_iter() + .map(|&c| memory.eq_ry[c]) + .collect() + }, + ); + EValuesForMatrix { e_rx, e_ry } +} + +#[instrument(skip_all)] +fn prove_spark( + merlin: &mut ProverState, + data: &SparkProverContext, + e_values: &EValuesForMatrix, + claimed_value: FieldElement, + memory: &Memory, + whir_configs: &SparkWhirConfigs, +) -> Result<()> { + let e_values_witness = commit_e_values(merlin, whir_configs, e_values); + + let (folding_randomness, sumcheck_final_folds) = + sumcheck_and_its_proofs(merlin, &data.matrix, e_values, claimed_value)?; + + memory_checking( + merlin, + data, + e_values, + &e_values_witness, + memory, + whir_configs, + &folding_randomness, + sumcheck_final_folds, + )?; + + Ok(()) +} + +#[instrument(skip_all)] +fn memory_checking( + merlin: &mut ProverState, + data: &SparkProverContext, + e_values: &EValuesForMatrix, + e_values_witness: &WhirWitness, + memory: &Memory, + whir_configs: &SparkWhirConfigs, + folding_randomness: &[FieldElement], + sumcheck_final_folds: [FieldElement; 3], +) -> Result<()> { + let tau: FieldElement = merlin.verifier_message(); + let gamma: FieldElement = merlin.verifier_message(); + let challenges = Challenges { tau, gamma }; + + prove_combined_rs_ws_product( + merlin, + &data.matrix, + e_values, + e_values_witness, + &data.witnesses.vals_rs_ws_witness, + whir_configs, + &challenges, + folding_randomness, + sumcheck_final_folds, + )?; + + let final_row_field = data.matrix.timestamps.final_row_field(); + prove_axis_init_final_product( + merlin, + AxisConfig { + eq_memory: &memory.eq_rx, + final_timestamp: &final_row_field, + whir_config: &whir_configs.row, + }, + &data.witnesses.final_row_ts_witness, + &challenges, + )?; + drop(final_row_field); + + let final_col_field = data.matrix.timestamps.final_col_field(); + prove_axis_init_final_product( + merlin, + AxisConfig { + eq_memory: &memory.eq_ry, + final_timestamp: &final_col_field, + whir_config: &whir_configs.col, + }, + &data.witnesses.final_col_ts_witness, + &challenges, + )?; + drop(final_col_field); + + Ok(()) +} + +#[instrument(skip_all)] +fn sumcheck_and_its_proofs( + merlin: &mut ProverState, + matrix: &SparkMatrix, + e_values: &EValuesForMatrix, + claimed_value: FieldElement, +) -> Result<(Vec, [FieldElement; 3])> { + let mles: [&[FieldElement]; 3] = [&matrix.coo.val, &e_values.e_rx, &e_values.e_ry]; + let (sumcheck_final_folds, folding_randomness) = + run_spark_sumcheck(merlin, mles, claimed_value)?; + + merlin.prover_hint_ark(&[ + sumcheck_final_folds[0], + sumcheck_final_folds[1], + sumcheck_final_folds[2], + ]); + + Ok((folding_randomness, [ + sumcheck_final_folds[0], + sumcheck_final_folds[1], + sumcheck_final_folds[2], + ])) +} + +#[instrument(skip_all)] +fn prove_combined_rs_ws_product( + merlin: &mut ProverState, + matrix: &SparkMatrix, + e_values: &EValuesForMatrix, + e_values_witness: &WhirWitness, + vals_rs_ws_witness: &WhirWitness, + whir_configs: &SparkWhirConfigs, + challenges: &Challenges, + folding_randomness: &[FieldElement], + sumcheck_final_folds: [FieldElement; 3], +) -> Result<()> { + let gamma_sq = challenges.gamma * challenges.gamma; + let one = FieldElement::ONE; + + let row_field = matrix.coo.row_field(); + let col_field = matrix.coo.col_field(); + let read_row_field = matrix.timestamps.read_row_field(); + let read_col_field = matrix.timestamps.read_col_field(); + let n = row_field.len(); + + let gpa_leaves_flat = tracing::info_span!("build_rs_ws_pairs").in_scope(|| { + let mut buf = vec![FieldElement::ZERO; 4 * n]; + let (row_section, col_section) = buf.split_at_mut(2 * n); + let (row_rs, row_ws) = row_section.split_at_mut(n); + let (col_rs, col_ws) = col_section.split_at_mut(n); + + row_rs + .par_iter_mut() + .zip(row_ws.par_iter_mut()) + .zip(col_rs.par_iter_mut()) + .zip(col_ws.par_iter_mut()) + .enumerate() + .for_each(|(i, (((rs_r, ws_r), rs_c), ws_c))| { + let row_base = row_field[i] * gamma_sq + + e_values.e_rx[i] * challenges.gamma + + read_row_field[i] + - challenges.tau; + let col_base = col_field[i] * gamma_sq + + e_values.e_ry[i] * challenges.gamma + + read_col_field[i] + - challenges.tau; + *rs_r = row_base; + *ws_r = row_base + one; + *rs_c = col_base; + *ws_c = col_base + one; + }); + buf + }); + let gpa_randomness = run_gpa4(merlin, gpa_leaves_flat)?; + + let (_combination_randomness, evaluation_randomness) = gpa_randomness.split_at(2); + + let polys: [&[FieldElement]; 4] = [&row_field, &read_row_field, &col_field, &read_col_field]; + let [row_address_eval, row_timestamp_eval, col_address_eval, col_timestamp_eval]: [_; 4] = + tracing::info_span!("multilinear_extend_rs_ws") + .in_scope(|| { + polys + .par_iter() + .map(|p| multilinear_extend(p, evaluation_randomness)) + .collect::>() + }) + .try_into() + .expect("4 polys"); + + merlin.prover_hint_ark(&row_address_eval); + merlin.prover_hint_ark(&row_timestamp_eval); + merlin.prover_hint_ark(&col_address_eval); + merlin.prover_hint_ark(&col_timestamp_eval); + + let pairs: [(&[FieldElement], &[FieldElement]); 5] = [ + (&row_field, folding_randomness), + (&read_row_field, folding_randomness), + (&col_field, folding_randomness), + (&read_col_field, folding_randomness), + (&matrix.coo.val, evaluation_randomness), + ]; + let [row_field_at_fold, read_row_at_fold, col_field_at_fold, read_col_at_fold, vals_at_eval]: [_; 5] = + tracing::info_span!("multilinear_extend_vals_rs_ws_cross") + .in_scope(|| { + pairs + .par_iter() + .map(|(p, r)| multilinear_extend(p, r)) + .collect::>() + }) + .try_into() + .expect("5 polys"); + + merlin.prover_hint_ark(&row_field_at_fold); + merlin.prover_hint_ark(&read_row_at_fold); + merlin.prover_hint_ark(&col_field_at_fold); + merlin.prover_hint_ark(&read_col_at_fold); + merlin.prover_hint_ark(&vals_at_eval); + + let fold_lf_for_vals_rs_ws = MultilinearExtension::new(folding_randomness.to_vec()); + let eval_lf_for_vals_rs_ws = MultilinearExtension::new(evaluation_randomness.to_vec()); + // Layout: linear_form-major, then poly. Polys: [vals, row_field, read_row, + // col_field, read_col]. + let vals_rs_ws_evaluations = [ + // fold_lf + sumcheck_final_folds[0], + row_field_at_fold, + read_row_at_fold, + col_field_at_fold, + read_col_at_fold, + // eval_lf + vals_at_eval, + row_address_eval, + row_timestamp_eval, + col_address_eval, + col_timestamp_eval, + ]; + let _ = whir_configs.num_terms_5batched.prove( + merlin, + vec![ + Cow::Borrowed(&matrix.coo.val), + Cow::Borrowed(&row_field), + Cow::Borrowed(&read_row_field), + Cow::Borrowed(&col_field), + Cow::Borrowed(&read_col_field), + ], + vec![Cow::Borrowed(vals_rs_ws_witness)], + vec![ + Box::new(fold_lf_for_vals_rs_ws) + as Box>, + Box::new(eval_lf_for_vals_rs_ws) + as Box>, + ], + Cow::Borrowed(&vals_rs_ws_evaluations), + ); + + let (row_value_eval, col_value_eval) = tracing::info_span!("multilinear_extend_e_values") + .in_scope(|| { + join( + || multilinear_extend(&e_values.e_rx, evaluation_randomness), + || multilinear_extend(&e_values.e_ry, evaluation_randomness), + ) + }); + merlin.prover_hint_ark(&row_value_eval); + merlin.prover_hint_ark(&col_value_eval); + + let fold_lf = MultilinearExtension::new(folding_randomness.to_vec()); + let eval_lf = MultilinearExtension::new(evaluation_randomness.to_vec()); + let evaluations = [ + sumcheck_final_folds[1], + sumcheck_final_folds[2], + row_value_eval, + col_value_eval, + ]; + let _ = whir_configs.num_terms_2batched.prove( + merlin, + vec![Cow::Borrowed(&e_values.e_rx), Cow::Borrowed(&e_values.e_ry)], + vec![Cow::Borrowed(e_values_witness)], + vec![ + Box::new(fold_lf) as Box>, + Box::new(eval_lf) as Box>, + ], + Cow::Borrowed(&evaluations), + ); + + Ok(()) +} + +#[instrument(skip_all)] +fn commit_e_values( + merlin: &mut ProverState, + whir_configs: &SparkWhirConfigs, + e_values: &EValuesForMatrix, +) -> WhirWitness { + whir_configs + .num_terms_2batched + .commit(merlin, &[&e_values.e_rx, &e_values.e_ry]) +} + +pub fn run_parallel_sumchecks( + merlin: &mut ProverState, + hypercube: &[FieldElement], + alphas: [&[FieldElement]; 3], + mut claimed_values: [FieldElement; 3], +) -> Result<([FieldElement; 4], Vec)> { + let mut sumcheck_randomness; + let mut sumcheck_randomness_accumulator = Vec::::new(); + let mut fold = None; + + let mut h_mle = hypercube.to_vec(); + let mut a_mle = alphas[0].to_vec(); + let mut b_mle = alphas[1].to_vec(); + let mut c_mle = alphas[2].to_vec(); + loop { + let [a_hhat_i_at_0, a_highest_coeff, b_hhat_i_at_0, b_highest_coeff, c_hhat_i_at_0, c_highest_coeff] = + sumcheck_fold_map_reduce( + [&mut h_mle, &mut a_mle, &mut b_mle, &mut c_mle], + fold, + |[h_mle, a_mle, b_mle, c_mle]| { + [ + h_mle.0 * a_mle.0, + (h_mle.1 - h_mle.0) * (a_mle.1 - a_mle.0), + h_mle.0 * b_mle.0, + (h_mle.1 - h_mle.0) * (b_mle.1 - b_mle.0), + h_mle.0 * c_mle.0, + (h_mle.1 - h_mle.0) * (c_mle.1 - c_mle.0), + ] + }, + ); + + if fold.is_some() { + h_mle.truncate(h_mle.len() / 2); + a_mle.truncate(a_mle.len() / 2); + b_mle.truncate(b_mle.len() / 2); + c_mle.truncate(c_mle.len() / 2); + } + + let mut a_hhat_i_coeffs = [FieldElement::zero(); 3]; + + a_hhat_i_coeffs[0] = a_hhat_i_at_0; + a_hhat_i_coeffs[2] = a_highest_coeff; + a_hhat_i_coeffs[1] = + claimed_values[0] - a_hhat_i_coeffs[0] - a_hhat_i_coeffs[0] - a_hhat_i_coeffs[2]; + + for a_coeff in &a_hhat_i_coeffs { + merlin.prover_message(a_coeff); + } + + let mut b_hhat_i_coeffs = [FieldElement::zero(); 3]; + + b_hhat_i_coeffs[0] = b_hhat_i_at_0; + b_hhat_i_coeffs[2] = b_highest_coeff; + b_hhat_i_coeffs[1] = + claimed_values[1] - b_hhat_i_coeffs[0] - b_hhat_i_coeffs[0] - b_hhat_i_coeffs[2]; + + for b_coeff in &b_hhat_i_coeffs { + merlin.prover_message(b_coeff); + } + + let mut c_hhat_i_coeffs = [FieldElement::zero(); 3]; + + c_hhat_i_coeffs[0] = c_hhat_i_at_0; + c_hhat_i_coeffs[2] = c_highest_coeff; + c_hhat_i_coeffs[1] = + claimed_values[2] - c_hhat_i_coeffs[0] - c_hhat_i_coeffs[0] - c_hhat_i_coeffs[2]; + + for c_coeff in &c_hhat_i_coeffs { + merlin.prover_message(c_coeff); + } + + sumcheck_randomness = merlin.verifier_message(); + fold = Some(sumcheck_randomness); + claimed_values[0] = eval_quadratic_poly(a_hhat_i_coeffs, sumcheck_randomness); + claimed_values[1] = eval_quadratic_poly(b_hhat_i_coeffs, sumcheck_randomness); + claimed_values[2] = eval_quadratic_poly(c_hhat_i_coeffs, sumcheck_randomness); + + sumcheck_randomness_accumulator.push(sumcheck_randomness); + if h_mle.len() <= 2 { + break; + } + } + + let folded_h = h_mle[0] + (h_mle[1] - h_mle[0]) * sumcheck_randomness; + let folded_a = a_mle[0] + (a_mle[1] - a_mle[0]) * sumcheck_randomness; + let folded_b = b_mle[0] + (b_mle[1] - b_mle[0]) * sumcheck_randomness; + let folded_c = c_mle[0] + (c_mle[1] - c_mle[0]) * sumcheck_randomness; + + merlin.prover_hint_ark(&[folded_a, folded_b, folded_c]); + + Ok(( + [folded_h, folded_a, folded_b, folded_c], + sumcheck_randomness_accumulator, + )) +} diff --git a/provekit/spark/src/serde_whir_witness.rs b/provekit/spark/src/serde_whir_witness.rs new file mode 100644 index 000000000..81c9f51bf --- /dev/null +++ b/provekit/spark/src/serde_whir_witness.rs @@ -0,0 +1,68 @@ +use { + crate::types::WhirWitness, + provekit_common::{utils::serde_ark_vec, FieldElement}, + serde::{ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer}, + whir::protocols::{ + irs_commit::{Evaluations, Witness}, + matrix_commit, + }, +}; + +pub fn serialize(w: &WhirWitness, s: S) -> Result { + let mut st = s.serialize_struct("WhirWitness", 4)?; + st.serialize_field("masks", &ArkVecRef(&w.masks))?; + st.serialize_field("matrix", &ArkVecRef(&w.matrix))?; + st.serialize_field("matrix_witness", &w.matrix_witness)?; + st.serialize_field("out_of_domain", &EvaluationsRef(&w.out_of_domain))?; + st.end() +} + +pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result { + let m = WitnessMirror::deserialize(d)?; + Ok(Witness { + masks: m.masks, + matrix: m.matrix, + matrix_witness: m.matrix_witness, + out_of_domain: Evaluations { + points: m.out_of_domain.points, + matrix: m.out_of_domain.matrix, + }, + }) +} + +struct ArkVecRef<'a>(&'a Vec); + +impl Serialize for ArkVecRef<'_> { + fn serialize(&self, s: S) -> Result { + serde_ark_vec::serialize(self.0, s) + } +} + +struct EvaluationsRef<'a>(&'a Evaluations); + +impl Serialize for EvaluationsRef<'_> { + fn serialize(&self, s: S) -> Result { + let mut st = s.serialize_struct("Evaluations", 2)?; + st.serialize_field("points", &ArkVecRef(&self.0.points))?; + st.serialize_field("matrix", &ArkVecRef(&self.0.matrix))?; + st.end() + } +} + +#[derive(Deserialize)] +struct WitnessMirror { + #[serde(with = "serde_ark_vec")] + masks: Vec, + #[serde(with = "serde_ark_vec")] + matrix: Vec, + matrix_witness: matrix_commit::Witness, + out_of_domain: EvaluationsMirror, +} + +#[derive(Deserialize)] +struct EvaluationsMirror { + #[serde(with = "serde_ark_vec")] + points: Vec, + #[serde(with = "serde_ark_vec")] + matrix: Vec, +} diff --git a/provekit/spark/src/setup.rs b/provekit/spark/src/setup.rs new file mode 100644 index 000000000..d18210561 --- /dev/null +++ b/provekit/spark/src/setup.rs @@ -0,0 +1,117 @@ +use { + crate::{ + prover::SparkProverScheme, + types::{SparkMatrix, SparkSetup, SparkWitnesses}, + }, + anyhow::Result, + provekit_common::{FieldElement, HashConfig, TranscriptSponge, WhirR1CSProof}, + tracing::instrument, + whir::{ + protocols::irs_commit::Commitment, + transcript::{codecs::Empty, DomainSeparator, Proof, ProverState, VerifierState}, + }, +}; + +pub(crate) struct PrecomputedCommitments { + pub vals_rsws: Commitment, + pub a_row_finalts: Commitment, + pub a_col_finalts: Commitment, +} + +#[instrument(skip_all)] +pub fn preprocess_spark( + matrix: &SparkMatrix, + hash_config: HashConfig, +) -> (SparkSetup, SparkWitnesses) { + let num_rows = matrix.timestamps.final_row.len(); + let num_cols = matrix.timestamps.final_col.len(); + let nonzero_terms = matrix.coo.val.len(); + let scheme = SparkProverScheme::new(num_rows, num_cols, nonzero_terms, hash_config); + + let ds = DomainSeparator::protocol(&scheme.whir_configs).instance(&Empty); + let mut merlin = ProverState::new(&ds, TranscriptSponge::from_config(hash_config)); + + let row_field = matrix.coo.row_field(); + let col_field = matrix.coo.col_field(); + let read_row_field = matrix.timestamps.read_row_field(); + let read_col_field = matrix.timestamps.read_col_field(); + let vals_rs_ws_witness = scheme + .whir_configs + .num_terms_5batched + .commit(&mut merlin, &[ + &matrix.coo.val, + &row_field, + &read_row_field, + &col_field, + &read_col_field, + ]); + drop((row_field, col_field, read_row_field, read_col_field)); + let final_row_field = matrix.timestamps.final_row_field(); + let final_row_ts_witness = scheme + .whir_configs + .row + .commit(&mut merlin, &[&final_row_field]); + drop(final_row_field); + let final_col_field = matrix.timestamps.final_col_field(); + let final_col_ts_witness = scheme + .whir_configs + .col + .commit(&mut merlin, &[&final_col_field]); + drop(final_col_field); + + let proof = merlin.proof(); + let setup = SparkSetup { + whir_configs: scheme.whir_configs, + matrix_dimensions: scheme.matrix_dimensions, + transcript: WhirR1CSProof { + narg_string: proof.narg_string, + hints: proof.hints, + #[cfg(debug_assertions)] + pattern: proof.pattern, + }, + hash_config, + }; + let witnesses = SparkWitnesses { + vals_rs_ws_witness, + final_row_ts_witness, + final_col_ts_witness, + }; + (setup, witnesses) +} + +pub(crate) fn extract_commitments(setup: &SparkSetup) -> Result { + let setup_ds = DomainSeparator::protocol(&setup.whir_configs).instance(&Empty); + let setup_proof = Proof { + narg_string: setup.transcript.narg_string.clone(), + hints: setup.transcript.hints.clone(), + #[cfg(debug_assertions)] + pattern: setup.transcript.pattern.clone(), + }; + let mut side = VerifierState::new( + &setup_ds, + &setup_proof, + TranscriptSponge::from_config(setup.hash_config), + ); + + let vals_rsws = setup + .whir_configs + .num_terms_5batched + .receive_commitment(&mut side) + .map_err(|e| anyhow::anyhow!("Failed to reconstruct vals_rsws commitment: {e}"))?; + let a_row_finalts = setup + .whir_configs + .row + .receive_commitment(&mut side) + .map_err(|e| anyhow::anyhow!("Failed to reconstruct row finalts commitment: {e}"))?; + let a_col_finalts = setup + .whir_configs + .col + .receive_commitment(&mut side) + .map_err(|e| anyhow::anyhow!("Failed to reconstruct col finalts commitment: {e}"))?; + + Ok(PrecomputedCommitments { + vals_rsws, + a_row_finalts, + a_col_finalts, + }) +} diff --git a/provekit/spark/src/sumcheck.rs b/provekit/spark/src/sumcheck.rs new file mode 100644 index 000000000..db0ad7dc2 --- /dev/null +++ b/provekit/spark/src/sumcheck.rs @@ -0,0 +1,197 @@ +use { + crate::utils::read_hint, + anyhow::{ensure, Result}, + ark_std::{One, Zero}, + provekit_common::{ + utils::{ + sumcheck::{eval_cubic_poly, eval_quadratic_poly, sumcheck_fold_map_reduce}, + HALF, + }, + FieldElement, TranscriptSponge, + }, + tracing::instrument, + whir::transcript::{ProverState, VerifierMessage, VerifierState}, +}; + +#[instrument(skip_all)] +pub fn run_spark_sumcheck( + merlin: &mut ProverState, + mles: [&[FieldElement]; 3], + mut claimed_value: FieldElement, +) -> Result<([FieldElement; 3], Vec)> { + let mut sumcheck_randomness; + let mut sumcheck_randomness_accumulator = Vec::::new(); + let mut fold = None; + + let mut m0 = mles[0].to_vec(); + let mut m1 = mles[1].to_vec(); + let mut m2 = mles[2].to_vec(); + + loop { + let [hhat_i_at_0, hhat_i_at_em1, hhat_i_at_inf_over_x_cube] = + sumcheck_fold_map_reduce([&mut m0, &mut m1, &mut m2], fold, |[m0, m1, m2]| { + [ + m0.0 * m1.0 * m2.0, + (m0.0 + m0.0 - m0.1) * (m1.0 + m1.0 - m1.1) * (m2.0 + m2.0 - m2.1), + (m0.1 - m0.0) * (m1.1 - m1.0) * (m2.1 - m2.0), + ] + }); + + if fold.is_some() { + m0.truncate(m0.len() / 2); + m1.truncate(m1.len() / 2); + m2.truncate(m2.len() / 2); + } + + let mut hhat_i_coeffs = [FieldElement::zero(); 4]; + + hhat_i_coeffs[0] = hhat_i_at_0; + hhat_i_coeffs[2] = + HALF * (claimed_value + hhat_i_at_em1 - hhat_i_at_0 - hhat_i_at_0 - hhat_i_at_0); + hhat_i_coeffs[3] = hhat_i_at_inf_over_x_cube; + hhat_i_coeffs[1] = claimed_value + - hhat_i_coeffs[0] + - hhat_i_coeffs[0] + - hhat_i_coeffs[3] + - hhat_i_coeffs[2]; + + ensure!( + claimed_value + == hhat_i_coeffs[0] + + hhat_i_coeffs[0] + + hhat_i_coeffs[1] + + hhat_i_coeffs[2] + + hhat_i_coeffs[3], + "Sumcheck binding check failed" + ); + + for coeff in &hhat_i_coeffs { + merlin.prover_message(coeff); + } + sumcheck_randomness = merlin.verifier_message(); + fold = Some(sumcheck_randomness); + claimed_value = eval_cubic_poly(hhat_i_coeffs, sumcheck_randomness); + sumcheck_randomness_accumulator.push(sumcheck_randomness); + if m0.len() <= 2 { + break; + } + } + + let folded_v0 = m0[0] + (m0[1] - m0[0]) * sumcheck_randomness; + let folded_v1 = m1[0] + (m1[1] - m1[0]) * sumcheck_randomness; + let folded_v2 = m2[0] + (m2[1] - m2[0]) * sumcheck_randomness; + + Ok(( + [folded_v0, folded_v1, folded_v2], + sumcheck_randomness_accumulator, + )) +} + +#[instrument(skip_all)] +pub fn run_parallel_sumchecks_verifier( + arthur: &mut VerifierState<'_, TranscriptSponge>, + variable_count: usize, + mut claimed_values: [FieldElement; 3], +) -> Result<([FieldElement; 3], [FieldElement; 3], Vec)> { + let mut folding_randomness = Vec::with_capacity(variable_count); + + for _ in 0..variable_count { + let a_coeffs: [FieldElement; 3] = [ + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read parallel sumcheck a_coeffs[0]"))?, + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read parallel sumcheck a_coeffs[1]"))?, + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read parallel sumcheck a_coeffs[2]"))?, + ]; + let b_coeffs: [FieldElement; 3] = [ + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read parallel sumcheck b_coeffs[0]"))?, + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read parallel sumcheck b_coeffs[1]"))?, + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read parallel sumcheck b_coeffs[2]"))?, + ]; + let c_coeffs: [FieldElement; 3] = [ + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read parallel sumcheck c_coeffs[0]"))?, + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read parallel sumcheck c_coeffs[1]"))?, + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read parallel sumcheck c_coeffs[2]"))?, + ]; + + ensure!( + claimed_values[0] == a_coeffs[0] + a_coeffs[0] + a_coeffs[1] + a_coeffs[2], + "parallel sumcheck equality failed (a)" + ); + ensure!( + claimed_values[1] == b_coeffs[0] + b_coeffs[0] + b_coeffs[1] + b_coeffs[2], + "parallel sumcheck equality failed (b)" + ); + ensure!( + claimed_values[2] == c_coeffs[0] + c_coeffs[0] + c_coeffs[1] + c_coeffs[2], + "parallel sumcheck equality failed (c)" + ); + + let alpha_i: FieldElement = arthur.verifier_message(); + folding_randomness.push(alpha_i); + + claimed_values[0] = eval_quadratic_poly(a_coeffs, alpha_i); + claimed_values[1] = eval_quadratic_poly(b_coeffs, alpha_i); + claimed_values[2] = eval_quadratic_poly(c_coeffs, alpha_i); + } + + let folded: [FieldElement; 3] = read_hint(arthur, "parallel sumcheck folded values")?; + + Ok((claimed_values, folded, folding_randomness)) +} + +#[instrument(skip_all)] +pub fn run_sumcheck_verifier_spark( + arthur: &mut VerifierState<'_, TranscriptSponge>, + variable_count: usize, + initial_sumcheck_val: FieldElement, +) -> Result<(Vec, FieldElement)> { + let mut saved_val_for_sumcheck_equality_assertion = initial_sumcheck_val; + + let mut alpha = vec![FieldElement::zero(); variable_count]; + + for i in 0..variable_count { + let hhat_i: [FieldElement; 4] = [ + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read SPARK sumcheck hhat_i[0]"))?, + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read SPARK sumcheck hhat_i[1]"))?, + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read SPARK sumcheck hhat_i[2]"))?, + arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read SPARK sumcheck hhat_i[3]"))?, + ]; + let alpha_i: FieldElement = arthur.verifier_message(); + alpha[i] = alpha_i; + + let hhat_i_at_zero = eval_cubic_poly(hhat_i, FieldElement::zero()); + let hhat_i_at_one = eval_cubic_poly(hhat_i, FieldElement::one()); + ensure!( + saved_val_for_sumcheck_equality_assertion == hhat_i_at_zero + hhat_i_at_one, + "Sumcheck equality check failed" + ); + saved_val_for_sumcheck_equality_assertion = eval_cubic_poly(hhat_i, alpha_i); + } + + Ok((alpha, saved_val_for_sumcheck_equality_assertion)) +} diff --git a/provekit/spark/src/types.rs b/provekit/spark/src/types.rs new file mode 100644 index 000000000..16f1fd54c --- /dev/null +++ b/provekit/spark/src/types.rs @@ -0,0 +1,211 @@ +pub use provekit_common::{MatrixDimensions, SparkSetup, SparkWhirConfigs}; +use { + provekit_common::{ + file::{ + binary_format::{ + SPARK_CONTEXT_FORMAT, SPARK_CONTEXT_VERSION, SPARK_PROOF_FORMAT, + SPARK_PROOF_VERSION, + }, + Compression, FileFormat, MaybeHashAware, + }, + utils::serde_ark_vec, + FieldElement, HashConfig, WhirR1CSProof, + }, + serde::{Deserialize, Serialize}, + whir::protocols::irs_commit, +}; + +pub type WhirWitness = irs_commit::Witness; + +#[derive(Serialize, Deserialize)] +#[serde(transparent)] +pub struct SparkProof(pub WhirR1CSProof); + +impl FileFormat for SparkProof { + const FORMAT: [u8; 8] = SPARK_PROOF_FORMAT; + const EXTENSION: &'static str = "sp"; + const VERSION: (u16, u16) = SPARK_PROOF_VERSION; + const COMPRESSION: Compression = Compression::Zstd; +} + +impl MaybeHashAware for SparkProof { + fn maybe_hash_config(&self) -> Option { + None + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(from = "SparkMatrixSerde", into = "SparkMatrixSerde")] +pub struct SparkMatrix { + pub coo: COOMatrix, + pub timestamps: TimeStamps, +} + +impl SparkMatrix { + pub fn new( + row: Vec, + col: Vec, + val: Vec, + num_rows: usize, + num_cols: usize, + ) -> Self { + let len = row.len(); + let mut read_row_counters = vec![0u32; num_rows]; + let mut read_col_counters = vec![0u32; num_cols]; + let mut read_row = Vec::with_capacity(len); + let mut read_col = Vec::with_capacity(len); + for i in 0..len { + read_row.push(read_row_counters[row[i]]); + read_row_counters[row[i]] += 1; + read_col.push(read_col_counters[col[i]]); + read_col_counters[col[i]] += 1; + } + Self { + coo: COOMatrix { row, col, val }, + timestamps: TimeStamps { + read_row, + read_col, + final_row: read_row_counters, + final_col: read_col_counters, + }, + } + } +} + +#[derive(Debug, Clone)] +pub struct COOMatrix { + pub row: Vec, + pub col: Vec, + pub val: Vec, +} + +impl COOMatrix { + pub fn row_field(&self) -> Vec { + self.row + .iter() + .map(|&r| FieldElement::from(r as u64)) + .collect() + } + + pub fn col_field(&self) -> Vec { + self.col + .iter() + .map(|&c| FieldElement::from(c as u64)) + .collect() + } +} + +#[derive(Debug, Clone)] +pub struct TimeStamps { + pub read_row: Vec, + pub read_col: Vec, + pub final_row: Vec, + pub final_col: Vec, +} + +impl TimeStamps { + pub fn read_row_field(&self) -> Vec { + self.read_row + .iter() + .map(|&x| FieldElement::from(u64::from(x))) + .collect() + } + + pub fn read_col_field(&self) -> Vec { + self.read_col + .iter() + .map(|&x| FieldElement::from(u64::from(x))) + .collect() + } + + pub fn final_row_field(&self) -> Vec { + self.final_row + .iter() + .map(|&x| FieldElement::from(u64::from(x))) + .collect() + } + + pub fn final_col_field(&self) -> Vec { + self.final_col + .iter() + .map(|&x| FieldElement::from(u64::from(x))) + .collect() + } +} + +#[derive(Serialize, Deserialize)] +struct SparkMatrixSerde { + row: Vec, + col: Vec, + #[serde(with = "serde_ark_vec")] + val: Vec, + num_rows: usize, + num_cols: usize, +} + +impl From for SparkMatrixSerde { + fn from(m: SparkMatrix) -> Self { + Self { + num_rows: m.timestamps.final_row.len(), + num_cols: m.timestamps.final_col.len(), + row: m.coo.row, + col: m.coo.col, + val: m.coo.val, + } + } +} + +impl From for SparkMatrix { + fn from(s: SparkMatrixSerde) -> Self { + SparkMatrix::new(s.row, s.col, s.val, s.num_rows, s.num_cols) + } +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct SparkWitnesses { + #[serde(with = "crate::serde_whir_witness")] + pub vals_rs_ws_witness: WhirWitness, + #[serde(with = "crate::serde_whir_witness")] + pub final_row_ts_witness: WhirWitness, + #[serde(with = "crate::serde_whir_witness")] + pub final_col_ts_witness: WhirWitness, +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct SparkProverContext { + pub matrix: SparkMatrix, + pub witnesses: SparkWitnesses, + pub setup: SparkSetup, +} + +impl FileFormat for SparkProverContext { + const FORMAT: [u8; 8] = SPARK_CONTEXT_FORMAT; + const EXTENSION: &'static str = "spctx"; + const VERSION: (u16, u16) = SPARK_CONTEXT_VERSION; + const COMPRESSION: Compression = Compression::Zstd; +} + +impl MaybeHashAware for SparkProverContext { + fn maybe_hash_config(&self) -> Option { + None + } +} + +#[derive(Debug, Clone)] +pub struct Memory { + pub eq_rx: Vec, + pub eq_ry: Vec, +} + +#[derive(Debug, Clone)] +pub struct EValuesForMatrix { + pub e_rx: Vec, + pub e_ry: Vec, +} + +/// Challenges drawn from the Fiat-Shamir transcript during proving. +#[derive(Debug, Clone)] +pub struct Challenges { + pub gamma: FieldElement, + pub tau: FieldElement, +} diff --git a/provekit/spark/src/utils.rs b/provekit/spark/src/utils.rs new file mode 100644 index 000000000..158293be0 --- /dev/null +++ b/provekit/spark/src/utils.rs @@ -0,0 +1,79 @@ +pub use crate::types::Memory; +use { + crate::types::{MatrixDimensions, SparkMatrix}, + anyhow::{anyhow, Result}, + ark_ff::Zero, + ark_serialize::CanonicalDeserialize, + provekit_common::{ + utils::sumcheck::calculate_evaluations_over_boolean_hypercube_for_eq, FieldElement, + TranscriptSponge, + }, + whir::transcript::VerifierState, +}; + +pub fn read_hint( + arthur: &mut VerifierState<'_, TranscriptSponge>, + label: &str, +) -> Result { + arthur + .prover_hint_ark() + .map_err(|e| anyhow!("Failed to read {label} hint: {e}")) +} + +#[tracing::instrument(skip_all)] +pub fn alphas_from_spark( + spark: &SparkMatrix, + dims: &MatrixDimensions, + alpha: &[FieldElement], +) -> [Vec; 3] { + let row_cnt = dims.num_rows / 2; + let col_cnt = dims.num_cols / 2; + + let eq_alpha = calculate_evaluations_over_boolean_hypercube_for_eq(alpha, row_cnt); + + let mut a = vec![FieldElement::zero(); col_cnt]; + let mut b = vec![FieldElement::zero(); col_cnt]; + let mut c = vec![FieldElement::zero(); col_cnt]; + + for i in 0..spark.coo.val.len() { + let v = spark.coo.val[i]; + if v.is_zero() { + continue; + } + let r = spark.coo.row[i]; + let col = spark.coo.col[i]; + if r < row_cnt && col < col_cnt { + a[col] += v * eq_alpha[r]; + } else if r < row_cnt { + b[col - col_cnt] += v * eq_alpha[r]; + } else { + c[col - col_cnt] += v * eq_alpha[r - row_cnt]; + } + } + + [a, b, c] +} + +#[tracing::instrument(skip_all)] +pub fn calculate_memory( + b: FieldElement, + point_row: &[FieldElement], + point_col: &[FieldElement], +) -> Memory { + let row_point: Vec<_> = std::iter::once(b) + .chain(point_row.iter().copied()) + .collect(); + let col_point: Vec<_> = std::iter::once(b) + .chain(point_col.iter().copied()) + .collect(); + Memory { + eq_rx: calculate_evaluations_over_boolean_hypercube_for_eq( + &row_point, + 1 << row_point.len(), + ), + eq_ry: calculate_evaluations_over_boolean_hypercube_for_eq( + &col_point, + 1 << col_point.len(), + ), + } +} diff --git a/provekit/spark/src/verifier.rs b/provekit/spark/src/verifier.rs new file mode 100644 index 000000000..2187442ca --- /dev/null +++ b/provekit/spark/src/verifier.rs @@ -0,0 +1,290 @@ +use { + crate::{ + gpa::gpa_sumcheck_verifier4, + memory::verify_axis, + setup::{extract_commitments, PrecomputedCommitments}, + sumcheck::{run_parallel_sumchecks_verifier, run_sumcheck_verifier_spark}, + types::{MatrixDimensions, SparkProof, SparkSetup, SparkWhirConfigs}, + utils::read_hint, + }, + anyhow::{ensure, Context, Result}, + ark_ff::{AdditiveGroup, Field, Zero}, + provekit_common::{ + spark::{SparkColQuery, SparkQueryBatch}, + utils::{next_power_of_two, sumcheck::calculate_eq}, + FieldElement, TranscriptSponge, + }, + tracing::instrument, + whir::{ + algebra::linear_form::MultilinearExtension, + transcript::{DomainSeparator, Proof, VerifierMessage, VerifierState}, + }, +}; + +pub struct SparkVerifierScheme; + +impl SparkVerifierScheme { + #[instrument(skip_all)] + pub fn verify( + &self, + proof: SparkProof, + setup: &SparkSetup, + batch: &SparkQueryBatch, + ) -> Result<()> { + ensure!( + !batch.queries.is_empty(), + "SPARK verifier needs at least one query" + ); + + let precomputed_commitments = extract_commitments(setup)?; + + let whir_proof = Proof { + narg_string: proof.0.narg_string, + hints: proof.0.hints, + #[cfg(debug_assertions)] + pattern: proof.0.pattern, + }; + let mut arthur = VerifierState::new( + &DomainSeparator::protocol(&setup.whir_configs) + .session(&setup.transcript.narg_string) + .instance(&batch.hash_bytes()?), + &whir_proof, + TranscriptSponge::from_config(setup.hash_config), + ); + + let request: SparkColQuery = if batch.queries.len() == 1 { + batch.queries[0].clone() + } else { + let beta: FieldElement = arthur.verifier_message(); + + let mut claimed_evals = [FieldElement::ZERO; 3]; + let mut beta_pow = FieldElement::ONE; + for query in &batch.queries { + claimed_evals[0] += beta_pow * query.claimed_a; + claimed_evals[1] += beta_pow * query.claimed_b; + claimed_evals[2] += beta_pow * query.claimed_c; + beta_pow *= beta; + } + + let domain_size = setup.matrix_dimensions.num_cols / 2; + ensure!( + domain_size.is_power_of_two(), + "SPARK domain_size must be a power of two (got {domain_size})" + ); + let variable_count = domain_size.ilog2() as usize; + let (final_claims, folded, folding_randomness) = + run_parallel_sumchecks_verifier(&mut arthur, variable_count, claimed_evals) + .context("verifying parallel sumchecks")?; + + let mut h_at_fr = FieldElement::ZERO; + let mut beta_pow = FieldElement::ONE; + for query in &batch.queries { + h_at_fr += beta_pow * calculate_eq(&query.col, &folding_randomness); + beta_pow *= beta; + } + ensure!( + final_claims[0] == h_at_fr * folded[0], + "parallel sumcheck final-equation check failed (a)" + ); + ensure!( + final_claims[1] == h_at_fr * folded[1], + "parallel sumcheck final-equation check failed (b)" + ); + ensure!( + final_claims[2] == h_at_fr * folded[2], + "parallel sumcheck final-equation check failed (c)" + ); + + SparkColQuery { + col: folding_randomness, + claimed_a: folded[0], + claimed_b: folded[1], + claimed_c: folded[2], + } + }; + + let r: FieldElement = arthur.verifier_message(); + ensure!( + !(FieldElement::ONE + r).is_zero(), + "SPARK RLC randomness must not equal -1 (would zero the denominator)" + ); + + let b1 = r / (FieldElement::ONE + r); + let combined = request.claimed_a + r * request.claimed_b + r * r * request.claimed_c; + let claimed_value = combined / (FieldElement::ONE + r) / (FieldElement::ONE + r); + + let extended_row: Vec = std::iter::once(b1) + .chain(batch.row.iter().copied()) + .collect(); + let extended_col: Vec = std::iter::once(b1) + .chain(request.col.iter().copied()) + .collect(); + + verify_spark_single_matrix( + &setup.whir_configs, + setup.matrix_dimensions.clone(), + &mut arthur, + &precomputed_commitments, + &extended_row, + &extended_col, + &claimed_value, + ) + } +} + +#[instrument(skip_all)] +pub(crate) fn verify_spark_single_matrix( + whir_configs: &SparkWhirConfigs, + matrix_dimensions: MatrixDimensions, + arthur: &mut VerifierState<'_, TranscriptSponge>, + precomputed_commitments: &PrecomputedCommitments, + row: &[FieldElement], + col: &[FieldElement], + claimed_value: &FieldElement, +) -> Result<()> { + let e_values_commitment = whir_configs + .num_terms_2batched + .receive_commitment(arthur) + .map_err(|e| anyhow::anyhow!("Failed to receive e_values commitment: {e}"))?; + + let (randomness, last_sumcheck_value) = run_sumcheck_verifier_spark( + arthur, + next_power_of_two(matrix_dimensions.nonzero_terms), + *claimed_value, + ) + .context("While verifying SPARK sumcheck")?; + let eval_weight = MultilinearExtension::new(randomness); + + let sumcheck_hints: [FieldElement; 3] = read_hint(arthur, "SPARK sumcheck final folds")?; + + ensure!( + last_sumcheck_value == sumcheck_hints[0] * sumcheck_hints[1] * sumcheck_hints[2], + "SPARK sumcheck final folds inconsistent with claimed last value" + ); + + let tau: FieldElement = arthur.verifier_message(); + let gamma: FieldElement = arthur.verifier_message(); + + let gpa_result = gpa_sumcheck_verifier4( + arthur, + provekit_common::utils::next_power_of_two(matrix_dimensions.nonzero_terms) + 3, + )?; + + let (combination_randomness, evaluation_randomness) = gpa_result.randomness.split_at(2); + + let claimed_row_rs = gpa_result.claimed_values[0]; + let claimed_row_ws = gpa_result.claimed_values[1]; + let claimed_col_rs = gpa_result.claimed_values[2]; + let claimed_col_ws = gpa_result.claimed_values[3]; + + let row_adr: FieldElement = read_hint(arthur, "row_adr")?; + let row_timestamp: FieldElement = read_hint(arthur, "row_timestamp")?; + let col_adr: FieldElement = read_hint(arthur, "col_adr")?; + let col_timestamp: FieldElement = read_hint(arthur, "col_timestamp")?; + + let gpa_eval_weight = MultilinearExtension::new(evaluation_randomness.to_vec()); + let gpa_eval_lf: &dyn whir::algebra::linear_form::LinearForm = &gpa_eval_weight; + + let row_field_at_fold: FieldElement = read_hint(arthur, "row_field_at_fold")?; + let read_row_at_fold: FieldElement = read_hint(arthur, "read_row_at_fold")?; + let col_field_at_fold: FieldElement = read_hint(arthur, "col_field_at_fold")?; + let read_col_at_fold: FieldElement = read_hint(arthur, "read_col_at_fold")?; + let vals_at_eval: FieldElement = read_hint(arthur, "vals_at_eval")?; + + let fold_lf_for_vals_rs_ws = MultilinearExtension::new(eval_weight.point.clone()); + let eval_lf_for_vals_rs_ws = MultilinearExtension::new(evaluation_randomness.to_vec()); + + let vals_rs_ws_claim = whir_configs + .num_terms_5batched + .verify(arthur, &[&precomputed_commitments.vals_rsws], &[ + // fold_lf + sumcheck_hints[0], + row_field_at_fold, + read_row_at_fold, + col_field_at_fold, + read_col_at_fold, + // eval_lf + vals_at_eval, + row_adr, + row_timestamp, + col_adr, + col_timestamp, + ]) + .map_err(|e| anyhow::anyhow!("WHIR verify failed for vals_rs_ws: {e}"))?; + vals_rs_ws_claim + .verify([ + &fold_lf_for_vals_rs_ws as &dyn whir::algebra::linear_form::LinearForm, + &eval_lf_for_vals_rs_ws as &dyn whir::algebra::linear_form::LinearForm, + ]) + .map_err(|e| anyhow::anyhow!("FinalClaim check failed for vals_rs_ws: {e}"))?; + + let row_mem: FieldElement = read_hint(arthur, "row_mem")?; + let col_mem: FieldElement = read_hint(arthur, "col_mem")?; + + let e_values_combined_claim = whir_configs + .num_terms_2batched + .verify(arthur, &[&e_values_commitment], &[ + sumcheck_hints[1], + sumcheck_hints[2], + row_mem, + col_mem, + ]) + .map_err(|e| anyhow::anyhow!("WHIR verify failed for e_values (combined): {e}"))?; + e_values_combined_claim + .verify([ + &eval_weight as &dyn whir::algebra::linear_form::LinearForm, + gpa_eval_lf, + ]) + .map_err(|e| anyhow::anyhow!("FinalClaim check failed for e_values (combined): {e}"))?; + + let gamma_sq = gamma * gamma; + + let row_rs_opening = row_adr * gamma_sq + row_mem * gamma + row_timestamp - tau; + let row_ws_opening = + row_adr * gamma_sq + row_mem * gamma + row_timestamp + FieldElement::ONE - tau; + let col_rs_opening = col_adr * gamma_sq + col_mem * gamma + col_timestamp - tau; + let col_ws_opening = + col_adr * gamma_sq + col_mem * gamma + col_timestamp + FieldElement::ONE - tau; + + let evaluated_value = row_rs_opening + * (FieldElement::ONE - combination_randomness[0]) + * (FieldElement::ONE - combination_randomness[1]) + + row_ws_opening + * (FieldElement::ONE - combination_randomness[0]) + * combination_randomness[1] + + col_rs_opening + * combination_randomness[0] + * (FieldElement::ONE - combination_randomness[1]) + + col_ws_opening * combination_randomness[0] * combination_randomness[1]; + + ensure!( + evaluated_value == gpa_result.last_sumcheck_value, + "rs/ws GPA sumcheck final value inconsistent with evaluated multilinear extension" + ); + + verify_axis( + arthur, + matrix_dimensions.num_rows, + &whir_configs.row, + &precomputed_commitments.a_row_finalts, + |eval_rand| calculate_eq(row, eval_rand), + &tau, + &gamma, + &claimed_row_rs, + &claimed_row_ws, + )?; + + verify_axis( + arthur, + matrix_dimensions.num_cols, + &whir_configs.col, + &precomputed_commitments.a_col_finalts, + |eval_rand| calculate_eq(col, eval_rand), + &tau, + &gamma, + &claimed_col_rs, + &claimed_col_ws, + )?; + + Ok(()) +} diff --git a/tooling/cli/Cargo.toml b/tooling/cli/Cargo.toml index c45feb548..97ee1c071 100644 --- a/tooling/cli/Cargo.toml +++ b/tooling/cli/Cargo.toml @@ -10,10 +10,12 @@ repository.workspace = true [dependencies] # Workspace crates +mavros-artifacts.workspace = true provekit-common.workspace = true provekit-gnark.workspace = true provekit-prover = { workspace = true, features = ["witness-generation", "parallel"] } provekit-r1cs-compiler.workspace = true +provekit-spark.workspace = true provekit-verifier.workspace = true # Noir language @@ -29,6 +31,7 @@ ark-ff.workspace = true # 3rd party anyhow.workspace = true +bincode.workspace = true argh.workspace = true base64.workspace = true hex.workspace = true diff --git a/tooling/cli/src/cmd/mod.rs b/tooling/cli/src/cmd/mod.rs index 9bc3fdb90..ccd2b6373 100644 --- a/tooling/cli/src/cmd/mod.rs +++ b/tooling/cli/src/cmd/mod.rs @@ -1,11 +1,13 @@ mod analyze_pkp; mod circuit_stats; mod generate_gnark_inputs; -mod prepare; +pub mod prepare; mod prove; +mod prove_spark; mod show_inputs; mod util; mod verify; +mod verify_spark; use {anyhow::Result, argh::FromArgs}; @@ -42,8 +44,10 @@ enum Commands { AnalyzePkp(analyze_pkp::Args), Prepare(prepare::Args), Prove(prove::Args), + ProveSpark(prove_spark::Args), CircuitStats(circuit_stats::Args), Verify(verify::Args), + VerifySpark(verify_spark::Args), GenerateGnarkInputs(generate_gnark_inputs::Args), ShowInputs(show_inputs::Args), } @@ -60,8 +64,10 @@ impl Command for Commands { Self::AnalyzePkp(args) => args.run(), Self::Prepare(args) => args.run(), Self::Prove(args) => args.run(), + Self::ProveSpark(args) => args.run(), Self::CircuitStats(args) => args.run(), Self::Verify(args) => args.run(), + Self::VerifySpark(args) => args.run(), Self::GenerateGnarkInputs(args) => args.run(), Self::ShowInputs(args) => args.run(), } diff --git a/tooling/cli/src/cmd/prepare.rs b/tooling/cli/src/cmd/prepare.rs index bcfd34d91..03f151c5e 100644 --- a/tooling/cli/src/cmd/prepare.rs +++ b/tooling/cli/src/cmd/prepare.rs @@ -2,6 +2,7 @@ use { super::{util::resolve_key_path, Command}, anyhow::{anyhow, bail, Context as _, Result}, argh::FromArgs, + mavros_artifacts::R1CS as MavrosR1CS, nargo::{ insert_all_files_for_workspace_into_file_manager, ops::{check_program, collect_errors, compile_program, optimize_program, report_errors}, @@ -10,18 +11,22 @@ use { nargo_toml::{find_root, get_package_manifest, resolve_workspace_from_toml, PackageSelection}, noir_artifact_cli::fs::artifact::save_program_to_file, noirc_driver::{CompilationResult, CompileOptions, CrateName, NOIR_ARTIFACT_VERSION_STRING}, - provekit_common::{file::write, HashConfig, Prover, Verifier}, + provekit_common::{ + file::write, utils::next_power_of_two, FieldElement, HashConfig, NoirProofScheme, Prover, + SparkSetup, Verifier, R1CS, + }, provekit_r1cs_compiler::{MavrosCompiler, NoirCompiler}, + provekit_spark::SparkMatrix, rayon::prelude::*, std::{ path::{Path, PathBuf}, str::FromStr, }, - tracing::instrument, + tracing::{info, instrument}, }; #[derive(PartialEq, Eq, Debug)] -enum Compiler { +pub enum Compiler { Noir, Mavros, } @@ -105,6 +110,20 @@ pub struct Args { /// blake3, poseidon2) #[argh(option, long = "hash", default = "String::from(\"skyscraper\")")] hash: String, + + /// also run SPARK preprocessing; the setup is folded into the PKV and the + /// prover context is written to `--spctx` + #[argh(switch, long = "spark")] + spark: bool, + + /// output path for the bundled SPARK prover context — matrix, witnesses + /// and setup (used with --spark) + #[argh( + option, + long = "spctx", + default = "PathBuf::from(\"noir_proof_scheme.spctx\")" + )] + spctx_path: PathBuf, } impl Command for Args { @@ -152,6 +171,9 @@ impl Args { if binary_packages.len() > 1 && (self.pkp_path.is_some() || self.pkv_path.is_some()) { bail!("--pkp/--pkv cannot be used with multiple binary packages"); } + if binary_packages.len() > 1 && self.spark { + bail!("--spark cannot be used with multiple binary packages"); + } let target_dir = workspace.target_directory_path(); @@ -185,6 +207,7 @@ impl Args { for (package, artifact) in binary_packages.iter().zip(artifacts) { let scheme = NoirCompiler::from_program(artifact, hash_config) .context("while building Noir proof scheme")?; + let spark_setup = self.maybe_build_spark(&scheme, hash_config)?; let pkp_path = self .pkp_path .clone() @@ -195,8 +218,9 @@ impl Args { .unwrap_or_else(|| format!("{}.pkv", package.name).into()); write(&Prover::from_noir_proof_scheme(scheme.clone()), &pkp_path) .context("while writing prover key")?; - write(&Verifier::from_noir_proof_scheme(scheme), &pkv_path) - .context("while writing verifier key")?; + let mut verifier = Verifier::from_noir_proof_scheme(scheme); + verifier.spark_setup = spark_setup; + write(&verifier, &pkv_path).context("while writing verifier key")?; } Ok(()) } @@ -208,15 +232,43 @@ impl Args { .context("--r1cs is required when using the mavros compiler")?; let scheme = MavrosCompiler::compile(&self.program_path, r1cs_path, hash_config) .context("while compiling with Mavros")?; + let spark_setup = self.maybe_build_spark(&scheme, hash_config)?; let pkp_path = resolve_key_path(self.pkp_path.as_deref(), "pkp")?; let pkv_path = resolve_key_path(self.pkv_path.as_deref(), "pkv")?; write(&Prover::from_noir_proof_scheme(scheme.clone()), &pkp_path) .context("while writing prover key")?; - write(&Verifier::from_noir_proof_scheme(scheme), &pkv_path) - .context("while writing verifier key")?; + let mut verifier = Verifier::from_noir_proof_scheme(scheme); + verifier.spark_setup = spark_setup; + write(&verifier, &pkv_path).context("while writing verifier key")?; Ok(()) } + /// Run SPARK preprocessing if `--spark` is set: write the prover context + /// (`.spctx`) and return the [`SparkSetup`] for the caller to fold into the + /// PKV. The verifier consumes the setup from the trusted PKV; there is no + /// standalone `.spc` artifact. + fn maybe_build_spark( + &self, + scheme: &NoirProofScheme, + hash_config: HashConfig, + ) -> Result> { + if !self.spark { + return Ok(None); + } + provekit_common::register_ntt(); + let matrix = build_spark_matrix_for_scheme(scheme, self.r1cs_path.as_deref())?; + let (setup, witnesses) = provekit_spark::preprocess_spark(&matrix, hash_config); + let context = provekit_spark::SparkProverContext { + matrix, + witnesses, + setup: setup.clone(), + }; + write(&context, &self.spctx_path) + .with_context(|| format!("writing SPARK prover context to {:?}", self.spctx_path))?; + info!("Wrote SPARK prover context to {:?}", self.spctx_path); + Ok(Some(setup)) + } + fn compile_options(&self) -> CompileOptions { CompileOptions { deny_warnings: self.deny_warnings, @@ -260,3 +312,159 @@ impl Args { Ok(PackageSelection::DefaultOrAll) } } + +pub fn build_spark_matrix_for_scheme( + scheme: &NoirProofScheme, + r1cs_path: Option<&Path>, +) -> Result { + let whir = match scheme { + NoirProofScheme::Noir(s) => s.whir_for_witness.clone(), + NoirProofScheme::Mavros(s) => s.whir_for_witness.clone(), + }; + match scheme { + NoirProofScheme::Noir(noir) => build_spark_r1cs_noir( + &noir.r1cs, + whir.m_0, + whir.m, + whir.w1_size, + whir.num_challenges, + ), + NoirProofScheme::Mavros(_) => { + let r1cs_path = + r1cs_path.context("--r1cs is required for SPARK with the mavros compiler")?; + build_spark_r1cs_mavros( + r1cs_path, + whir.m_0, + whir.m, + whir.w1_size, + whir.num_challenges, + ) + } + } +} + +pub fn build_spark_r1cs_noir( + r1cs: &R1CS, + log_row: usize, + log_col: usize, + w1_size: usize, + num_challenges: usize, +) -> Result { + let is_single_commitment = num_challenges == 0; + + let original_num_entries = + r1cs.a().iter().count() + r1cs.b().iter().count() + r1cs.c().iter().count(); + + let padded_num_entries = 1 << next_power_of_two(original_num_entries); + let to_fill = padded_num_entries - original_num_entries; + + let row_cnt = 1 << log_row; + let col_cnt = if is_single_commitment { + 1 << log_col + } else { + 1 << (1 + log_col) + }; + + let col_witness_split_offset = |c: usize| -> usize { + if !is_single_commitment && (c >= w1_size) { + (1 << log_col) - w1_size + } else { + 0 + } + }; + + let (mut row, mut col, mut val) = ( + Vec::with_capacity(padded_num_entries), + Vec::with_capacity(padded_num_entries), + Vec::with_capacity(padded_num_entries), + ); + + for (matrix, row_offset, col_offset) in [ + (r1cs.a(), 0, 0), + (r1cs.b(), 0, col_cnt), + (r1cs.c(), row_cnt, col_cnt), + ] { + for ((r, c), v) in matrix.iter() { + row.push(r + row_offset); + col.push(c + col_offset + col_witness_split_offset(c)); + val.push(v); + } + } + for _ in 0..to_fill { + row.push(0); + col.push(0); + val.push(FieldElement::from(0u64)); + } + + Ok(SparkMatrix::new(row, col, val, 2 * row_cnt, 2 * col_cnt)) +} + +pub fn build_spark_r1cs_mavros( + r1cs_path: &Path, + log_row: usize, + log_col: usize, + w1_size: usize, + num_challenges: usize, +) -> Result { + let is_single_commitment = num_challenges == 0; + + let r1cs_bytes = std::fs::read(r1cs_path).context("while reading R1CS file")?; + let r1cs: MavrosR1CS = + bincode::deserialize(&r1cs_bytes).context("while deserializing R1CS from bincode")?; + + let row_cnt = 1 << log_row; + let col_cnt = if is_single_commitment { + 1 << log_col + } else { + 1 << (1 + log_col) + }; + + let col_witness_split_offset = |c: usize| -> usize { + if !is_single_commitment && (c >= w1_size) { + (1 << log_col) - w1_size + } else { + 0 + } + }; + + let original_num_entries: usize = r1cs + .constraints + .iter() + .map(|r1c| r1c.a.len() + r1c.b.len() + r1c.c.len()) + .sum(); + + let padded_num_entries = 1 << next_power_of_two(original_num_entries); + let to_fill = padded_num_entries - original_num_entries; + + let (mut row, mut col, mut val) = ( + Vec::with_capacity(padded_num_entries), + Vec::with_capacity(padded_num_entries), + Vec::with_capacity(padded_num_entries), + ); + + for (i, r1c) in r1cs.constraints.iter().enumerate() { + for &(c, v) in &r1c.a { + row.push(i); + col.push(c + col_witness_split_offset(c)); + val.push(v); + } + for &(c, v) in &r1c.b { + row.push(i); + col.push(c + col_cnt + col_witness_split_offset(c)); + val.push(v); + } + for &(c, v) in &r1c.c { + row.push(i + row_cnt); + col.push(c + col_cnt + col_witness_split_offset(c)); + val.push(v); + } + } + + for _ in 0..to_fill { + row.push(0); + col.push(0); + val.push(FieldElement::from(0u64)); + } + + Ok(SparkMatrix::new(row, col, val, 2 * row_cnt, 2 * col_cnt)) +} diff --git a/tooling/cli/src/cmd/prove.rs b/tooling/cli/src/cmd/prove.rs index 7030da6a2..cf0698a25 100644 --- a/tooling/cli/src/cmd/prove.rs +++ b/tooling/cli/src/cmd/prove.rs @@ -38,6 +38,18 @@ pub struct Args { /// path to the verifier key (default: `.pkv`) #[argh(option, long = "verifier")] verifier_path: Option, + + /// directory in which to write SPARK queries (default: ./spark_proofs) + #[argh( + option, + long = "spark-queries-dir", + default = "PathBuf::from(\"./spark_proofs\")" + )] + spark_queries_dir: PathBuf, + + /// produce SPARK queries and write them to `spark_queries_dir`. + #[argh(switch, long = "produce-spark-query")] + produce_spark_query: bool, } impl Command for Args { @@ -53,9 +65,17 @@ impl Command for Args { let (constraints, witnesses) = prover.size(); info!(constraints, witnesses, "Read Noir proof scheme"); - let proof = prover - .prove_with_toml(&input_path) - .context("While proving Noir program statement")?; + let (proof, spark_queries) = if self.produce_spark_query { + let (proof, batch) = prover + .prove_with_spark_toml(&input_path) + .context("While proving Noir program statement")?; + (proof, Some(batch)) + } else { + let proof = prover + .prove_with_toml(&input_path) + .context("While proving Noir program statement")?; + (proof, None) + }; write(&proof, &self.proof_path).context("while writing proof")?; @@ -69,6 +89,19 @@ impl Command for Args { .context("While verifying Noir proof")?; } + if let Some(batch) = &spark_queries { + std::fs::create_dir_all(&self.spark_queries_dir) + .with_context(|| format!("creating {:?}", self.spark_queries_dir))?; + let queries_path = self.spark_queries_dir.join("spark_queries.json"); + let queries_file = std::fs::File::create(&queries_path) + .with_context(|| format!("creating {queries_path:?}"))?; + serde_json::to_writer_pretty(queries_file, batch).context("writing SPARK queries")?; + info!( + count = batch.queries.len(), + "Wrote SPARK queries to {queries_path:?}" + ); + } + Ok(()) } } diff --git a/tooling/cli/src/cmd/prove_spark.rs b/tooling/cli/src/cmd/prove_spark.rs new file mode 100644 index 000000000..12adea7fc --- /dev/null +++ b/tooling/cli/src/cmd/prove_spark.rs @@ -0,0 +1,81 @@ +use { + super::Command, + anyhow::{Context, Result}, + argh::FromArgs, + provekit_common::{ + file::{read, write}, + spark::SparkQueryBatch, + Prover, + }, + provekit_spark::{SparkProverContext, SparkProverScheme}, + std::{fs::File, io::BufReader, path::PathBuf}, + tracing::{info, instrument}, +}; + +/// Generate SPARK proofs for the queries emitted by `prove`. +#[derive(FromArgs, PartialEq, Eq, Debug)] +#[argh(subcommand, name = "prove-spark")] +pub struct Args { + /// path to the prepared proof scheme + #[argh(positional)] + prover_path: PathBuf, + + /// directory containing `spark_queries.json`; the SPARK proof is written + /// here as `spark_proof.sp` (default: ./spark_proofs) + #[argh( + option, + long = "spark-dir", + default = "PathBuf::from(\"./spark_proofs\")" + )] + spark_dir: PathBuf, + + /// path to the SPARK prover context (matrix + witnesses + setup) written + /// by `prepare --spark` + #[argh( + option, + long = "spctx", + default = "PathBuf::from(\"noir_proof_scheme.spctx\")" + )] + spctx_path: PathBuf, +} + +impl Command for Args { + #[instrument(skip_all)] + fn run(&self) -> Result<()> { + provekit_common::register_ntt(); + + let prover: Prover = read(&self.prover_path).context("while reading Provekit Prover")?; + + let queries_path = self.spark_dir.join("spark_queries.json"); + if !queries_path.exists() { + info!("No SPARK queries found at {queries_path:?}"); + return Ok(()); + } + let batch = read_queries(&queries_path)?; + + let hash_config = prover.whir_for_witness().hash_config; + let context: SparkProverContext = read(&self.spctx_path) + .with_context(|| format!("reading SPARK prover context from {:?}", self.spctx_path))?; + + let num_constraints = context.matrix.timestamps.final_row.len(); + let num_witnesses = context.matrix.timestamps.final_col.len(); + let num_nonzero = context.matrix.coo.val.len(); + + let scheme = + SparkProverScheme::new(num_constraints, num_witnesses, num_nonzero, hash_config); + let spark_proof = scheme + .prove(&context, &batch) + .context("generating SPARK proof")?; + let proof_path = self.spark_dir.join("spark_proof.sp"); + write(&spark_proof, &proof_path) + .with_context(|| format!("writing SPARK proof to {proof_path:?}"))?; + info!("Wrote SPARK proof to {proof_path:?}"); + + Ok(()) + } +} + +fn read_queries(path: &PathBuf) -> Result { + let file = File::open(path).with_context(|| format!("opening {path:?}"))?; + serde_json::from_reader(BufReader::new(file)).with_context(|| format!("parsing {path:?}")) +} diff --git a/tooling/cli/src/cmd/verify_spark.rs b/tooling/cli/src/cmd/verify_spark.rs new file mode 100644 index 000000000..926bea339 --- /dev/null +++ b/tooling/cli/src/cmd/verify_spark.rs @@ -0,0 +1,74 @@ +use { + super::Command, + anyhow::{Context, Result}, + argh::FromArgs, + provekit_common::{file::read, spark::SparkQueryBatch, Verifier}, + provekit_spark::{SparkProof, SparkVerifierScheme}, + std::{fs::File, io::BufReader, path::PathBuf}, + tracing::instrument, +}; + +/// Verify a standalone SPARK proof against the saved SparkQueryBatch. +#[derive(FromArgs, PartialEq, Eq, Debug)] +#[argh(subcommand, name = "verify-spark")] +pub struct Args { + /// path to the SPARK proof file (.sp or .json) + #[argh(positional)] + proof_path: PathBuf, + + /// path to the ProveKit Verifier key (.pkv) produced by `prepare --spark` + #[argh(positional)] + pkv_path: PathBuf, + + /// path to the SPARK queries JSON file (`spark_queries.json`) written by + /// `prove` + #[argh(positional)] + queries_path: PathBuf, +} + +impl Command for Args { + #[instrument(skip_all)] + fn run(&self) -> Result<()> { + provekit_common::register_ntt(); + + let (proof, (verifier, queries)) = rayon::join( + || read::(&self.proof_path).context("while reading SPARK proof"), + || { + rayon::join( + || read::(&self.pkv_path).context("while reading Provekit Verifier"), + || { + read_queries(&self.queries_path) + .with_context(|| format!("while reading {:?}", self.queries_path)) + }, + ) + }, + ); + let proof = proof?; + let verifier = verifier?; + let batch = queries?; + + let setup = verifier.spark_setup.as_ref().with_context(|| { + format!( + "PKV {:?} does not contain a SPARK setup; re-run `prepare --spark`", + self.pkv_path + ) + })?; + + anyhow::ensure!( + !batch.queries.is_empty(), + "SPARK queries file {:?} is empty", + self.queries_path + ); + + SparkVerifierScheme + .verify(proof, setup, &batch) + .context("while verifying SPARK proof")?; + + Ok(()) + } +} + +fn read_queries(path: &PathBuf) -> Result { + let file = File::open(path).with_context(|| format!("opening {path:?}"))?; + serde_json::from_reader(BufReader::new(file)).context("parsing SPARK queries JSON") +}