From 2230559d30b2d00c450b5fe42e19f304080873ed Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 4 Feb 2025 17:37:48 -0500 Subject: [PATCH 01/18] Add commitments --- spartan_parallel/src/r1csproof.rs | 62 +++++++++++++++++++++++-------- 1 file changed, 47 insertions(+), 15 deletions(-) diff --git a/spartan_parallel/src/r1csproof.rs b/spartan_parallel/src/r1csproof.rs index 8dc188ab..6ab01495 100644 --- a/spartan_parallel/src/r1csproof.rs +++ b/spartan_parallel/src/r1csproof.rs @@ -8,7 +8,8 @@ use super::sumcheck::SumcheckInstanceProof; use super::timer::Timer; use super::transcript::{Transcript, append_protocol_name, challenge_vector, challenge_scalar, append_field_to_transcript}; use ff_ext::ExtensionField; -use mpcs::PolynomialCommitmentScheme; +use itertools::max; +use mpcs::{pcs_open, PolynomialCommitmentScheme}; use crate::{ProverWitnessSecInfo, VerifierWitnessSecInfo}; use serde::Serialize; use std::cmp::min; @@ -432,18 +433,50 @@ impl> R1CSPr } } - /* - let proof_eval_vars_at_ry_list = PolyEvalProof::prove_batched_instances_disjoint_rounds( - &poly_list, - &num_proofs_list, - &num_inputs_list, - &rq, - &ry, - &Zr_list, - transcript, - random_tape, - ); - */ + let max_len = max(poly_list.iter().map(|&p| p.num_vars)).unwrap().next_power_of_two(); + let param = Pcs::setup(max_len).unwrap(); + let (pp, _error) = Pcs::trim(param, max_len).unwrap(); + + let mut proof_eval_vars_at_ry_list: Vec = Vec::new(); + let mut proof_idx: usize = 0; + for i in 0..num_witness_secs { + let w = witness_secs[i]; + let wit_sec_num_instance = w.w_mat.len(); + for p in 0..wit_sec_num_instance { + let poly = poly_list[proof_idx]; + + let num_proofs = num_proofs_list[proof_idx]; + let num_inputs = num_inputs_list[proof_idx]; + let num_vars_q = num_proofs.log_2(); + let num_vars_y = num_inputs.log_2(); + let ry_short = { + if num_vars_y >= ry.len() { + let mut ry_pad: Vec = vec![E::ZERO; num_vars_y - ry.len()]; + ry_pad.extend_from_slice(&ry); + ry_pad + } + // Else ry_short is the last w.num_inputs[p].log_2() entries of ry + // thus, to obtain the actual ry, need to multiply by (1 - ry2)(1 - ry3)..., which is ry_factors[num_rounds_y - w.num_inputs[p]] + else { + ry[ry.len() - num_vars_y..].to_vec() + } + }; + let rq_short = rq[rq.len() - num_vars_q..].to_vec(); + let r = [rq_short, ry_short.clone()].concat(); + let Zr = Zr_list[proof_idx]; + + let pcs_proof: Pcs::Proof = pcs_open::( + &pp, + poly, + &witness_secs[i].comm_w[p], + &r, + &Zr, + transcript + ).expect("PCS proof should not fail"); + + proof_eval_vars_at_ry_list.push(pcs_proof); + } + } // Bind the resulting witness list to rp // poly_vars stores the result of each witness matrix bounded to (rq_short ++ ry) @@ -510,8 +543,7 @@ impl> R1CSPr claims_phase2: (*Az_claim, *Bz_claim, *Cz_claim), eval_vars_at_ry_list: raw_eval_vars_at_ry_list, eval_vars_at_ry, - proof_eval_vars_at_ry_list: Vec::new(), - // proof_eval_vars_at_ry_list, + proof_eval_vars_at_ry_list, }, [rp, rq, rx, [rw, ry].concat()], ) From 8bb5ed552761574b65da4f8b41f8233b64e8e83e Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 5 Feb 2025 15:30:21 -0500 Subject: [PATCH 02/18] Append commitments --- spartan_parallel/src/lib.rs | 64 ++++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/spartan_parallel/src/lib.rs b/spartan_parallel/src/lib.rs index 04c66d6f..9060969e 100644 --- a/spartan_parallel/src/lib.rs +++ b/spartan_parallel/src/lib.rs @@ -1574,11 +1574,23 @@ impl> SNARK< let (perm_exec_w2_prover, perm_exec_w2_v_comm) = Self::mat_to_prove_wit_sec(perm_exec_w2, &poly_pp); let (perm_exec_w3_prover, perm_exec_w3_v_comm) = Self::mat_to_prove_wit_sec(perm_exec_w3, &poly_pp); let (perm_exec_w3_shifted_prover, perm_exec_w3_shifted_v_comm) = Self::mat_to_prove_wit_sec(perm_exec_w3_shifted_mat, &poly_pp); - let (block_w2_prover, block_w2_v_comm) = Self::mats_to_prove_wit_sec(block_w2, &poly_pp); let (block_w3_prover, block_w3_v_comm) = Self::mats_to_prove_wit_sec(block_w3, &poly_pp); let (block_w3_shifted_prover, block_w3_shifted_v_comm) = Self::mats_to_prove_wit_sec(block_w3_shifted_mat, &poly_pp); + // Append commmitments to transcript + for comm in vec![ + perm_exec_w2_v_comm, + perm_exec_w3_v_comm, + perm_exec_w3_shifted_v_comm, + ] + .into_iter() + .chain(block_w2_v_comm.into_iter()) + .chain(block_w3_v_comm.into_iter()) + .chain(block_w3_shifted_v_comm.into_iter()) { + Pcs::write_commitment(comm, transcript); + } + ( comb_tau, comb_r, @@ -1728,6 +1740,10 @@ impl> SNARK< let timer_commit = Timer::new("input_commit"); let (block_vars_prover, block_vars_v_comm_list) = Self::mats_to_prove_wit_sec(block_vars_mat, &poly_pp); let (exec_inputs_prover, exec_inputs_v_comm) = Self::mat_to_prove_wit_sec(exec_inputs_list, &poly_pp); + for comm in block_vars_v_comm_list.into_iter().chain(vec![exec_inputs_v_comm].into_iter()) { + Pcs::write_commitment(comm, transcript); + } + let init_phy_mems_prover = if total_num_init_phy_mem_accesses > 0 { Self::mat_to_prover_wit_sec_no_commit(init_phy_mems_list) } else { @@ -1744,6 +1760,14 @@ impl> SNARK< let addr_phy_mems_shifted_list = vec![addr_phy_mems_list[1..].to_vec(), vec![vec![E::ZERO; PHY_MEM_WIDTH]]].concat(); let (addr_phy_mems_prover, addr_phy_mems_v_comm) = Self::mat_to_prove_wit_sec(addr_phy_mems_list, &poly_pp); let (addr_phy_mems_shifted_prover, addr_phy_mems_shifted_v_comm) = Self::mat_to_prove_wit_sec(addr_phy_mems_shifted_list, &poly_pp); + + for comm in vec![ + addr_phy_mems_v_comm, + addr_phy_mems_shifted_v_comm, + ] { + Pcs::write_commitment(comm, transcript); + } + (addr_phy_mems_prover, Some(addr_phy_mems_v_comm), addr_phy_mems_shifted_prover, Some(addr_phy_mems_shifted_v_comm)) } else { (ProverWitnessSecInfo::dummy(), None, ProverWitnessSecInfo::dummy(), None) @@ -1756,6 +1780,15 @@ impl> SNARK< let (addr_vir_mems_prover, addr_vir_mems_v_comm) = Self::mat_to_prove_wit_sec(addr_vir_mems_list, &poly_pp); let (addr_vir_mems_shifted_prover, addr_vir_mems_shifted_v_comm) = Self::mat_to_prove_wit_sec(addr_vir_mems_shifted_list, &poly_pp); let (addr_ts_bits_prover, addr_ts_bits_v_comm) = Self::mat_to_prove_wit_sec(addr_ts_bits_list, &poly_pp); + + for comm in vec![ + addr_vir_mems_v_comm, + addr_vir_mems_shifted_v_comm, + addr_ts_bits_v_comm, + ] { + Pcs::write_commitment(comm, transcript); + } + ( addr_vir_mems_prover, Some(addr_vir_mems_v_comm), @@ -2730,6 +2763,19 @@ impl> SNARK< ) }; + // Append commmitments to transcript + for comm in vec![ + self.perm_exec_w2_comm, + self.perm_exec_w3_comm, + self.perm_exec_w3_shifted_comm, + ] + .into_iter() + .chain(self.block_w2_comm_list.into_iter()) + .chain(self.block_w3_comm_list.into_iter()) + .chain(self.block_w3_shifted_comm_list.into_iter()) { + Pcs::write_commitment(comm, transcript); + } + let (init_phy_mem_w2_verifier, init_phy_mem_w3_verifier, init_phy_mem_w3_shifted_verifier) = { if total_num_init_phy_mem_accesses > 0 { ( @@ -2808,6 +2854,10 @@ impl> SNARK< ) }; + for comm in self.block_vars_comm_list.into_iter().chain(vec![self.exec_inputs_comm]) { + Pcs::write_commitment(comm, transcript); + } + let init_phy_mems_verifier = { if input_stack.len() > 0 { assert_eq!( @@ -2879,6 +2929,12 @@ impl> SNARK< let (addr_phy_mems_verifier, addr_phy_mems_shifted_verifier) = { if total_num_phy_mem_accesses > 0 { + for comm in vec![ + self.addr_phy_mems_comm, + self.addr_phy_mems_shifted_comm, + ] { + Pcs::write_commitment(comm, transcript); + } ( VerifierWitnessSecInfo::new(vec![PHY_MEM_WIDTH], &vec![total_num_phy_mem_accesses]), VerifierWitnessSecInfo::new(vec![PHY_MEM_WIDTH], &vec![total_num_phy_mem_accesses]), @@ -2893,6 +2949,12 @@ impl> SNARK< let (addr_vir_mems_verifier, addr_vir_mems_shifted_verifier, addr_ts_bits_verifier) = { if total_num_vir_mem_accesses > 0 { + for comm in vec![ + self.addr_vir_mems_comm, + self.addr_vir_mems_shifted_comm, + ] { + Pcs::write_commitment(comm, transcript); + } ( VerifierWitnessSecInfo::new(vec![VIR_MEM_WIDTH], &vec![total_num_vir_mem_accesses]), VerifierWitnessSecInfo::new(vec![VIR_MEM_WIDTH], &vec![total_num_vir_mem_accesses]), From 0d83d91feeb03bce93f3ed6c7c9df55b5f1b2dc8 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 5 Feb 2025 16:32:49 -0500 Subject: [PATCH 03/18] Add commitments to witness section info --- spartan_parallel/src/lib.rs | 129 +++++++++++++++++++++++++++--------- 1 file changed, 99 insertions(+), 30 deletions(-) diff --git a/spartan_parallel/src/lib.rs b/spartan_parallel/src/lib.rs index 9060969e..523f7353 100644 --- a/spartan_parallel/src/lib.rs +++ b/spartan_parallel/src/lib.rs @@ -530,20 +530,23 @@ impl> ProverWitnessSecInfo // Information regarding one witness sec #[derive(Clone)] -struct VerifierWitnessSecInfo { +struct VerifierWitnessSecInfo> { // Number of inputs per block num_inputs: Vec, // Number of proofs per block, used by merge num_proofs: Vec, + // One commitment per circuit + comm_w: Vec } -impl VerifierWitnessSecInfo { +impl> VerifierWitnessSecInfo { // Unfortunately, cannot obtain all metadata from the commitment - fn new(num_inputs: Vec, num_proofs: &Vec) -> VerifierWitnessSecInfo { + fn new(num_inputs: Vec, num_proofs: &Vec, comm_w: Vec) -> VerifierWitnessSecInfo { let l = num_inputs.len(); VerifierWitnessSecInfo { num_inputs, num_proofs: num_proofs[..l].to_vec(), + comm_w, } } @@ -551,6 +554,7 @@ impl VerifierWitnessSecInfo { VerifierWitnessSecInfo { num_inputs: Vec::new(), num_proofs: Vec::new(), + comm_w: Vec::new(), } } @@ -558,6 +562,7 @@ impl VerifierWitnessSecInfo { VerifierWitnessSecInfo { num_inputs: vec![1], num_proofs: vec![1], + comm_w: vec![Pcs::Commitment::default()], } } @@ -565,15 +570,18 @@ impl VerifierWitnessSecInfo { fn concat(components: Vec<&VerifierWitnessSecInfo>) -> VerifierWitnessSecInfo { let mut num_inputs = Vec::new(); let mut num_proofs = Vec::new(); + let mut comm_w = Vec::new(); for c in components { num_inputs.extend(c.num_inputs.clone()); num_proofs.extend(c.num_proofs.clone()); + comm_w.extend(c.comm_w.clone()); } VerifierWitnessSecInfo { num_inputs, num_proofs, + comm_w, } } @@ -589,6 +597,7 @@ impl VerifierWitnessSecInfo { let mut inst_map = Vec::new(); let mut merged_num_inputs = Vec::new(); let mut merged_num_proofs = Vec::new(); + let mut merged_comm_w = Vec::new(); while inst_map.len() < merged_size { // Choose the next instance with the most proofs let mut next_max_num_proofs = 0; @@ -606,6 +615,7 @@ impl VerifierWitnessSecInfo { inst_map.push(next_component); merged_num_inputs.push(components[next_component].num_inputs[pointers[next_component]]); merged_num_proofs.push(components[next_component].num_proofs[pointers[next_component]]); + merged_comm_w.push(components[next_component].comm_w[pointers[next_component]]); pointers[next_component] = pointers[next_component] + 1; } @@ -613,6 +623,7 @@ impl VerifierWitnessSecInfo { VerifierWitnessSecInfo { num_inputs: merged_num_inputs, num_proofs: merged_num_proofs, + comm_w: merged_comm_w, }, inst_map, ) @@ -637,6 +648,14 @@ pub struct SNARK> { block_w2_comm_list: Vec, block_w3_comm_list: Vec, block_w3_shifted_comm_list: Vec, + + init_phy_mem_w2_comm: Option, + init_phy_mem_w3_comm: Option, + init_phy_mem_w3_shifted_comm: Option, + init_vir_mem_w2_comm: Option, + init_vir_mem_w3_comm: Option, + init_vir_mem_w3_shifted_comm: Option, + phy_mem_addr_w2_comm: Option, phy_mem_addr_w3_comm: Option, phy_mem_addr_w3_shifted_comm: Option, @@ -1613,7 +1632,7 @@ impl> SNARK< // Initial Physical Memory-as-a-whole let timer_sec_gen = Timer::new("init_phy_mem_witness_gen"); - let (init_phy_mem_w2_prover, _, init_phy_mem_w3_prover, _, init_phy_mem_w3_shifted_prover, _) = + let (init_phy_mem_w2_prover, init_phy_mem_w2_comm, init_phy_mem_w3_prover, init_phy_mem_w3_comm, init_phy_mem_w3_shifted_prover, init_phy_mem_w3_shifted_comm) = Self::mem_gen::( total_num_init_phy_mem_accesses, &init_phy_mems_list, @@ -1626,7 +1645,7 @@ impl> SNARK< // Initial Virtual Memory-as-a-whole let timer_sec_gen = Timer::new("init_vir_mem_witness_gen"); - let (init_vir_mem_w2_prover, _, init_vir_mem_w3_prover, _, init_vir_mem_w3_shifted_prover, _) = + let (init_vir_mem_w2_prover, init_vir_mem_w2_comm, init_vir_mem_w3_prover, init_vir_mem_w3_comm, init_vir_mem_w3_shifted_prover, init_vir_mem_w3_shifted_comm) = Self::mem_gen::( total_num_init_vir_mem_accesses, &init_vir_mems_list, @@ -1648,6 +1667,24 @@ impl> SNARK< transcript, &poly_pp, ); + + for op_comm in vec![ + init_phy_mem_w2_comm, + init_phy_mem_w3_comm, + init_phy_mem_w3_shifted_comm, + init_vir_mem_w2_comm, + init_vir_mem_w3_comm, + init_vir_mem_w3_shifted_comm, + phy_mem_addr_w2_comm, + phy_mem_addr_w3_comm, + phy_mem_addr_w3_shifted_comm, + ] { + match op_comm { + Some(comm) => Pcs::write_commitment(comm, transcript), + None => Ok(()), + } + } + timer_sec_gen.stop(); // Virtual Memory-as-a-whole @@ -1701,6 +1738,14 @@ impl> SNARK< let (vir_mem_addr_w3_mle, vir_mem_addr_w3_p_comm, vir_mem_addr_w3_v_comm) = Self::mat_to_comm(&vir_mem_addr_w3, &poly_pp); let (vir_mem_addr_w3_shifted_mle, vir_mem_addr_w3_shifted_p_comm, vir_mem_addr_w3_shifted_v_comm) = Self::mat_to_comm(&vir_mem_addr_w3_shifted_mat, &poly_pp); + for comm in vec![ + vir_mem_addr_w2_v_comm, + vir_mem_addr_w3_v_comm, + vir_mem_addr_w3_shifted_v_comm, + ] { + Pcs::write_commitment(comm, transcript); + } + let vir_mem_addr_w2_prover = ProverWitnessSecInfo::new(vec![vir_mem_addr_w2], vec![vir_mem_addr_w2_mle], vec![vir_mem_addr_w2_p_comm]); let vir_mem_addr_w3_prover = @@ -2394,6 +2439,14 @@ impl> SNARK< block_w2_comm_list, block_w3_comm_list, block_w3_shifted_comm_list, + + init_phy_mem_w2_comm, + init_phy_mem_w3_comm, + init_phy_mem_w3_shifted_comm, + init_vir_mem_w2_comm, + init_vir_mem_w3_comm, + init_vir_mem_w3_shifted_comm, + phy_mem_addr_w2_comm, phy_mem_addr_w3_comm, phy_mem_addr_w3_shifted_comm, @@ -2744,21 +2797,23 @@ impl> SNARK< .next_power_of_two() }) .collect(); - VerifierWitnessSecInfo::new(block_w2_size_list, &block_num_proofs) + VerifierWitnessSecInfo::new(block_w2_size_list, &block_num_proofs, self.block_w2_comm_list) }; ( - VerifierWitnessSecInfo::new(vec![num_ios], &vec![1]), - VerifierWitnessSecInfo::new(vec![num_ios], &vec![consis_num_proofs]), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![consis_num_proofs]), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![consis_num_proofs]), + VerifierWitnessSecInfo::new(vec![num_ios], &vec![1], vec![Pcs::Commitment::default()]), + VerifierWitnessSecInfo::new(vec![num_ios], &vec![consis_num_proofs], vec![self.perm_exec_w2_comm]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![consis_num_proofs], vec![self.perm_exec_w3_comm]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![consis_num_proofs], vec![self.perm_exec_w3_shifted_comm]), block_w2_verifier, VerifierWitnessSecInfo::new( vec![W3_WIDTH; block_num_instances], &block_num_proofs.clone(), + self.block_w3_comm_list, ), VerifierWitnessSecInfo::new( vec![W3_WIDTH; block_num_instances], &block_num_proofs.clone(), + self.block_w3_shifted_comm_list, ), ) }; @@ -2782,9 +2837,10 @@ impl> SNARK< VerifierWitnessSecInfo::new( vec![INIT_PHY_MEM_WIDTH], &vec![total_num_init_phy_mem_accesses], + vec![self.init_phy_mem_w2_comm.expect("commitment should exist")] ), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_init_phy_mem_accesses]), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_init_phy_mem_accesses]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_init_phy_mem_accesses], vec![self.init_phy_mem_w3_comm.expect("commitment should exist")]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_init_phy_mem_accesses], vec![self.init_phy_mem_w3_shifted_comm.expect("commitment should exist")]), ) } else { ( @@ -2801,9 +2857,10 @@ impl> SNARK< VerifierWitnessSecInfo::new( vec![INIT_VIR_MEM_WIDTH], &vec![total_num_init_vir_mem_accesses], + vec![self.init_vir_mem_w2_comm.expect("commitment should exist")], ), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_init_vir_mem_accesses]), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_init_vir_mem_accesses]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_init_vir_mem_accesses], vec![self.init_vir_mem_w3_comm.expect("commitment should exist")]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_init_vir_mem_accesses], vec![self.init_vir_mem_w3_shifted_comm.expect("commitment should exist")]), ) } else { ( @@ -2817,9 +2874,9 @@ impl> SNARK< let (phy_mem_addr_w2_verifier, phy_mem_addr_w3_verifier, phy_mem_addr_w3_shifted_verifier) = { if total_num_phy_mem_accesses > 0 { ( - VerifierWitnessSecInfo::new(vec![PHY_MEM_WIDTH], &vec![total_num_phy_mem_accesses]), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_phy_mem_accesses]), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_phy_mem_accesses]), + VerifierWitnessSecInfo::new(vec![PHY_MEM_WIDTH], &vec![total_num_phy_mem_accesses], vec![self.phy_mem_addr_w2_comm.expect("commitment should exist")]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_phy_mem_accesses], vec![self.phy_mem_addr_w3_comm].expect("commitment should exist")), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_phy_mem_accesses], vec![self.phy_mem_addr_w3_shifted_comm].expect("commitment should exist")), ) } else { ( @@ -2832,10 +2889,18 @@ impl> SNARK< let (vir_mem_addr_w2_verifier, vir_mem_addr_w3_verifier, vir_mem_addr_w3_shifted_verifier) = { if total_num_vir_mem_accesses > 0 { + for comm in vec![ + self.vir_mem_addr_w2_comm, + self.vir_mem_addr_w3_comm, + self.vir_mem_addr_w3_shifted_comm, + ] { + Pcs::write_commitment(comm, transcript); + } + ( - VerifierWitnessSecInfo::new(vec![VIR_MEM_WIDTH], &vec![total_num_vir_mem_accesses]), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_vir_mem_accesses]), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_vir_mem_accesses]), + VerifierWitnessSecInfo::new(vec![VIR_MEM_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.vir_mem_addr_w2_comm]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.vir_mem_addr_w3_comm]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.vir_mem_addr_w3_shifted_comm]), ) } else { ( @@ -2848,16 +2913,16 @@ impl> SNARK< let (block_vars_verifier, exec_inputs_verifier) = { // add the commitment to the verifier's transcript + for comm in self.block_vars_comm_list.into_iter().chain(vec![self.exec_inputs_comm]) { + Pcs::write_commitment(comm, transcript); + } + ( - VerifierWitnessSecInfo::new(block_num_vars, &block_num_proofs), - VerifierWitnessSecInfo::new(vec![num_ios], &vec![consis_num_proofs]), + VerifierWitnessSecInfo::new(block_num_vars, &block_num_proofs, self.block_vars_comm_list), + VerifierWitnessSecInfo::new(vec![num_ios], &vec![consis_num_proofs], vec![self.exec_inputs_comm]), ) }; - for comm in self.block_vars_comm_list.into_iter().chain(vec![self.exec_inputs_comm]) { - Pcs::write_commitment(comm, transcript); - } - let init_phy_mems_verifier = { if input_stack.len() > 0 { assert_eq!( @@ -2887,6 +2952,7 @@ impl> SNARK< VerifierWitnessSecInfo::new( vec![INIT_PHY_MEM_WIDTH], &vec![total_num_init_phy_mem_accesses], + vec![Pcs::Commitment::default()], ) } else { VerifierWitnessSecInfo::dummy() @@ -2921,6 +2987,7 @@ impl> SNARK< VerifierWitnessSecInfo::new( vec![INIT_VIR_MEM_WIDTH], &vec![total_num_init_vir_mem_accesses], + vec![Pcs::Commitment::default()], ) } else { VerifierWitnessSecInfo::dummy() @@ -2936,8 +3003,8 @@ impl> SNARK< Pcs::write_commitment(comm, transcript); } ( - VerifierWitnessSecInfo::new(vec![PHY_MEM_WIDTH], &vec![total_num_phy_mem_accesses]), - VerifierWitnessSecInfo::new(vec![PHY_MEM_WIDTH], &vec![total_num_phy_mem_accesses]), + VerifierWitnessSecInfo::new(vec![PHY_MEM_WIDTH], &vec![total_num_phy_mem_accesses], vec![self.addr_phy_mems_comm.expect("commitment should exist")]), + VerifierWitnessSecInfo::new(vec![PHY_MEM_WIDTH], &vec![total_num_phy_mem_accesses], vec![self.addr_phy_mems_shifted_comm.expect("commitment should exist")]), ) } else { ( @@ -2952,15 +3019,17 @@ impl> SNARK< for comm in vec![ self.addr_vir_mems_comm, self.addr_vir_mems_shifted_comm, + self.addr_ts_bits_comm, ] { Pcs::write_commitment(comm, transcript); } ( - VerifierWitnessSecInfo::new(vec![VIR_MEM_WIDTH], &vec![total_num_vir_mem_accesses]), - VerifierWitnessSecInfo::new(vec![VIR_MEM_WIDTH], &vec![total_num_vir_mem_accesses]), + VerifierWitnessSecInfo::new(vec![VIR_MEM_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.addr_vir_mems_comm.expect("commitment should exist")]), + VerifierWitnessSecInfo::new(vec![VIR_MEM_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.addr_vir_mems_shifted_comm.expect("commitment should exist")]), VerifierWitnessSecInfo::new( vec![mem_addr_ts_bits_size], &vec![total_num_vir_mem_accesses], + vec![self.addr_ts_bits_comm.expect("commitment should exist")], ), ) } else { From 8257bec4d1c1e99ce06f669d323757194a6304e6 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 6 Feb 2025 00:08:59 -0500 Subject: [PATCH 04/18] Add verification --- spartan_parallel/src/r1csproof.rs | 57 ++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/spartan_parallel/src/r1csproof.rs b/spartan_parallel/src/r1csproof.rs index 6ab01495..85a4a717 100644 --- a/spartan_parallel/src/r1csproof.rs +++ b/spartan_parallel/src/r1csproof.rs @@ -435,7 +435,7 @@ impl> R1CSPr let max_len = max(poly_list.iter().map(|&p| p.num_vars)).unwrap().next_power_of_two(); let param = Pcs::setup(max_len).unwrap(); - let (pp, _error) = Pcs::trim(param, max_len).unwrap(); + let (pp, vp) = Pcs::trim(param, max_len).unwrap(); let mut proof_eval_vars_at_ry_list: Vec = Vec::new(); let mut proof_idx: usize = 0; @@ -672,6 +672,7 @@ impl> R1CSPr let mut num_proofs_list = Vec::new(); let mut num_inputs_list = Vec::new(); let mut eval_Zr_list = Vec::new(); + let mut comm_w_list = Vec::new(); for i in 0..num_witness_secs { let w = witness_secs[i]; let wit_sec_num_instance = w.num_proofs.len(); @@ -681,22 +682,23 @@ impl> R1CSPr num_proofs_list.push(w.num_proofs[p]); num_inputs_list.push(w.num_inputs[p]); eval_Zr_list.push(self.eval_vars_at_ry_list[i][p]); + comm_w_list.push(w.comm_w[p]); } else { assert_eq!(self.eval_vars_at_ry_list[i][p], ZERO); } } } - /* - PolyEvalProof::verify_batched_instances_disjoint_rounds( + Self::verify_batched_instances_disjoint_rounds( &self.proof_eval_vars_at_ry_list, &num_proofs_list, &num_inputs_list, transcript, &rq, &ry, + eval_Zr_list, + &comm_w_list, )?; - */ // Then on rp let mut expected_eval_vars_list = Vec::new(); @@ -776,4 +778,51 @@ impl> R1CSPr Ok([rp, rq, rx, [rw, ry].concat()]) } + + pub fn verify_batched_instances_disjoint_rounds( + proof_list: &Vec, + num_proofs_list: &Vec, + num_inputs_list: &Vec, + transcript: &mut Transcript, + rq: &[E], + ry: &[E], + eval_Zr_list: Vec, + comm_w: &Vec, + ) -> Result<(), ProofVerifyError> { + let max_num_proofs = max(num_proofs_list); + let max_num_inputs = max(num_inputs_list); + let max_len = (max_num_proofs + max_num_inputs).next_power_of_two(); + + let param = Pcs::setup(max_len).unwrap(); + let (_pp, vp) = Pcs::trim(param, max_len).unwrap(); + + for (idx, proof) in proof_list.into_iter().enumerate() { + let num_proofs = num_proofs_list[idx]; + let num_inputs = num_inputs_list[idx]; + + let num_vars_q = num_proofs.log_2(); + let num_vars_y = num_inputs.log_2(); + + let ry_short = { + if num_vars_y >= ry.len() { + let mut ry_pad: Vec = vec![E::ZERO; num_vars_y - ry.len()]; + ry_pad.extend_from_slice(&ry); + ry_pad + } + // Else ry_short is the last w.num_inputs[p].log_2() entries of ry + // thus, to obtain the actual ry, need to multiply by (1 - ry2)(1 - ry3)..., which is ry_factors[num_rounds_y - w.num_inputs[p]] + else { + ry[ry.len() - num_vars_y..].to_vec() + } + }; + let rq_short = rq[rq.len() - num_vars_q..].to_vec(); + let r = [rq_short, ry_short.clone()].concat(); + let Zr = eval_Zr_list[idx]; + let comm = comm_w[idx]; + + pcs_verify(vp, comm, &r, &Zr, proof, transcript); + } + + Ok(()) + } } From 98fd06096806383f5c9bb7b169fcb8c543aff319 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 6 Feb 2025 01:11:12 -0500 Subject: [PATCH 05/18] Build fix --- spartan_parallel/src/dense_mlpoly.rs | 5 +- spartan_parallel/src/lib.rs | 185 ++++++++++++++------------- spartan_parallel/src/r1csproof.rs | 16 +-- spartan_parallel/src/transcript.rs | 2 +- 4 files changed, 106 insertions(+), 102 deletions(-) diff --git a/spartan_parallel/src/dense_mlpoly.rs b/spartan_parallel/src/dense_mlpoly.rs index 93e57d87..781bcf90 100644 --- a/spartan_parallel/src/dense_mlpoly.rs +++ b/spartan_parallel/src/dense_mlpoly.rs @@ -9,7 +9,6 @@ use core::ops::Index; use rayon::{iter::ParallelIterator, slice::ParallelSliceMut}; use serde::{Deserialize, Serialize}; use std::cmp::min; -use rayon::prelude::*; #[derive(Debug, Clone)] pub struct DensePolynomial { @@ -383,8 +382,8 @@ impl PolyEvalProof { pub fn verify( &self, - transcript: &mut Transcript, - r: &[E], // point at which the polynomial is evaluated + _transcript: &mut Transcript, + _r: &[E], // point at which the polynomial is evaluated ) -> Result<(), ProofVerifyError> { // TODO: Alternative evaluation proof scheme Ok(()) diff --git a/spartan_parallel/src/lib.rs b/spartan_parallel/src/lib.rs index 523f7353..15fc12e0 100644 --- a/spartan_parallel/src/lib.rs +++ b/spartan_parallel/src/lib.rs @@ -39,7 +39,7 @@ mod unipoly; use std::{ cmp::{max, Ordering}, fs::File, - io::Write, iter::zip, + io::Write, }; use dense_mlpoly::{DensePolynomial, PolyEvalProof}; @@ -57,7 +57,7 @@ use ff_ext::ExtensionField; use serde::{Deserialize, Serialize}; use timer::Timer; use bytes::{from_bytes, to_bytes}; -use transcript::{Transcript, append_protocol_name, append_field_to_transcript, append_field_vector_to_transcript, challenge_scalar, challenge_vector}; +use transcript::{Transcript, append_protocol_name, append_field_to_transcript, append_field_vector_to_transcript, challenge_scalar}; const INIT_PHY_MEM_WIDTH: usize = 4; const INIT_VIR_MEM_WIDTH: usize = 4; @@ -541,7 +541,7 @@ struct VerifierWitnessSecInfo> VerifierWitnessSecInfo { // Unfortunately, cannot obtain all metadata from the commitment - fn new(num_inputs: Vec, num_proofs: &Vec, comm_w: Vec) -> VerifierWitnessSecInfo { + fn new(num_inputs: Vec, num_proofs: &Vec, comm_w: Vec) -> VerifierWitnessSecInfo { let l = num_inputs.len(); VerifierWitnessSecInfo { num_inputs, @@ -550,7 +550,7 @@ impl> VerifierWitnessSecIn } } - fn dummy() -> VerifierWitnessSecInfo { + fn dummy() -> VerifierWitnessSecInfo { VerifierWitnessSecInfo { num_inputs: Vec::new(), num_proofs: Vec::new(), @@ -558,7 +558,7 @@ impl> VerifierWitnessSecIn } } - fn pad() -> VerifierWitnessSecInfo { + fn pad() -> VerifierWitnessSecInfo { VerifierWitnessSecInfo { num_inputs: vec![1], num_proofs: vec![1], @@ -567,7 +567,7 @@ impl> VerifierWitnessSecIn } // Concatenate the components in the given order to a new verifier witness sec - fn concat(components: Vec<&VerifierWitnessSecInfo>) -> VerifierWitnessSecInfo { + fn concat(components: Vec<&VerifierWitnessSecInfo>) -> VerifierWitnessSecInfo { let mut num_inputs = Vec::new(); let mut num_proofs = Vec::new(); let mut comm_w = Vec::new(); @@ -589,7 +589,7 @@ impl> VerifierWitnessSecIn // Assume all components are sorted // Returns: 1. the merged VerifierWitnessSec, // 2. for each instance in the merged VerifierWitnessSec, the component it orignally belong to - fn merge(components: Vec<&VerifierWitnessSecInfo>) -> (VerifierWitnessSecInfo, Vec) { + fn merge(components: Vec<&VerifierWitnessSecInfo>) -> (VerifierWitnessSecInfo, Vec) { // Merge algorithm with pointer on each component let mut pointers = vec![0; components.len()]; let merged_size = components.iter().map(|c| c.num_inputs.len()).sum(); @@ -615,7 +615,7 @@ impl> VerifierWitnessSecIn inst_map.push(next_component); merged_num_inputs.push(components[next_component].num_inputs[pointers[next_component]]); merged_num_proofs.push(components[next_component].num_proofs[pointers[next_component]]); - merged_comm_w.push(components[next_component].comm_w[pointers[next_component]]); + merged_comm_w.push(components[next_component].comm_w[pointers[next_component]].clone()); pointers[next_component] = pointers[next_component] + 1; } @@ -1599,14 +1599,19 @@ impl> SNARK< // Append commmitments to transcript for comm in vec![ - perm_exec_w2_v_comm, - perm_exec_w3_v_comm, - perm_exec_w3_shifted_v_comm, - ] - .into_iter() - .chain(block_w2_v_comm.into_iter()) - .chain(block_w3_v_comm.into_iter()) - .chain(block_w3_shifted_v_comm.into_iter()) { + &perm_exec_w2_v_comm, + &perm_exec_w3_v_comm, + &perm_exec_w3_shifted_v_comm, + ] { + Pcs::write_commitment(comm, transcript); + } + for comm in block_w2_v_comm.iter() { + Pcs::write_commitment(comm, transcript); + } + for comm in block_w3_v_comm.iter() { + Pcs::write_commitment(comm, transcript); + } + for comm in block_w3_shifted_v_comm.iter() { Pcs::write_commitment(comm, transcript); } @@ -1669,19 +1674,22 @@ impl> SNARK< ); for op_comm in vec![ - init_phy_mem_w2_comm, - init_phy_mem_w3_comm, - init_phy_mem_w3_shifted_comm, - init_vir_mem_w2_comm, - init_vir_mem_w3_comm, - init_vir_mem_w3_shifted_comm, - phy_mem_addr_w2_comm, - phy_mem_addr_w3_comm, - phy_mem_addr_w3_shifted_comm, + &init_phy_mem_w2_comm, + &init_phy_mem_w3_comm, + &init_phy_mem_w3_shifted_comm, + &init_vir_mem_w2_comm, + &init_vir_mem_w3_comm, + &init_vir_mem_w3_shifted_comm, + &phy_mem_addr_w2_comm, + &phy_mem_addr_w3_comm, + &phy_mem_addr_w3_shifted_comm, ] { match op_comm { - Some(comm) => Pcs::write_commitment(comm, transcript), - None => Ok(()), + Some(comm) => { + Pcs::write_commitment(comm, transcript); + () + }, + None => (), } } @@ -1739,9 +1747,9 @@ impl> SNARK< let (vir_mem_addr_w3_shifted_mle, vir_mem_addr_w3_shifted_p_comm, vir_mem_addr_w3_shifted_v_comm) = Self::mat_to_comm(&vir_mem_addr_w3_shifted_mat, &poly_pp); for comm in vec![ - vir_mem_addr_w2_v_comm, - vir_mem_addr_w3_v_comm, - vir_mem_addr_w3_shifted_v_comm, + &vir_mem_addr_w2_v_comm, + &vir_mem_addr_w3_v_comm, + &vir_mem_addr_w3_shifted_v_comm, ] { Pcs::write_commitment(comm, transcript); } @@ -1785,9 +1793,10 @@ impl> SNARK< let timer_commit = Timer::new("input_commit"); let (block_vars_prover, block_vars_v_comm_list) = Self::mats_to_prove_wit_sec(block_vars_mat, &poly_pp); let (exec_inputs_prover, exec_inputs_v_comm) = Self::mat_to_prove_wit_sec(exec_inputs_list, &poly_pp); - for comm in block_vars_v_comm_list.into_iter().chain(vec![exec_inputs_v_comm].into_iter()) { + for comm in block_vars_v_comm_list.iter() { Pcs::write_commitment(comm, transcript); } + Pcs::write_commitment(&exec_inputs_v_comm, transcript); let init_phy_mems_prover = if total_num_init_phy_mem_accesses > 0 { Self::mat_to_prover_wit_sec_no_commit(init_phy_mems_list) @@ -1807,8 +1816,8 @@ impl> SNARK< let (addr_phy_mems_shifted_prover, addr_phy_mems_shifted_v_comm) = Self::mat_to_prove_wit_sec(addr_phy_mems_shifted_list, &poly_pp); for comm in vec![ - addr_phy_mems_v_comm, - addr_phy_mems_shifted_v_comm, + &addr_phy_mems_v_comm, + &addr_phy_mems_shifted_v_comm, ] { Pcs::write_commitment(comm, transcript); } @@ -1827,9 +1836,9 @@ impl> SNARK< let (addr_ts_bits_prover, addr_ts_bits_v_comm) = Self::mat_to_prove_wit_sec(addr_ts_bits_list, &poly_pp); for comm in vec![ - addr_vir_mems_v_comm, - addr_vir_mems_shifted_v_comm, - addr_ts_bits_v_comm, + &addr_vir_mems_v_comm, + &addr_vir_mems_shifted_v_comm, + &addr_ts_bits_v_comm, ] { Pcs::write_commitment(comm, transcript); } @@ -2797,37 +2806,42 @@ impl> SNARK< .next_power_of_two() }) .collect(); - VerifierWitnessSecInfo::new(block_w2_size_list, &block_num_proofs, self.block_w2_comm_list) + VerifierWitnessSecInfo::new(block_w2_size_list, &block_num_proofs, self.block_w2_comm_list.clone()) }; ( VerifierWitnessSecInfo::new(vec![num_ios], &vec![1], vec![Pcs::Commitment::default()]), - VerifierWitnessSecInfo::new(vec![num_ios], &vec![consis_num_proofs], vec![self.perm_exec_w2_comm]), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![consis_num_proofs], vec![self.perm_exec_w3_comm]), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![consis_num_proofs], vec![self.perm_exec_w3_shifted_comm]), + VerifierWitnessSecInfo::new(vec![num_ios], &vec![consis_num_proofs], vec![self.perm_exec_w2_comm.clone()]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![consis_num_proofs], vec![self.perm_exec_w3_comm.clone()]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![consis_num_proofs], vec![self.perm_exec_w3_shifted_comm.clone()]), block_w2_verifier, VerifierWitnessSecInfo::new( vec![W3_WIDTH; block_num_instances], &block_num_proofs.clone(), - self.block_w3_comm_list, + self.block_w3_comm_list.clone(), ), VerifierWitnessSecInfo::new( vec![W3_WIDTH; block_num_instances], &block_num_proofs.clone(), - self.block_w3_shifted_comm_list, + self.block_w3_shifted_comm_list.clone(), ), ) }; // Append commmitments to transcript for comm in vec![ - self.perm_exec_w2_comm, - self.perm_exec_w3_comm, - self.perm_exec_w3_shifted_comm, - ] - .into_iter() - .chain(self.block_w2_comm_list.into_iter()) - .chain(self.block_w3_comm_list.into_iter()) - .chain(self.block_w3_shifted_comm_list.into_iter()) { + &self.perm_exec_w2_comm, + &self.perm_exec_w3_comm, + &self.perm_exec_w3_shifted_comm, + ] { + Pcs::write_commitment(comm, transcript); + } + for comm in self.block_w2_comm_list.iter() { + Pcs::write_commitment(comm, transcript); + } + for comm in self.block_w3_comm_list.iter() { + Pcs::write_commitment(comm, transcript); + } + for comm in self.block_w3_shifted_comm_list.iter() { Pcs::write_commitment(comm, transcript); } @@ -2837,10 +2851,10 @@ impl> SNARK< VerifierWitnessSecInfo::new( vec![INIT_PHY_MEM_WIDTH], &vec![total_num_init_phy_mem_accesses], - vec![self.init_phy_mem_w2_comm.expect("commitment should exist")] + vec![self.init_phy_mem_w2_comm.clone().expect("commitment should exist")] ), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_init_phy_mem_accesses], vec![self.init_phy_mem_w3_comm.expect("commitment should exist")]), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_init_phy_mem_accesses], vec![self.init_phy_mem_w3_shifted_comm.expect("commitment should exist")]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_init_phy_mem_accesses], vec![self.init_phy_mem_w3_comm.clone().expect("commitment should exist")]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_init_phy_mem_accesses], vec![self.init_phy_mem_w3_shifted_comm.clone().expect("commitment should exist")]), ) } else { ( @@ -2857,10 +2871,10 @@ impl> SNARK< VerifierWitnessSecInfo::new( vec![INIT_VIR_MEM_WIDTH], &vec![total_num_init_vir_mem_accesses], - vec![self.init_vir_mem_w2_comm.expect("commitment should exist")], + vec![self.init_vir_mem_w2_comm.clone().expect("commitment should exist")], ), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_init_vir_mem_accesses], vec![self.init_vir_mem_w3_comm.expect("commitment should exist")]), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_init_vir_mem_accesses], vec![self.init_vir_mem_w3_shifted_comm.expect("commitment should exist")]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_init_vir_mem_accesses], vec![self.init_vir_mem_w3_comm.clone().expect("commitment should exist")]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_init_vir_mem_accesses], vec![self.init_vir_mem_w3_shifted_comm.clone().expect("commitment should exist")]), ) } else { ( @@ -2874,9 +2888,9 @@ impl> SNARK< let (phy_mem_addr_w2_verifier, phy_mem_addr_w3_verifier, phy_mem_addr_w3_shifted_verifier) = { if total_num_phy_mem_accesses > 0 { ( - VerifierWitnessSecInfo::new(vec![PHY_MEM_WIDTH], &vec![total_num_phy_mem_accesses], vec![self.phy_mem_addr_w2_comm.expect("commitment should exist")]), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_phy_mem_accesses], vec![self.phy_mem_addr_w3_comm].expect("commitment should exist")), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_phy_mem_accesses], vec![self.phy_mem_addr_w3_shifted_comm].expect("commitment should exist")), + VerifierWitnessSecInfo::new(vec![PHY_MEM_WIDTH], &vec![total_num_phy_mem_accesses], vec![self.phy_mem_addr_w2_comm.clone().expect("commitment should exist")]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_phy_mem_accesses], vec![self.phy_mem_addr_w3_comm.clone().expect("commitment should exist")]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_phy_mem_accesses], vec![self.phy_mem_addr_w3_shifted_comm.clone().expect("commitment should exist")]), ) } else { ( @@ -2889,18 +2903,14 @@ impl> SNARK< let (vir_mem_addr_w2_verifier, vir_mem_addr_w3_verifier, vir_mem_addr_w3_shifted_verifier) = { if total_num_vir_mem_accesses > 0 { - for comm in vec![ - self.vir_mem_addr_w2_comm, - self.vir_mem_addr_w3_comm, - self.vir_mem_addr_w3_shifted_comm, - ] { - Pcs::write_commitment(comm, transcript); - } + Pcs::write_commitment(self.vir_mem_addr_w2_comm.as_ref().clone().expect("valid commitment expected"), transcript); + Pcs::write_commitment(self.vir_mem_addr_w3_comm.as_ref().clone().expect("valid commitment expected"), transcript); + Pcs::write_commitment(self.vir_mem_addr_w3_shifted_comm.as_ref().clone().expect("valid commitment expected"), transcript); ( - VerifierWitnessSecInfo::new(vec![VIR_MEM_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.vir_mem_addr_w2_comm]), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.vir_mem_addr_w3_comm]), - VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.vir_mem_addr_w3_shifted_comm]), + VerifierWitnessSecInfo::new(vec![VIR_MEM_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.vir_mem_addr_w2_comm.clone().expect("valid commitment expected")]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.vir_mem_addr_w3_comm.clone().expect("valid commitment expected")]), + VerifierWitnessSecInfo::new(vec![W3_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.vir_mem_addr_w3_shifted_comm.clone().expect("valid commitment expected")]), ) } else { ( @@ -2913,13 +2923,14 @@ impl> SNARK< let (block_vars_verifier, exec_inputs_verifier) = { // add the commitment to the verifier's transcript - for comm in self.block_vars_comm_list.into_iter().chain(vec![self.exec_inputs_comm]) { + for comm in self.block_vars_comm_list.iter() { Pcs::write_commitment(comm, transcript); } + Pcs::write_commitment(&self.exec_inputs_comm, transcript); ( - VerifierWitnessSecInfo::new(block_num_vars, &block_num_proofs, self.block_vars_comm_list), - VerifierWitnessSecInfo::new(vec![num_ios], &vec![consis_num_proofs], vec![self.exec_inputs_comm]), + VerifierWitnessSecInfo::new(block_num_vars, &block_num_proofs, self.block_vars_comm_list.clone()), + VerifierWitnessSecInfo::new(vec![num_ios], &vec![consis_num_proofs], vec![self.exec_inputs_comm.clone()]), ) }; @@ -2996,15 +3007,12 @@ impl> SNARK< let (addr_phy_mems_verifier, addr_phy_mems_shifted_verifier) = { if total_num_phy_mem_accesses > 0 { - for comm in vec![ - self.addr_phy_mems_comm, - self.addr_phy_mems_shifted_comm, - ] { - Pcs::write_commitment(comm, transcript); - } + Pcs::write_commitment(&self.addr_phy_mems_comm.clone().expect("valid commitment expected"), transcript); + Pcs::write_commitment(&self.addr_phy_mems_shifted_comm.clone().expect("valid commitment expected"), transcript); + ( - VerifierWitnessSecInfo::new(vec![PHY_MEM_WIDTH], &vec![total_num_phy_mem_accesses], vec![self.addr_phy_mems_comm.expect("commitment should exist")]), - VerifierWitnessSecInfo::new(vec![PHY_MEM_WIDTH], &vec![total_num_phy_mem_accesses], vec![self.addr_phy_mems_shifted_comm.expect("commitment should exist")]), + VerifierWitnessSecInfo::new(vec![PHY_MEM_WIDTH], &vec![total_num_phy_mem_accesses], vec![self.addr_phy_mems_comm.clone().expect("commitment should exist")]), + VerifierWitnessSecInfo::new(vec![PHY_MEM_WIDTH], &vec![total_num_phy_mem_accesses], vec![self.addr_phy_mems_shifted_comm.clone().expect("commitment should exist")]), ) } else { ( @@ -3016,20 +3024,17 @@ impl> SNARK< let (addr_vir_mems_verifier, addr_vir_mems_shifted_verifier, addr_ts_bits_verifier) = { if total_num_vir_mem_accesses > 0 { - for comm in vec![ - self.addr_vir_mems_comm, - self.addr_vir_mems_shifted_comm, - self.addr_ts_bits_comm, - ] { - Pcs::write_commitment(comm, transcript); - } + Pcs::write_commitment(self.addr_vir_mems_comm.clone().as_ref().expect("valid commitment expected"), transcript); + Pcs::write_commitment(self.addr_vir_mems_shifted_comm.clone().as_ref().expect("valid commitment expected"), transcript); + Pcs::write_commitment(self.addr_ts_bits_comm.clone().as_ref().expect("valid commitment expected"), transcript); + ( - VerifierWitnessSecInfo::new(vec![VIR_MEM_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.addr_vir_mems_comm.expect("commitment should exist")]), - VerifierWitnessSecInfo::new(vec![VIR_MEM_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.addr_vir_mems_shifted_comm.expect("commitment should exist")]), + VerifierWitnessSecInfo::new(vec![VIR_MEM_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.addr_vir_mems_comm.clone().expect("commitment should exist")]), + VerifierWitnessSecInfo::new(vec![VIR_MEM_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.addr_vir_mems_shifted_comm.clone().expect("commitment should exist")]), VerifierWitnessSecInfo::new( vec![mem_addr_ts_bits_size], &vec![total_num_vir_mem_accesses], - vec![self.addr_ts_bits_comm.expect("commitment should exist")], + vec![self.addr_ts_bits_comm.clone().expect("commitment should exist")], ), ) } else { diff --git a/spartan_parallel/src/r1csproof.rs b/spartan_parallel/src/r1csproof.rs index 85a4a717..33083ef7 100644 --- a/spartan_parallel/src/r1csproof.rs +++ b/spartan_parallel/src/r1csproof.rs @@ -15,6 +15,7 @@ use serde::Serialize; use std::cmp::min; use std::iter::zip; use rayon::prelude::*; +use mpcs::pcs_verify; #[derive(Serialize, Debug)] pub struct R1CSProof> { @@ -435,7 +436,7 @@ impl> R1CSPr let max_len = max(poly_list.iter().map(|&p| p.num_vars)).unwrap().next_power_of_two(); let param = Pcs::setup(max_len).unwrap(); - let (pp, vp) = Pcs::trim(param, max_len).unwrap(); + let (pp, _vp) = Pcs::trim(param, max_len).unwrap(); let mut proof_eval_vars_at_ry_list: Vec = Vec::new(); let mut proof_idx: usize = 0; @@ -475,6 +476,7 @@ impl> R1CSPr ).expect("PCS proof should not fail"); proof_eval_vars_at_ry_list.push(pcs_proof); + proof_idx += 1; } } @@ -566,7 +568,7 @@ impl> R1CSPr // NUM_INPUTS: number of inputs per block // W_MAT: num_instances x num_proofs x num_inputs hypermatrix for all values // COMM_W: one commitment per instance - witness_secs: Vec<&VerifierWitnessSecInfo>, + witness_secs: Vec<&VerifierWitnessSecInfo>, num_cons: usize, evals: &[E; 3], @@ -682,7 +684,7 @@ impl> R1CSPr num_proofs_list.push(w.num_proofs[p]); num_inputs_list.push(w.num_inputs[p]); eval_Zr_list.push(self.eval_vars_at_ry_list[i][p]); - comm_w_list.push(w.comm_w[p]); + comm_w_list.push(w.comm_w[p].clone()); } else { assert_eq!(self.eval_vars_at_ry_list[i][p], ZERO); } @@ -789,8 +791,8 @@ impl> R1CSPr eval_Zr_list: Vec, comm_w: &Vec, ) -> Result<(), ProofVerifyError> { - let max_num_proofs = max(num_proofs_list); - let max_num_inputs = max(num_inputs_list); + let max_num_proofs = max(num_proofs_list).expect("max should exist").clone(); + let max_num_inputs = max(num_inputs_list).expect("max should exist").clone(); let max_len = (max_num_proofs + max_num_inputs).next_power_of_two(); let param = Pcs::setup(max_len).unwrap(); @@ -817,10 +819,8 @@ impl> R1CSPr }; let rq_short = rq[rq.len() - num_vars_q..].to_vec(); let r = [rq_short, ry_short.clone()].concat(); - let Zr = eval_Zr_list[idx]; - let comm = comm_w[idx]; - pcs_verify(vp, comm, &r, &Zr, proof, transcript); + pcs_verify::(&vp, &comm_w[idx], &r, &eval_Zr_list[idx], proof, transcript); } Ok(()) diff --git a/spartan_parallel/src/transcript.rs b/spartan_parallel/src/transcript.rs index 7553c924..7a4b2ffa 100644 --- a/spartan_parallel/src/transcript.rs +++ b/spartan_parallel/src/transcript.rs @@ -51,7 +51,7 @@ pub fn append_field_to_transcript( /// Append a vector ExtensionField scalars to transcript pub fn append_field_vector_to_transcript( - label: &'static [u8], + _label: &'static [u8], transcript: &mut Transcript, input: &[E], ) { From 85fbd37b85b10cedd453bf02fdca1d028c3dc29e Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Thu, 6 Feb 2025 10:04:20 -0500 Subject: [PATCH 06/18] Minor fixes --- spartan_parallel/src/lib.rs | 43 ++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/spartan_parallel/src/lib.rs b/spartan_parallel/src/lib.rs index 15fc12e0..e215a1a7 100644 --- a/spartan_parallel/src/lib.rs +++ b/spartan_parallel/src/lib.rs @@ -562,7 +562,7 @@ impl> VerifierWitnessSecIn VerifierWitnessSecInfo { num_inputs: vec![1], num_proofs: vec![1], - comm_w: vec![Pcs::Commitment::default()], + comm_w: Vec::new(), } } @@ -1686,8 +1686,7 @@ impl> SNARK< ] { match op_comm { Some(comm) => { - Pcs::write_commitment(comm, transcript); - () + Pcs::write_commitment(comm, transcript).unwrap() }, None => (), } @@ -1751,7 +1750,7 @@ impl> SNARK< &vir_mem_addr_w3_v_comm, &vir_mem_addr_w3_shifted_v_comm, ] { - Pcs::write_commitment(comm, transcript); + Pcs::write_commitment(comm, transcript).unwrap(); } let vir_mem_addr_w2_prover = @@ -1794,9 +1793,9 @@ impl> SNARK< let (block_vars_prover, block_vars_v_comm_list) = Self::mats_to_prove_wit_sec(block_vars_mat, &poly_pp); let (exec_inputs_prover, exec_inputs_v_comm) = Self::mat_to_prove_wit_sec(exec_inputs_list, &poly_pp); for comm in block_vars_v_comm_list.iter() { - Pcs::write_commitment(comm, transcript); + Pcs::write_commitment(comm, transcript).unwrap(); } - Pcs::write_commitment(&exec_inputs_v_comm, transcript); + Pcs::write_commitment(&exec_inputs_v_comm, transcript).unwrap(); let init_phy_mems_prover = if total_num_init_phy_mem_accesses > 0 { Self::mat_to_prover_wit_sec_no_commit(init_phy_mems_list) @@ -1819,7 +1818,7 @@ impl> SNARK< &addr_phy_mems_v_comm, &addr_phy_mems_shifted_v_comm, ] { - Pcs::write_commitment(comm, transcript); + Pcs::write_commitment(comm, transcript).unwrap(); } (addr_phy_mems_prover, Some(addr_phy_mems_v_comm), addr_phy_mems_shifted_prover, Some(addr_phy_mems_shifted_v_comm)) @@ -1840,7 +1839,7 @@ impl> SNARK< &addr_vir_mems_shifted_v_comm, &addr_ts_bits_v_comm, ] { - Pcs::write_commitment(comm, transcript); + Pcs::write_commitment(comm, transcript).unwrap(); } ( @@ -2833,16 +2832,16 @@ impl> SNARK< &self.perm_exec_w3_comm, &self.perm_exec_w3_shifted_comm, ] { - Pcs::write_commitment(comm, transcript); + Pcs::write_commitment(comm, transcript).unwrap(); } for comm in self.block_w2_comm_list.iter() { - Pcs::write_commitment(comm, transcript); + Pcs::write_commitment(comm, transcript).unwrap(); } for comm in self.block_w3_comm_list.iter() { - Pcs::write_commitment(comm, transcript); + Pcs::write_commitment(comm, transcript).unwrap(); } for comm in self.block_w3_shifted_comm_list.iter() { - Pcs::write_commitment(comm, transcript); + Pcs::write_commitment(comm, transcript).unwrap(); } let (init_phy_mem_w2_verifier, init_phy_mem_w3_verifier, init_phy_mem_w3_shifted_verifier) = { @@ -2903,9 +2902,9 @@ impl> SNARK< let (vir_mem_addr_w2_verifier, vir_mem_addr_w3_verifier, vir_mem_addr_w3_shifted_verifier) = { if total_num_vir_mem_accesses > 0 { - Pcs::write_commitment(self.vir_mem_addr_w2_comm.as_ref().clone().expect("valid commitment expected"), transcript); - Pcs::write_commitment(self.vir_mem_addr_w3_comm.as_ref().clone().expect("valid commitment expected"), transcript); - Pcs::write_commitment(self.vir_mem_addr_w3_shifted_comm.as_ref().clone().expect("valid commitment expected"), transcript); + Pcs::write_commitment(self.vir_mem_addr_w2_comm.as_ref().clone().expect("valid commitment expected"), transcript).unwrap(); + Pcs::write_commitment(self.vir_mem_addr_w3_comm.as_ref().clone().expect("valid commitment expected"), transcript).unwrap(); + Pcs::write_commitment(self.vir_mem_addr_w3_shifted_comm.as_ref().clone().expect("valid commitment expected"), transcript).unwrap(); ( VerifierWitnessSecInfo::new(vec![VIR_MEM_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.vir_mem_addr_w2_comm.clone().expect("valid commitment expected")]), @@ -2924,9 +2923,9 @@ impl> SNARK< let (block_vars_verifier, exec_inputs_verifier) = { // add the commitment to the verifier's transcript for comm in self.block_vars_comm_list.iter() { - Pcs::write_commitment(comm, transcript); + Pcs::write_commitment(comm, transcript).unwrap(); } - Pcs::write_commitment(&self.exec_inputs_comm, transcript); + Pcs::write_commitment(&self.exec_inputs_comm, transcript).unwrap(); ( VerifierWitnessSecInfo::new(block_num_vars, &block_num_proofs, self.block_vars_comm_list.clone()), @@ -3007,8 +3006,8 @@ impl> SNARK< let (addr_phy_mems_verifier, addr_phy_mems_shifted_verifier) = { if total_num_phy_mem_accesses > 0 { - Pcs::write_commitment(&self.addr_phy_mems_comm.clone().expect("valid commitment expected"), transcript); - Pcs::write_commitment(&self.addr_phy_mems_shifted_comm.clone().expect("valid commitment expected"), transcript); + Pcs::write_commitment(&self.addr_phy_mems_comm.clone().expect("valid commitment expected"), transcript).unwrap(); + Pcs::write_commitment(&self.addr_phy_mems_shifted_comm.clone().expect("valid commitment expected"), transcript).unwrap(); ( VerifierWitnessSecInfo::new(vec![PHY_MEM_WIDTH], &vec![total_num_phy_mem_accesses], vec![self.addr_phy_mems_comm.clone().expect("commitment should exist")]), @@ -3024,9 +3023,9 @@ impl> SNARK< let (addr_vir_mems_verifier, addr_vir_mems_shifted_verifier, addr_ts_bits_verifier) = { if total_num_vir_mem_accesses > 0 { - Pcs::write_commitment(self.addr_vir_mems_comm.clone().as_ref().expect("valid commitment expected"), transcript); - Pcs::write_commitment(self.addr_vir_mems_shifted_comm.clone().as_ref().expect("valid commitment expected"), transcript); - Pcs::write_commitment(self.addr_ts_bits_comm.clone().as_ref().expect("valid commitment expected"), transcript); + Pcs::write_commitment(self.addr_vir_mems_comm.clone().as_ref().expect("valid commitment expected"), transcript).unwrap(); + Pcs::write_commitment(self.addr_vir_mems_shifted_comm.clone().as_ref().expect("valid commitment expected"), transcript).unwrap(); + Pcs::write_commitment(self.addr_ts_bits_comm.clone().as_ref().expect("valid commitment expected"), transcript).unwrap(); ( VerifierWitnessSecInfo::new(vec![VIR_MEM_WIDTH], &vec![total_num_vir_mem_accesses], vec![self.addr_vir_mems_comm.clone().expect("commitment should exist")]), From f8d5950c6ae8595e567bd5e26afa0c3801c4876c Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 9 Feb 2025 18:41:03 -0500 Subject: [PATCH 07/18] Consolidate changes on PCS --- spartan_parallel/src/dense_mlpoly.rs | 163 +-------------------- spartan_parallel/src/lib.rs | 197 +++++++++++++------------- spartan_parallel/src/r1csinstance.rs | 67 ++++----- spartan_parallel/src/sparse_mlpoly.rs | 179 ++++++++++++++--------- 4 files changed, 254 insertions(+), 352 deletions(-) diff --git a/spartan_parallel/src/dense_mlpoly.rs b/spartan_parallel/src/dense_mlpoly.rs index 781bcf90..78385690 100644 --- a/spartan_parallel/src/dense_mlpoly.rs +++ b/spartan_parallel/src/dense_mlpoly.rs @@ -1,5 +1,6 @@ #![allow(clippy::too_many_arguments)] use ff_ext::ExtensionField; +use multilinear_extensions::mle::DenseMultilinearExtension; use super::errors::ProofVerifyError; use super::math::Math; @@ -346,6 +347,10 @@ impl DensePolynomial { .collect::>(), ) } + + pub fn to_ceno_mle(&self) -> DenseMultilinearExtension { + DenseMultilinearExtension::from_evaluation_vec_smart(self.num_vars, self.Z.clone()) + } } impl Index for DensePolynomial { @@ -357,164 +362,6 @@ impl Index for DensePolynomial { } } -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct PolyEvalProof { - _phantom: E, -} - -impl PolyEvalProof { - fn protocol_name() -> &'static [u8] { - b"polynomial evaluation proof" - } - - pub fn prove( - _poly: &DensePolynomial, - _r: &[E], // point at which the polynomial is evaluated - _Zr: &E, // evaluation of \widetilde{Z}(r) - _transcript: &mut Transcript, - _random_tape: &mut RandomTape, - ) -> PolyEvalProof { - // TODO: Alternative evaluation proof scheme - PolyEvalProof { - _phantom: E::ZERO, - } - } - - pub fn verify( - &self, - _transcript: &mut Transcript, - _r: &[E], // point at which the polynomial is evaluated - ) -> Result<(), ProofVerifyError> { - // TODO: Alternative evaluation proof scheme - Ok(()) - } - - pub fn verify_plain( - &self, - transcript: &mut Transcript, - r: &[E], // point at which the polynomial is evaluated - _Zr: &E, // evaluation \widetilde{Z}(r) - ) -> Result<(), ProofVerifyError> { - self.verify(transcript, r) - } - - // Evaluation of multiple points on the same instance - pub fn prove_batched_points( - _poly: &DensePolynomial, - _r_list: Vec>, // point at which the polynomial is evaluated - _Zr_list: Vec, // evaluation of \widetilde{Z}(r) on each point - _transcript: &mut Transcript, - _random_tape: &mut RandomTape, - ) -> Vec> { - // TODO: Alternative evaluation proof scheme - vec![] - } - - pub fn verify_plain_batched_points( - _proof_list: &Vec>, - _transcript: &mut Transcript, - _r_list: Vec>, // point at which the polynomial is evaluated - _Zr_list: Vec, // commitment to \widetilde{Z}(r) on each point - ) -> Result<(), ProofVerifyError> { - // TODO: Alternative evaluation proof scheme - Ok(()) - } - - // Evaluation on multiple instances, each at different point - // Size of each instance might be different, but all are larger than the evaluation point - pub fn prove_batched_instances( - _poly_list: &Vec>, // list of instances - _r_list: Vec<&Vec>, // point at which the polynomial is evaluated - _Zr_list: &Vec, // evaluation of \widetilde{Z}(r) on each instance - _transcript: &mut Transcript, - _random_tape: &mut RandomTape, - ) -> Vec> { - // TODO: Alternative evaluation proof scheme - vec![] - } - - pub fn verify_plain_batched_instances( - _proof_list: &Vec>, - _transcript: &mut Transcript, - _r_list: Vec<&Vec>, // point at which the polynomial is evaluated - _Zr_list: &Vec, // commitment to \widetilde{Z}(r) of each instance - _num_vars_list: &Vec, // size of each polynomial - ) -> Result<(), ProofVerifyError> { - // TODO: Alternative evaluation proof scheme - Ok(()) - } - - // Like prove_batched_instances, but r is divided into rq ++ ry - // Each polynomial is supplemented with num_proofs and num_inputs - pub fn prove_batched_instances_disjoint_rounds( - _poly_list: &Vec<&DensePolynomial>, - _num_proofs_list: &Vec, - _num_inputs_list: &Vec, - _rq: &[E], - _ry: &[E], - _Zr_list: &Vec, - _transcript: &mut Transcript, - _random_tape: &mut RandomTape, - ) -> Vec> { - // TODO: Alternative evaluation proof scheme - /* Pad or trim rq and ry in the following sense - let num_vars_q = num_proofs.log_2(); - let num_vars_y = num_inputs.log_2(); - let ry_short = { - if num_vars_y >= ry.len() { - let ry_pad = &vec![zero; num_vars_y - ry.len()]; - [ry_pad, ry].concat() - } - // Else ry_short is the last w.num_inputs[p].log_2() entries of ry - // thus, to obtain the actual ry, need to multiply by (1 - ry2)(1 - ry3)..., which is ry_factors[num_rounds_y - w.num_inputs[p]] - else { - ry[ry.len() - num_vars_y..].to_vec() - } - }; - let rq_short = rq[rq.len() - num_vars_q..].to_vec(); - let r = [rq_short, ry_short.clone()].concat(); - }; - */ - vec![] - } - - pub fn verify_batched_instances_disjoint_rounds( - _proof_list: &Vec>, - _num_proofs_list: &Vec, - _num_inputs_list: &Vec, - _transcript: &mut Transcript, - _rq: &[E], - _ry: &[E], - ) -> Result<(), ProofVerifyError> { - // TODO: Alternative evaluation proof scheme - Ok(()) - } - - // Treat the polynomial(s) as univariate and open on a single point - pub fn prove_uni_batched_instances( - _poly_list: &Vec<&DensePolynomial>, - _r: &E, // point at which the polynomial is evaluated - _Zr: &Vec, // evaluation of \widetilde{Z}(r) - _transcript: &mut Transcript, - _random_tape: &mut RandomTape, - ) -> PolyEvalProof { - // TODO: Alternative evaluation proof scheme - PolyEvalProof { - _phantom: E::ZERO, - } - } - - pub fn verify_uni_batched_instances( - &self, - _transcript: &mut Transcript, - _r: &E, // point at which the polynomial is evaluated - _poly_size: Vec, - ) -> Result<(), ProofVerifyError> { - // TODO: Alternative evaluation proof scheme - Ok(()) - } -} - /* #[cfg(test)] mod tests { diff --git a/spartan_parallel/src/lib.rs b/spartan_parallel/src/lib.rs index e215a1a7..7ec4bf8e 100644 --- a/spartan_parallel/src/lib.rs +++ b/spartan_parallel/src/lib.rs @@ -42,7 +42,7 @@ use std::{ io::Write, }; -use dense_mlpoly::{DensePolynomial, PolyEvalProof}; +use dense_mlpoly::DensePolynomial; use errors::{ProofVerifyError, R1CSError}; use instance::Instance; use itertools::Itertools; @@ -67,8 +67,8 @@ const W3_WIDTH: usize = 8; /// `ComputationCommitment` holds a public preprocessed NP statement (e.g., R1CS) #[derive(Clone, Serialize)] -pub struct ComputationCommitment { - comm: R1CSCommitment, +pub struct ComputationCommitment> { + comm: R1CSCommitment, } /// `ComputationDecommitment` holds information to decommit `ComputationCommitment` @@ -143,16 +143,16 @@ pub type MemsAssignment = Assignment; // IOProofs contains a series of proofs that the committed values match the input and output of the program #[derive(Serialize, Deserialize, Debug)] -struct IOProofs { +struct IOProofs> { // The prover needs to prove: // 1. Input and output block are both valid // 2. Block number of the input and output block are correct // 3. Input and outputs are correct // 4. The constant value of the input is 1 - proofs: Vec>, + proofs: Vec, } -impl IOProofs { +impl> IOProofs { // Given the polynomial in execution order, generate all proofs fn prove( exec_poly_inputs: &DensePolynomial, @@ -171,7 +171,7 @@ impl IOProofs { output_exec_num: usize, transcript: &mut Transcript, random_tape: &mut RandomTape, - ) -> IOProofs { + ) -> IOProofs { let r_len = (num_proofs * num_ios).log_2(); let to_bin_array = |x: usize| { (0..r_len) @@ -203,37 +203,39 @@ impl IOProofs { } input_indices = input_indices[..live_input.len()].to_vec(); + // _debug // batch prove all proofs - let proofs = PolyEvalProof::prove_batched_points( - exec_poly_inputs, - [ - vec![ - 0, // input valid - output_exec_num * num_ios, // output valid - 2, // input block num - output_exec_num * num_ios + 2 + (num_inputs_unpadded - 1), // output block num - output_exec_num * num_ios + 2 + (num_inputs_unpadded - 1) + output_offset - 1, // output correctness - ], - input_indices, // input correctness - ] - .concat() - .iter() - .map(|i| to_bin_array(*i)) - .collect(), - vec![ - vec![ - E::ONE, - E::ONE, - input_block_num, - output_block_num, - output, - ], - live_input, - ] - .concat(), - transcript, - random_tape, - ); + // let proofs = PolyEvalProof::prove_batched_points( + // exec_poly_inputs, + // [ + // vec![ + // 0, // input valid + // output_exec_num * num_ios, // output valid + // 2, // input block num + // output_exec_num * num_ios + 2 + (num_inputs_unpadded - 1), // output block num + // output_exec_num * num_ios + 2 + (num_inputs_unpadded - 1) + output_offset - 1, // output correctness + // ], + // input_indices, // input correctness + // ] + // .concat() + // .iter() + // .map(|i| to_bin_array(*i)) + // .collect(), + // vec![ + // vec![ + // E::ONE, + // E::ONE, + // input_block_num, + // output_block_num, + // output, + // ], + // live_input, + // ] + // .concat(), + // transcript, + // random_tape, + // ); + let proofs = vec![]; IOProofs { proofs } } @@ -284,36 +286,38 @@ impl IOProofs { } input_indices = input_indices[..live_input.len()].to_vec(); + // _debug // batch verify all proofs - PolyEvalProof::verify_plain_batched_points( - &self.proofs, - transcript, - [ - vec![ - 0, // input valid - output_exec_num * num_ios, // output valid - 2, // input block num - output_exec_num * num_ios + 2 + (num_inputs_unpadded - 1), // output block num - output_exec_num * num_ios + 2 + (num_inputs_unpadded - 1) + output_offset - 1, // output correctness - ], - input_indices, // input correctness - ] - .concat() - .iter() - .map(|i| to_bin_array(*i)) - .collect(), - vec![ - vec![ - E::ONE, - E::ONE, - input_block_num, - output_block_num, - output, - ], - live_input, - ] - .concat(), - ) + // PolyEvalProof::verify_plain_batched_points( + // &self.proofs, + // transcript, + // [ + // vec![ + // 0, // input valid + // output_exec_num * num_ios, // output valid + // 2, // input block num + // output_exec_num * num_ios + 2 + (num_inputs_unpadded - 1), // output block num + // output_exec_num * num_ios + 2 + (num_inputs_unpadded - 1) + output_offset - 1, // output correctness + // ], + // input_indices, // input correctness + // ] + // .concat() + // .iter() + // .map(|i| to_bin_array(*i)) + // .collect(), + // vec![ + // vec![ + // E::ONE, + // E::ONE, + // input_block_num, + // output_block_num, + // output, + // ], + // live_input, + // ] + // .concat(), + // ) + Ok(()) } } @@ -321,11 +325,11 @@ impl IOProofs { // We do so by treating both polynomials as univariate and evaluate on a single point C // Finally, show shifted(C) = orig(C) * C^(shift_size) + rc * openings, where rc * openings are the first few entries of the original poly dot product with the power series of C #[derive(Serialize, Deserialize, Debug)] -struct ShiftProofs { - proof: PolyEvalProof, +struct ShiftProofs> { + proof: Pcs::Proof, } -impl ShiftProofs { +impl> ShiftProofs { fn prove( orig_polys: Vec<&DensePolynomial>, shifted_polys: Vec<&DensePolynomial>, @@ -333,7 +337,7 @@ impl ShiftProofs { header_len_list: Vec, transcript: &mut Transcript, random_tape: &mut RandomTape, - ) -> ShiftProofs { + ) -> ShiftProofs { // Assert that all polynomials are of the same size let num_instances = orig_polys.len(); assert_eq!(num_instances, shifted_polys.len()); @@ -367,13 +371,15 @@ impl ShiftProofs { orig_evals.push(orig_eval); shifted_evals.push(shifted_eval); } - let addr_phy_mems_shift_proof = PolyEvalProof::prove_uni_batched_instances( - &[orig_polys, shifted_polys].concat(), - &c, - &[orig_evals, shifted_evals].concat(), - transcript, - random_tape, - ); + + // _debug + // let addr_phy_mems_shift_proof = PolyEvalProof::prove_uni_batched_instances( + // &[orig_polys, shifted_polys].concat(), + // &c, + // &[orig_evals, shifted_evals].concat(), + // transcript, + // random_tape, + // ); ShiftProofs { proof: addr_phy_mems_shift_proof, @@ -405,12 +411,13 @@ impl ShiftProofs { next_c = next_c * c; } + // _debug // Proof of opening - self.proof.verify_uni_batched_instances( - transcript, - &c, - [poly_size_list.clone(), poly_size_list].concat(), - )?; + // self.proof.verify_uni_batched_instances( + // transcript, + // &c, + // [poly_size_list.clone(), poly_size_list].concat(), + // )?; Ok(()) } } @@ -666,16 +673,16 @@ pub struct SNARK> { block_r1cs_sat_proof: R1CSProof, block_inst_evals_bound_rp: [E; 3], block_inst_evals_list: Vec, - block_r1cs_eval_proof_list: Vec>, + block_r1cs_eval_proof_list: Vec>, pairwise_check_r1cs_sat_proof: R1CSProof, pairwise_check_inst_evals_bound_rp: [E; 3], pairwise_check_inst_evals_list: Vec, - pairwise_check_r1cs_eval_proof: R1CSEvalProof, + pairwise_check_r1cs_eval_proof: R1CSEvalProof, perm_root_r1cs_sat_proof: R1CSProof, perm_root_inst_evals: [E; 3], - perm_root_r1cs_eval_proof: R1CSEvalProof, + perm_root_r1cs_eval_proof: R1CSEvalProof, // Product proof for permutation // perm_poly_poly_list: Vec, // proof_eval_perm_poly_prod_list: Vec>, @@ -808,11 +815,11 @@ impl> SNARK< inst: &Instance, ) -> ( Vec>, - Vec>, + Vec>, Vec>, ) { let timer_encode = Timer::new("SNARK::encode"); - let (label_map, mut comm, mut decomm) = inst.inst.multi_commit(); + let (label_map, mut comm, mut decomm) = R1CSCommitment::multi_commit(&inst.inst); timer_encode.stop(); ( @@ -829,9 +836,9 @@ impl> SNARK< } /// A public computation to create a commitment to a single R1CS instance - pub fn encode(inst: &Instance) -> (ComputationCommitment, ComputationDecommitment) { + pub fn encode(inst: &Instance) -> (ComputationCommitment, ComputationDecommitment) { let timer_encode = Timer::new("SNARK::encode"); - let (comm, decomm) = inst.inst.commit(); + let (comm, decomm) = R1CSCommitment::commit(&inst.inst); timer_encode.stop(); ( @@ -995,7 +1002,7 @@ impl> SNARK< block_num_proofs: &Vec, block_inst: &mut Instance, block_comm_map: &Vec>, - block_comm_list: &Vec>, + block_comm_list: &Vec>, block_decomm_list: &Vec>, consis_num_proofs: usize, @@ -1004,7 +1011,7 @@ impl> SNARK< total_num_phy_mem_accesses: usize, total_num_vir_mem_accesses: usize, pairwise_check_inst: &mut Instance, - pairwise_check_comm: &ComputationCommitment, + pairwise_check_comm: &ComputationCommitment, pairwise_check_decomm: &ComputationDecommitment, block_vars_mat: Vec>>, @@ -1016,7 +1023,7 @@ impl> SNARK< addr_ts_bits_list: Vec>, perm_root_inst: &Instance, - perm_root_comm: &ComputationCommitment, + perm_root_comm: &ComputationCommitment, perm_root_decomm: &ComputationDecommitment, transcript: &mut Transcript, ) -> Self { @@ -2515,7 +2522,7 @@ impl> SNARK< block_num_proofs: &Vec, block_num_cons: usize, block_comm_map: &Vec>, - block_comm_list: &Vec>, + block_comm_list: &Vec>, consis_num_proofs: usize, total_num_init_phy_mem_accesses: usize, @@ -2523,10 +2530,10 @@ impl> SNARK< total_num_phy_mem_accesses: usize, total_num_vir_mem_accesses: usize, pairwise_check_num_cons: usize, - pairwise_check_comm: &ComputationCommitment, + pairwise_check_comm: &ComputationCommitment, perm_root_num_cons: usize, - perm_root_comm: &ComputationCommitment, + perm_root_comm: &ComputationCommitment, transcript: &mut Transcript, ) -> Result<(), ProofVerifyError> { diff --git a/spartan_parallel/src/r1csinstance.rs b/spartan_parallel/src/r1csinstance.rs index 1ed894c3..65d88edd 100644 --- a/spartan_parallel/src/r1csinstance.rs +++ b/spartan_parallel/src/r1csinstance.rs @@ -1,3 +1,4 @@ +use mpcs::PolynomialCommitmentScheme; use rayon::prelude::*; use std::cmp::{max, min}; @@ -36,10 +37,10 @@ pub struct R1CSInstance { } #[derive(Debug, Serialize, Deserialize, Clone)] -pub struct R1CSCommitment { +pub struct R1CSCommitment> { num_cons: usize, num_vars: usize, - comm: SparseMatPolyCommitment, + comm: SparseMatPolyCommitment, } pub struct R1CSDecommitment { @@ -48,7 +49,7 @@ pub struct R1CSDecommitment { dense: MultiSparseMatPolynomialAsDense, } -impl R1CSCommitment { +impl> R1CSCommitment { pub fn get_num_cons(&self) -> usize { self.num_cons } @@ -466,19 +467,21 @@ impl R1CSInstance { pub fn evaluate(&self, rx: &[E], ry: &[E]) -> (E, E, E) { assert_eq!(self.num_instances, 1); - let evals = SparseMatPolynomial::multi_evaluate( + let evals: Vec = SparseMatPolynomial::multi_evaluate( &[&self.A_list[0], &self.B_list[0], &self.C_list[0]], rx, ry, ); (evals[0], evals[1], evals[2]) } +} +impl> R1CSCommitment { pub fn multi_commit( - &self, + instance: &R1CSInstance, ) -> ( Vec>, - Vec>, + Vec>, Vec>, ) { let mut vars_size: HashMap = HashMap::new(); @@ -488,25 +491,25 @@ impl R1CSInstance { let mut max_num_vars_list: Vec = Vec::new(); // Group the instances based on number of variables, which are already orders of 2^4 - for i in 0..self.num_instances { - let var_len = self.num_vars[i]; + for i in 0..instance.num_instances { + let var_len = instance.num_vars[i]; // A_list, B_list, C_list if let Some(index) = vars_size.get(&var_len) { label_map[*index].push(3 * i); - sparse_polys_list[*index].push(&self.A_list[i]); + sparse_polys_list[*index].push(&instance.A_list[i]); label_map[*index].push(3 * i + 1); - sparse_polys_list[*index].push(&self.B_list[i]); + sparse_polys_list[*index].push(&instance.B_list[i]); label_map[*index].push(3 * i + 2); - sparse_polys_list[*index].push(&self.C_list[i]); - max_num_cons_list[*index] = max(max_num_cons_list[*index], self.num_cons[i]); - max_num_vars_list[*index] = max(max_num_vars_list[*index], self.num_vars[i]); + sparse_polys_list[*index].push(&instance.C_list[i]); + max_num_cons_list[*index] = max(max_num_cons_list[*index], instance.num_cons[i]); + max_num_vars_list[*index] = max(max_num_vars_list[*index], instance.num_vars[i]); } else { let next_label = vars_size.len(); vars_size.insert(var_len, next_label); label_map.push(vec![3 * i, 3 * i + 1, 3 * i + 2]); - sparse_polys_list.push(vec![&self.A_list[i], &self.B_list[i], &self.C_list[i]]); - max_num_cons_list.push(self.num_cons[i]); - max_num_vars_list.push(self.num_vars[i]); + sparse_polys_list.push(vec![&instance.A_list[i], &instance.B_list[i], &instance.C_list[i]]); + max_num_cons_list.push(instance.num_cons[i]); + max_num_vars_list.push(instance.num_vars[i]); } } @@ -515,7 +518,7 @@ impl R1CSInstance { for ((sparse_polys, max_num_cons), max_num_vars) in zip(zip(sparse_polys_list, max_num_cons_list), max_num_vars_list) { - let (comm, dense) = SparseMatPolynomial::multi_commit(&sparse_polys); + let (comm, dense) = SparseMatPolyCommitment::multi_commit(&sparse_polys); let r1cs_comm = R1CSCommitment { num_cons: max_num_cons.next_power_of_two(), num_vars: max_num_vars, @@ -535,24 +538,24 @@ impl R1CSInstance { } // Used if there is only one instance - pub fn commit(&self) -> (R1CSCommitment, R1CSDecommitment) { + pub fn commit(instance: &R1CSInstance) -> (R1CSCommitment, R1CSDecommitment) { let mut sparse_polys = Vec::new(); - for i in 0..self.num_instances { - sparse_polys.push(&self.A_list[i]); - sparse_polys.push(&self.B_list[i]); - sparse_polys.push(&self.C_list[i]); + for i in 0..instance.num_instances { + sparse_polys.push(&instance.A_list[i]); + sparse_polys.push(&instance.B_list[i]); + sparse_polys.push(&instance.C_list[i]); } - let (comm, dense) = SparseMatPolynomial::multi_commit(&sparse_polys); + let (comm, dense) = SparseMatPolyCommitment::multi_commit(&sparse_polys); let r1cs_comm = R1CSCommitment { - num_cons: self.num_instances * self.max_num_cons, - num_vars: self.max_num_vars, + num_cons: instance.num_instances * instance.max_num_cons, + num_vars: instance.max_num_vars, comm, }; let r1cs_decomm = R1CSDecommitment { - num_cons: self.num_instances * self.max_num_cons, - num_vars: self.max_num_vars, + num_cons: instance.num_instances * instance.max_num_cons, + num_vars: instance.max_num_vars, dense, }; @@ -561,11 +564,11 @@ impl R1CSInstance { } #[derive(Debug, Serialize, Deserialize)] -pub struct R1CSEvalProof { - proof: SparseMatPolyEvalProof, +pub struct R1CSEvalProof> { + proof: SparseMatPolyEvalProof, } -impl R1CSEvalProof { +impl> R1CSEvalProof { // If is BLOCK, separate the first 3 entries of ry out (corresponding to the 5 segments of witnesses) pub fn prove( decomm: &R1CSDecommitment, @@ -574,7 +577,7 @@ impl R1CSEvalProof { evals: &Vec, transcript: &mut Transcript, random_tape: &mut RandomTape, - ) -> R1CSEvalProof { + ) -> R1CSEvalProof { let timer = Timer::new("R1CSEvalProof::prove"); let rx_skip_len = rx.len() - min(rx.len(), decomm.num_cons.log_2()); let rx_header = rx[..rx_skip_len] @@ -614,7 +617,7 @@ impl R1CSEvalProof { pub fn verify( &self, - comm: &R1CSCommitment, + comm: &R1CSCommitment, rx: &[E], // point at which the R1CS matrix polynomials are evaluated ry: &[E], evals: &Vec, diff --git a/spartan_parallel/src/sparse_mlpoly.rs b/spartan_parallel/src/sparse_mlpoly.rs index ba457902..79c7a56c 100644 --- a/spartan_parallel/src/sparse_mlpoly.rs +++ b/spartan_parallel/src/sparse_mlpoly.rs @@ -2,8 +2,10 @@ #![allow(clippy::too_many_arguments)] #![allow(clippy::needless_range_loop)] use ff_ext::ExtensionField; +use mpcs::PolynomialCommitmentScheme; +use multilinear_extensions::mle::DenseMultilinearExtension; use super::dense_mlpoly::DensePolynomial; -use super::dense_mlpoly::{EqPolynomial, IdentityPolynomial, PolyEvalProof}; +use super::dense_mlpoly::{EqPolynomial, IdentityPolynomial}; use super::errors::ProofVerifyError; use super::math::Math; use super::product_tree::{DotProductCircuit, ProductCircuit, ProductCircuitEvalProofBatched}; @@ -39,6 +41,11 @@ pub struct Derefs { comb: DensePolynomial, } +#[derive(Debug, Serialize, Deserialize)] +pub struct DerefsCommitment> { + comm_ops_val: Pcs::Commitment +} + impl Derefs { pub fn new(row_ops_val: Vec>, col_ops_val: Vec>) -> Self { assert_eq!(row_ops_val.len(), col_ops_val.len()); @@ -61,12 +68,23 @@ impl Derefs { } } +impl> DerefsCommitment { + pub fn commit(derefs: &Derefs) -> DerefsCommitment { + let l = derefs.comb.len(); + let param = Pcs::setup(l).unwrap(); + let (pp, _vp) = Pcs::trim(param, l).unwrap(); + let comm_ops_val = Pcs::get_pure_commitment(&Pcs::commit(&pp, &derefs.comb.to_ceno_mle()).expect("Commitment should not fail")); + DerefsCommitment { comm_ops_val } + } +} + #[derive(Debug, Serialize, Deserialize)] -pub struct DerefsEvalProof { - proof_derefs: PolyEvalProof, +pub struct DerefsEvalProof> { + comm_derefs: Pcs::CommitmentWithWitness, + proof_derefs: Pcs::Proof, } -impl DerefsEvalProof { +impl> DerefsEvalProof { fn protocol_name() -> &'static [u8] { b"Derefs evaluation proof" } @@ -77,7 +95,7 @@ impl DerefsEvalProof { evals: Vec, transcript: &mut Transcript, random_tape: &mut RandomTape, - ) -> PolyEvalProof { + ) -> Self { assert_eq!(joint_poly.get_num_vars(), r.len() + evals.len().log_2()); // append the claimed evaluations to transcript @@ -102,10 +120,16 @@ impl DerefsEvalProof { // decommit the joint polynomial at r_joint append_field_to_transcript(b"joint_claim_eval", transcript, eval_joint); - let proof_derefs = - PolyEvalProof::prove(joint_poly, &r_joint, &eval_joint, transcript, random_tape); + let l: usize = 1 << joint_poly.get_num_vars(); + let mle = joint_poly.to_ceno_mle(); + let (pp, _vp) = Pcs::trim(Pcs::setup(l), l); + let comm_derefs = Pcs::commit(&pp, &mle).expect("Commit should not fail."); + let proof_derefs = Pcs::open(pp, &mle, comm_derefs, &r_joint, &eval_joint, transcript).expect("Proof should not fail"); - proof_derefs + Self { + comm_derefs, + proof_derefs, + } } // evalues both polynomials at r and produces a joint proof of opening @@ -119,7 +143,7 @@ impl DerefsEvalProof { ) -> Self { append_protocol_name( transcript, - DerefsEvalProof::::protocol_name(), + DerefsEvalProof::::protocol_name(), ); let evals = { @@ -128,14 +152,12 @@ impl DerefsEvalProof { evals.resize(evals.len().next_power_of_two(), E::ZERO); evals }; - let proof_derefs = - DerefsEvalProof::prove_single(&derefs.comb, r, evals, transcript, random_tape); - - DerefsEvalProof { proof_derefs } + + DerefsEvalProof::::prove_single(&derefs.comb, r, evals, transcript, random_tape) } fn verify_single( - proof: &PolyEvalProof, + proof: &Self, r: &[E], evals: Vec, transcript: &mut Transcript, @@ -158,7 +180,9 @@ impl DerefsEvalProof { // decommit the joint polynomial at r_joint append_field_to_transcript(b"joint_claim_eval", transcript, joint_claim_eval); - proof.verify_plain(transcript, &r_joint, &joint_claim_eval) + let l: usize = 1 << poly_evals.get_num_vars(); + let (_pp, vp) = Pcs::trim(Pcs::setup(l), l); + Pcs::verify(vp, proof.comm_derefs, &r_joint, &joint_claim_eval, proof.proof_derefs, transcript) } // verify evaluations of both polynomials at r @@ -171,14 +195,14 @@ impl DerefsEvalProof { ) -> Result<(), ProofVerifyError> { append_protocol_name( transcript, - DerefsEvalProof::::protocol_name(), + DerefsEvalProof::::protocol_name(), ); let mut evals = eval_row_ops_val_vec.to_owned(); evals.extend(eval_col_ops_val_vec); evals.resize(evals.len().next_power_of_two(), E::ZERO); - DerefsEvalProof::verify_single(&self.proof_derefs, r, evals, transcript) + DerefsEvalProof::::verify_single(&self.proof_derefs, r, evals, transcript) } } @@ -251,17 +275,21 @@ pub struct MultiSparseMatPolynomialAsDense { col: AddrTimestamps, comb_ops: DensePolynomial, comb_mem: DensePolynomial, + comb_ops_ceno_mle: DenseMultilinearExtension, + comb_mem_ceno_mle: DenseMultilinearExtension, } #[derive(Debug, Serialize, Deserialize, Clone)] -pub struct SparseMatPolyCommitment { +pub struct SparseMatPolyCommitment> { batch_size: usize, num_ops: usize, num_mem_cells: usize, + comm_comb_ops: Pcs::Commitment, + comm_comb_mem: Pcs::Commitment, _phantom: E, } -impl SparseMatPolyCommitment { +impl> SparseMatPolyCommitment { pub fn append_to_transcript(&self, _label: &'static [u8], transcript: &mut Transcript) { append_field_to_transcript(b"batch_size", transcript, E::from(self.batch_size as u64)); append_field_to_transcript(b"num_ops", transcript, E::from(self.num_ops as u64)); @@ -348,6 +376,9 @@ impl SparseMatPolynomial { let mut comb_mem = row.audit_ts.clone(); comb_mem.extend(&col.audit_ts); + let comb_ops_ceno_mle = comb_ops.clone().to_ceno_mle(); + let comb_mem_ceno_mle = comb_mem.clone().to_ceno_mle(); + MultiSparseMatPolynomialAsDense { batch_size: sparse_polys.len(), row: ret_row, @@ -355,10 +386,11 @@ impl SparseMatPolynomial { val: ret_val_vec, comb_ops, comb_mem, + comb_ops_ceno_mle, + comb_mem_ceno_mle, } } - fn evaluate_with_tables(&self, eval_table_rx: &[E], eval_table_ry: &[E]) -> E { assert_eq!(self.num_vars_x.pow2(), eval_table_rx.len()); assert_eq!(self.num_vars_y.pow2(), eval_table_ry.len()); @@ -454,21 +486,34 @@ impl SparseMatPolynomial { } M_evals } +} +impl> SparseMatPolyCommitment { pub fn multi_commit( sparse_polys: &[&SparseMatPolynomial], ) -> ( - SparseMatPolyCommitment, + SparseMatPolyCommitment, MultiSparseMatPolynomialAsDense, ) { let batch_size = sparse_polys.len(); let dense = SparseMatPolynomial::multi_sparse_to_dense_rep(sparse_polys); + let l_ops = dense.comb_ops.len(); + let l_mem = dense.comb_mem.len(); + + let (p_ops, _) = Pcs::trim(Pcs::setup(l_ops).expect("Param setup should not fail"), l_ops).unwrap(); + let (p_mem, _) = Pcs::trim(Pcs::setup(l_mem).expect("Param setup should not fail"), l_mem).unwrap(); + + let comm_comb_ops = Pcs::get_pure_commitment(&Pcs::commit(&p_ops, &dense.comb_ops_ceno_mle).expect("Commit should not fail")); + let comm_comb_mem = Pcs::get_pure_commitment(&Pcs::commit(&p_mem, &dense.comb_mem_ceno_mle).expect("Commit should not fail")); + ( SparseMatPolyCommitment { batch_size, num_mem_cells: dense.row.audit_ts.len(), num_ops: dense.row.read_ts[0].len(), + comm_comb_ops, + comm_comb_mem, _phantom: E::ZERO, }, dense, @@ -480,7 +525,6 @@ impl MultiSparseMatPolynomialAsDense { pub fn deref(&self, row_mem_val: &[E], col_mem_val: &[E]) -> Derefs { let row_ops_val = self.row.deref(row_mem_val); let col_ops_val = self.col.deref(col_mem_val); - Derefs::new(row_ops_val, col_ops_val) } } @@ -649,17 +693,19 @@ impl PolyEvalNetwork { } #[derive(Debug, Serialize, Deserialize)] -pub struct HashLayerProof { +pub struct HashLayerProof> { eval_row: (Vec, Vec, E), eval_col: (Vec, Vec, E), eval_val: Vec, eval_derefs: (Vec, Vec), - pub proof_ops: PolyEvalProof, - pub proof_mem: PolyEvalProof, - pub proof_derefs: DerefsEvalProof, + pub comm_ops: Pcs::CommitmentWithWitness, + pub proof_ops: Pcs::Proof, + pub comm_mem: Pcs::CommitmentWithWitness, + pub proof_mem: Pcs::Proof, + pub proof_derefs: DerefsEvalProof, } -impl HashLayerProof { +impl> HashLayerProof { fn protocol_name() -> &'static [u8] { b"Sparse polynomial hash layer proof" } @@ -699,7 +745,7 @@ impl HashLayerProof { ) -> Self { append_protocol_name( transcript, - HashLayerProof::::protocol_name(), + HashLayerProof::::protocol_name(), ); let (rand_mem, rand_ops) = rand; @@ -754,13 +800,11 @@ impl HashLayerProof { debug_assert_eq!(dense.comb_ops.evaluate(&r_joint_ops), joint_claim_eval_ops); append_field_to_transcript(b"joint_claim_eval_ops", transcript, joint_claim_eval_ops); - let proof_ops = PolyEvalProof::prove( - &dense.comb_ops, - &r_joint_ops, - &joint_claim_eval_ops, - transcript, - random_tape, - ); + let l: usize = 1 << dense.comb_ops.get_num_vars(); + let mle = dense.comb_ops.clone().to_ceno_mle(); + let (pp, _vp) = Pcs::trim(Pcs::setup(l), l); + let comm_ops = Pcs::commit(&pp, &mle).expect("Commit should not fail."); + let proof_ops = Pcs::open(pp, &mle, comm_ops, &r_joint_ops, &joint_claim_eval_ops, transcript).expect("Proof should not fail"); // form a single decommitment using comb_comb_mem at rand_mem let evals_mem: Vec = vec![eval_row_audit_ts, eval_col_audit_ts]; @@ -779,20 +823,20 @@ impl HashLayerProof { debug_assert_eq!(dense.comb_mem.evaluate(&r_joint_mem), joint_claim_eval_mem); append_field_to_transcript(b"joint_claim_eval_mem", transcript, joint_claim_eval_mem); - let proof_mem = PolyEvalProof::prove( - &dense.comb_mem, - &r_joint_mem, - &joint_claim_eval_mem, - transcript, - random_tape, - ); + let l: usize = 1 << dense.comb_mem.get_num_vars(); + let mle = dense.comb_mem.clone().to_ceno_mle(); + let (pp, _vp) = Pcs::trim(Pcs::setup(l), l); + let comm_mem = Pcs::commit(&pp, &mle).expect("Commit should not fail."); + let proof_mem = Pcs::open(pp, &mle, comm_mem, &r_joint_mem, &joint_claim_eval_mem, transcript).expect("Proof should not fail"); HashLayerProof { eval_row: (eval_row_addr_vec, eval_row_read_ts_vec, eval_row_audit_ts), eval_col: (eval_col_addr_vec, eval_col_read_ts_vec, eval_col_audit_ts), eval_val: eval_val_vec, eval_derefs, + comm_ops, proof_ops, + comm_mem, proof_mem, proof_derefs, } @@ -857,7 +901,7 @@ impl HashLayerProof { claims_row: &(E, Vec, Vec, E), claims_col: &(E, Vec, Vec, E), claims_dotp: &[E], - _comm: &SparseMatPolyCommitment, + _comm: &SparseMatPolyCommitment, rx: &[E], ry: &[E], r_hash: &E, @@ -867,7 +911,7 @@ impl HashLayerProof { let timer = Timer::new("verify_hash_proof"); append_protocol_name( transcript, - HashLayerProof::::protocol_name(), + HashLayerProof::::protocol_name(), ); let (rand_mem, rand_ops) = rand; @@ -916,9 +960,10 @@ impl HashLayerProof { let mut r_joint_ops = challenges_ops; r_joint_ops.extend(rand_ops); append_field_to_transcript(b"joint_claim_eval_ops", transcript, joint_claim_eval_ops); - self - .proof_ops - .verify_plain(transcript, &r_joint_ops, &joint_claim_eval_ops)?; + + let l: usize = 1 << poly_evals_ops.get_num_vars(); + let (_pp, vp) = Pcs::trim(Pcs::setup(l), l); + Pcs::verify(vp, self.comm_ops, &r_joint_ops, &joint_claim_eval_ops, self.proof_ops, transcript)?; // verify proof-mem using comm_comb_mem at rand_mem // form a single decommitment using comb_comb_mem at rand_mem @@ -937,13 +982,13 @@ impl HashLayerProof { r_joint_mem.extend(rand_mem); append_field_to_transcript(b"joint_claim_eval_mem", transcript, joint_claim_eval_mem); - self - .proof_mem - .verify_plain(transcript, &r_joint_mem, &joint_claim_eval_mem)?; + let l: usize = 1 << poly_evals_mem.get_num_vars(); + let (_pp, vp) = Pcs::trim(Pcs::setup(l), l); + Pcs::verify(vp, self.comm_mem, &r_joint_mem, &joint_claim_eval_mem, self.proof_mem, transcript)?; // verify the claims from the product layer let (eval_ops_addr, eval_read_ts, eval_audit_ts) = &self.eval_row; - HashLayerProof::verify_helper( + HashLayerProof::::verify_helper( &(rand_mem, rand_ops), claims_row, eval_row_ops_val, @@ -956,7 +1001,7 @@ impl HashLayerProof { )?; let (eval_ops_addr, eval_read_ts, eval_audit_ts) = &self.eval_col; - HashLayerProof::verify_helper( + HashLayerProof::::verify_helper( &(rand_mem, rand_ops), claims_col, eval_col_ops_val, @@ -1231,12 +1276,12 @@ impl ProductLayerProof { } #[derive(Debug, Serialize, Deserialize)] -pub struct PolyEvalNetworkProof { +pub struct PolyEvalNetworkProof> { pub proof_prod_layer: ProductLayerProof, - pub proof_hash_layer: HashLayerProof, + pub proof_hash_layer: HashLayerProof, } -impl PolyEvalNetworkProof { +impl> PolyEvalNetworkProof { fn protocol_name() -> &'static [u8] { b"Sparse polynomial evaluation proof" } @@ -1252,7 +1297,7 @@ impl PolyEvalNetworkProof { ) -> Self { append_protocol_name( transcript, - PolyEvalNetworkProof::::protocol_name(), + PolyEvalNetworkProof::::protocol_name(), ); let (proof_prod_layer, rand_mem, rand_ops) = ProductLayerProof::prove( @@ -1282,7 +1327,7 @@ impl PolyEvalNetworkProof { pub fn verify( &self, - comm: &SparseMatPolyCommitment, + comm: &SparseMatPolyCommitment, r_header: E, evals: &[E], rx: &[E], @@ -1294,7 +1339,7 @@ impl PolyEvalNetworkProof { let timer = Timer::new("verify_polyeval_proof"); append_protocol_name( transcript, - PolyEvalNetworkProof::::protocol_name(), + PolyEvalNetworkProof::::protocol_name(), ); let num_instances = evals.len(); @@ -1345,11 +1390,11 @@ impl PolyEvalNetworkProof { } #[derive(Debug, Serialize, Deserialize)] -pub struct SparseMatPolyEvalProof { - pub poly_eval_network_proof: PolyEvalNetworkProof, +pub struct SparseMatPolyEvalProof> { + pub poly_eval_network_proof: PolyEvalNetworkProof, } -impl SparseMatPolyEvalProof { +impl> SparseMatPolyEvalProof { fn protocol_name() -> &'static [u8] { b"Sparse polynomial evaluation proof" } @@ -1380,10 +1425,10 @@ impl SparseMatPolyEvalProof { evals: &[E], // a vector evaluation of \widetilde{M}(r = (rx,ry)) for each M transcript: &mut Transcript, random_tape: &mut RandomTape, - ) -> SparseMatPolyEvalProof { + ) -> SparseMatPolyEvalProof { append_protocol_name( transcript, - SparseMatPolyEvalProof::::protocol_name(), + SparseMatPolyEvalProof::::protocol_name(), ); // ensure there is one eval for each polynomial in dense @@ -1391,7 +1436,7 @@ impl SparseMatPolyEvalProof { let (mem_rx, mem_ry) = { // equalize the lengths of rx and ry - let (rx_ext, ry_ext) = SparseMatPolyEvalProof::equalize(rx, ry); + let (rx_ext, ry_ext) = SparseMatPolyEvalProof::::equalize(rx, ry); let poly_rx = EqPolynomial::new(rx_ext).evals(); let poly_ry = EqPolynomial::new(ry_ext).evals(); (poly_rx, poly_ry) @@ -1441,7 +1486,7 @@ impl SparseMatPolyEvalProof { pub fn verify( &self, - comm: &SparseMatPolyCommitment, + comm: &SparseMatPolyCommitment, r_header: E, rx: &[E], // point at which the polynomial is evaluated ry: &[E], @@ -1450,11 +1495,11 @@ impl SparseMatPolyEvalProof { ) -> Result<(), ProofVerifyError> { append_protocol_name( transcript, - SparseMatPolyEvalProof::::protocol_name(), + SparseMatPolyEvalProof::::protocol_name(), ); // equalize the lengths of rx and ry - let (rx_ext, ry_ext) = SparseMatPolyEvalProof::equalize(rx, ry); + let (rx_ext, ry_ext) = SparseMatPolyEvalProof::::equalize(rx, ry); let (nz, num_mem_cells) = (comm.num_ops, comm.num_mem_cells); assert_eq!(rx_ext.len().pow2(), num_mem_cells); From 148e583d3500d6378d51882aaf0e448f7dbd6d42 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 9 Feb 2025 20:53:20 -0500 Subject: [PATCH 08/18] Compilation fix : --- spartan_parallel/src/sparse_mlpoly.rs | 38 ++++++++++++++++++--------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/spartan_parallel/src/sparse_mlpoly.rs b/spartan_parallel/src/sparse_mlpoly.rs index 79c7a56c..8dba77d5 100644 --- a/spartan_parallel/src/sparse_mlpoly.rs +++ b/spartan_parallel/src/sparse_mlpoly.rs @@ -122,9 +122,9 @@ impl> DerefsEvalProof> DerefsEvalProof Ok(()), + Error => ProofVerifyError::InternalError + } } // verify evaluations of both polynomials at r @@ -202,7 +206,7 @@ impl> DerefsEvalProof::verify_single(&self.proof_derefs, r, evals, transcript) + DerefsEvalProof::::verify_single(&self, r, evals, transcript) } } @@ -802,9 +806,9 @@ impl> HashLayerProof = vec![eval_row_audit_ts, eval_col_audit_ts]; @@ -825,9 +829,9 @@ impl> HashLayerProof> HashLayerProof Ok(()), + Error => ProofVerifyError::InternalError + } // verify proof-mem using comm_comb_mem at rand_mem // form a single decommitment using comb_comb_mem at rand_mem @@ -983,8 +991,12 @@ impl> HashLayerProof Ok(()), + Error => ProofVerifyError::InternalError + } // verify the claims from the product layer let (eval_ops_addr, eval_read_ts, eval_audit_ts) = &self.eval_row; From 065dde7e3874d68a5be3d2bfbb111eb6911ce5c0 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 9 Feb 2025 21:09:14 -0500 Subject: [PATCH 09/18] Compilation Fix --- spartan_parallel/src/sparse_mlpoly.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/spartan_parallel/src/sparse_mlpoly.rs b/spartan_parallel/src/sparse_mlpoly.rs index 8dba77d5..d3fa3fd1 100644 --- a/spartan_parallel/src/sparse_mlpoly.rs +++ b/spartan_parallel/src/sparse_mlpoly.rs @@ -185,7 +185,7 @@ impl> DerefsEvalProof Ok(()), - Error => ProofVerifyError::InternalError + Err(e) => Err(ProofVerifyError::InternalError) } } @@ -774,9 +774,9 @@ impl> HashLayerProof::prove_helper((rand_mem, rand_ops), &dense.row); let (eval_col_addr_vec, eval_col_read_ts_vec, eval_col_audit_ts) = - HashLayerProof::prove_helper((rand_mem, rand_ops), &dense.col); + HashLayerProof::::prove_helper((rand_mem, rand_ops), &dense.col); let eval_val_vec = (0..dense.val.len()) .map(|i| dense.val[i].evaluate(rand_ops)) .collect::>(); @@ -970,7 +970,7 @@ impl> HashLayerProof Ok(()), - Error => ProofVerifyError::InternalError + Err(e) => Err(ProofVerifyError::InternalError) } // verify proof-mem using comm_comb_mem at rand_mem From 76060afc1247cbf83026422ac8413c863c191bba Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 9 Feb 2025 22:17:04 -0500 Subject: [PATCH 10/18] Adjust verify result handling --- spartan_parallel/src/sparse_mlpoly.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/spartan_parallel/src/sparse_mlpoly.rs b/spartan_parallel/src/sparse_mlpoly.rs index d3fa3fd1..5c7e7bbb 100644 --- a/spartan_parallel/src/sparse_mlpoly.rs +++ b/spartan_parallel/src/sparse_mlpoly.rs @@ -967,11 +967,7 @@ impl> HashLayerProof Ok(()), - Err(e) => Err(ProofVerifyError::InternalError) - } + Pcs::verify(&vp, &Pcs::get_pure_commitment(&self.comm_ops), &r_joint_ops, &joint_claim_eval_ops, &self.proof_ops, transcript).map_err(|e| ProofVerifyError::InternalError)?; // verify proof-mem using comm_comb_mem at rand_mem // form a single decommitment using comb_comb_mem at rand_mem @@ -992,11 +988,7 @@ impl> HashLayerProof Ok(()), - Error => ProofVerifyError::InternalError - } + Pcs::verify(&vp, &Pcs::get_pure_commitment(&self.comm_mem), &r_joint_mem, &joint_claim_eval_mem, &self.proof_mem, transcript).map_err(|e| ProofVerifyError::InternalError)?; // verify the claims from the product layer let (eval_ops_addr, eval_read_ts, eval_audit_ts) = &self.eval_row; From ba08bb993693afac4bf0569acf7903087c530cae Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 10 Feb 2025 18:10:58 -0500 Subject: [PATCH 11/18] Temporarily remove serialization requirements --- spartan_parallel/src/lib.rs | 36 ++++++++++++++++++++++----- spartan_parallel/src/r1csinstance.rs | 14 ++++++++++- spartan_parallel/src/sparse_mlpoly.rs | 4 +-- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/spartan_parallel/src/lib.rs b/spartan_parallel/src/lib.rs index 7ec4bf8e..c6f9f109 100644 --- a/spartan_parallel/src/lib.rs +++ b/spartan_parallel/src/lib.rs @@ -39,7 +39,7 @@ mod unipoly; use std::{ cmp::{max, Ordering}, fs::File, - io::Write, + io::Write, marker::PhantomData, }; use dense_mlpoly::DensePolynomial; @@ -326,7 +326,10 @@ impl> IOProofs { // Finally, show shifted(C) = orig(C) * C^(shift_size) + rc * openings, where rc * openings are the first few entries of the original poly dot product with the power series of C #[derive(Serialize, Deserialize, Debug)] struct ShiftProofs> { - proof: Pcs::Proof, + // _debug + // proof: Pcs::Proof, + _phantom_e: E, + _phantom_c: Pcs::Commitment } impl> ShiftProofs { @@ -382,7 +385,9 @@ impl> ShiftProofs // ); ShiftProofs { - proof: addr_phy_mems_shift_proof, + // proof: addr_phy_mems_shift_proof, + _phantom_e: E::from_u128(0), + _phantom_c: Pcs::Commitment::default(), } } @@ -762,6 +767,8 @@ impl> SNARK< */ let dense_commit_size = 0; + // _debug_serialization + /* let block_proof_size = bincode::serialize(&self.block_r1cs_sat_proof) .unwrap() .len() @@ -774,7 +781,11 @@ impl> SNARK< + bincode::serialize(&self.block_r1cs_eval_proof_list) .unwrap() .len(); + */ + let block_proof_size = 0; + // _debug_serialization + /* let pairwise_proof_size = bincode::serialize(&self.pairwise_check_r1cs_sat_proof) .unwrap() .len() @@ -787,7 +798,11 @@ impl> SNARK< + bincode::serialize(&self.pairwise_check_r1cs_eval_proof) .unwrap() .len(); + */ + let pairwise_proof_size = 0; + // _debug_serialization + /* let perm_proof_size = bincode::serialize(&self.perm_root_r1cs_sat_proof) .unwrap() .len() @@ -797,6 +812,9 @@ impl> SNARK< + bincode::serialize(&self.perm_root_r1cs_eval_proof) .unwrap() .len(); + */ + let perm_proof_size = 0; + // + bincode::serialize(&self.perm_poly_poly_list).unwrap().len() // + bincode::serialize(&self.proof_eval_perm_poly_prod_list).unwrap().len(); @@ -2039,7 +2057,9 @@ impl> SNARK< &mut random_tape, ); - let proof_encoded: Vec = bincode::serialize(&proof).unwrap(); + // _debug_serialization + // let proof_encoded: Vec = bincode::serialize(&proof).unwrap(); + let proof_encoded: Vec = vec![]; Timer::print(&format!("len_r1cs_eval_proof {:?}", proof_encoded.len())); r1cs_eval_proof_list.push(proof); @@ -2155,7 +2175,9 @@ impl> SNARK< &mut random_tape, ); - let proof_encoded: Vec = bincode::serialize(&proof).unwrap(); + // _debug_serialization + // let proof_encoded: Vec = bincode::serialize(&proof).unwrap(); + let proof_encoded: Vec = vec![]; Timer::print(&format!("len_r1cs_eval_proof {:?}", proof_encoded.len())); proof }; @@ -2270,7 +2292,9 @@ impl> SNARK< &mut random_tape, ); - let proof_encoded: Vec = bincode::serialize(&proof).unwrap(); + // _debug_serialization + // let proof_encoded: Vec = bincode::serialize(&proof).unwrap(); + let proof_encoded: Vec = vec![]; Timer::print(&format!("len_r1cs_eval_proof {:?}", proof_encoded.len())); proof }; diff --git a/spartan_parallel/src/r1csinstance.rs b/spartan_parallel/src/r1csinstance.rs index 65d88edd..87c001a3 100644 --- a/spartan_parallel/src/r1csinstance.rs +++ b/spartan_parallel/src/r1csinstance.rs @@ -19,6 +19,7 @@ use super::sparse_mlpoly::{ use super::timer::Timer; use flate2::{write::ZlibEncoder, Compression}; use serde::{Deserialize, Serialize}; +use serde::ser::{Serializer, SerializeStruct}; use std::iter::zip; #[derive(Debug, Serialize, Deserialize, Clone)] @@ -563,11 +564,22 @@ impl> R1CSCommitment> { proof: SparseMatPolyEvalProof, } +impl> Serialize for R1CSEvalProof { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut state = serializer.serialize_struct("R1CSEvalProof", 1)?; + state.serialize_field("_inner", &0)?; + state.end() + } +} + impl> R1CSEvalProof { // If is BLOCK, separate the first 3 entries of ry out (corresponding to the 5 segments of witnesses) pub fn prove( diff --git a/spartan_parallel/src/sparse_mlpoly.rs b/spartan_parallel/src/sparse_mlpoly.rs index 5c7e7bbb..2cd55b4b 100644 --- a/spartan_parallel/src/sparse_mlpoly.rs +++ b/spartan_parallel/src/sparse_mlpoly.rs @@ -1279,7 +1279,7 @@ impl ProductLayerProof { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug)] pub struct PolyEvalNetworkProof> { pub proof_prod_layer: ProductLayerProof, pub proof_hash_layer: HashLayerProof, @@ -1393,7 +1393,7 @@ impl> PolyEvalNetworkProof } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug)] pub struct SparseMatPolyEvalProof> { pub poly_eval_network_proof: PolyEvalNetworkProof, } From b1442de6fe2a95c01953c7fc2e987642de287acf Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 11 Feb 2025 00:55:22 -0500 Subject: [PATCH 12/18] Remove dense polynomial on sparse mat polynomial --- spartan_parallel/src/lib.rs | 8 ++++---- spartan_parallel/src/r1csproof.rs | 2 +- spartan_parallel/src/sparse_mlpoly.rs | 28 +++++++++++---------------- 3 files changed, 16 insertions(+), 22 deletions(-) diff --git a/spartan_parallel/src/lib.rs b/spartan_parallel/src/lib.rs index c6f9f109..05444dbd 100644 --- a/spartan_parallel/src/lib.rs +++ b/spartan_parallel/src/lib.rs @@ -1628,16 +1628,16 @@ impl> SNARK< &perm_exec_w3_v_comm, &perm_exec_w3_shifted_v_comm, ] { - Pcs::write_commitment(comm, transcript); + Pcs::write_commitment(comm, transcript).expect("Appending commitment should not fail"); } for comm in block_w2_v_comm.iter() { - Pcs::write_commitment(comm, transcript); + Pcs::write_commitment(comm, transcript).expect("Appending commitment should not fail"); } for comm in block_w3_v_comm.iter() { - Pcs::write_commitment(comm, transcript); + Pcs::write_commitment(comm, transcript).expect("Appending commitment should not fail"); } for comm in block_w3_shifted_v_comm.iter() { - Pcs::write_commitment(comm, transcript); + Pcs::write_commitment(comm, transcript).expect("Appending commitment should not fail"); } ( diff --git a/spartan_parallel/src/r1csproof.rs b/spartan_parallel/src/r1csproof.rs index 33083ef7..c4dcc51e 100644 --- a/spartan_parallel/src/r1csproof.rs +++ b/spartan_parallel/src/r1csproof.rs @@ -820,7 +820,7 @@ impl> R1CSPr let rq_short = rq[rq.len() - num_vars_q..].to_vec(); let r = [rq_short, ry_short.clone()].concat(); - pcs_verify::(&vp, &comm_w[idx], &r, &eval_Zr_list[idx], proof, transcript); + pcs_verify::(&vp, &comm_w[idx], &r, &eval_Zr_list[idx], proof, transcript).map_err(|_e| ProofVerifyError::InternalError)? } Ok(()) diff --git a/spartan_parallel/src/sparse_mlpoly.rs b/spartan_parallel/src/sparse_mlpoly.rs index 2cd55b4b..9cb22f6e 100644 --- a/spartan_parallel/src/sparse_mlpoly.rs +++ b/spartan_parallel/src/sparse_mlpoly.rs @@ -3,7 +3,7 @@ #![allow(clippy::needless_range_loop)] use ff_ext::ExtensionField; use mpcs::PolynomialCommitmentScheme; -use multilinear_extensions::mle::DenseMultilinearExtension; +use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; use super::dense_mlpoly::DensePolynomial; use super::dense_mlpoly::{EqPolynomial, IdentityPolynomial}; use super::errors::ProofVerifyError; @@ -277,8 +277,6 @@ pub struct MultiSparseMatPolynomialAsDense { val: Vec>, row: AddrTimestamps, col: AddrTimestamps, - comb_ops: DensePolynomial, - comb_mem: DensePolynomial, comb_ops_ceno_mle: DenseMultilinearExtension, comb_mem_ceno_mle: DenseMultilinearExtension, } @@ -388,8 +386,6 @@ impl SparseMatPolynomial { row: ret_row, col: ret_col, val: ret_val_vec, - comb_ops, - comb_mem, comb_ops_ceno_mle, comb_mem_ceno_mle, } @@ -502,8 +498,8 @@ impl> SparseMatPolyCommitm let batch_size = sparse_polys.len(); let dense = SparseMatPolynomial::multi_sparse_to_dense_rep(sparse_polys); - let l_ops = dense.comb_ops.len(); - let l_mem = dense.comb_mem.len(); + let l_ops = 1 << dense.comb_ops_ceno_mle.num_vars; + let l_mem = 1 << dense.comb_mem_ceno_mle.num_vars; let (p_ops, _) = Pcs::trim(Pcs::setup(l_ops).expect("Param setup should not fail"), l_ops).unwrap(); let (p_mem, _) = Pcs::trim(Pcs::setup(l_mem).expect("Param setup should not fail"), l_mem).unwrap(); @@ -801,14 +797,13 @@ impl> HashLayerProof = vec![eval_row_audit_ts, eval_col_audit_ts]; @@ -824,14 +819,13 @@ impl> HashLayerProof Date: Wed, 12 Feb 2025 14:44:32 -0500 Subject: [PATCH 13/18] Replace DensePolynomial --- spartan_parallel/src/custom_dense_mlpoly.rs | 55 ++-- spartan_parallel/src/dense_mlpoly.rs | 9 + spartan_parallel/src/lib.rs | 23 +- spartan_parallel/src/product_tree.rs | 120 +++++---- spartan_parallel/src/r1csinstance.rs | 11 +- spartan_parallel/src/r1csproof.rs | 33 ++- spartan_parallel/src/sparse_mlpoly.rs | 282 ++++++++++---------- spartan_parallel/src/sumcheck.rs | 72 ++--- 8 files changed, 324 insertions(+), 281 deletions(-) diff --git a/spartan_parallel/src/custom_dense_mlpoly.rs b/spartan_parallel/src/custom_dense_mlpoly.rs index 547304c7..728d8b35 100644 --- a/spartan_parallel/src/custom_dense_mlpoly.rs +++ b/spartan_parallel/src/custom_dense_mlpoly.rs @@ -1,7 +1,7 @@ #![allow(clippy::too_many_arguments)] use std::cmp::min; -use crate::dense_mlpoly::DensePolynomial; +use multilinear_extensions::mle::DenseMultilinearExtension; use crate::math::Math; use ff_ext::ExtensionField; use rayon::prelude::*; @@ -366,30 +366,31 @@ impl DensePolynomialPqx { return cl.index(0, 0, 0, 0); } - // Convert to a (p, q_rev, x_rev) regular dense poly of form (p, q, x) - pub fn to_dense_poly(&self) -> DensePolynomial { - let ZERO = E::ZERO; - - let p_space = self.num_vars_p.pow2(); - let q_space = self.num_vars_q.pow2(); - let w_space = self.num_vars_w.pow2(); - let x_space = self.num_vars_x.pow2(); - - let mut Z_poly = vec![ZERO; p_space * q_space * w_space * x_space]; - for p in 0..self.num_instances { - for q in 0..self.num_proofs[p] { - for w in 0..self.num_witness_secs { - for x in 0..self.num_inputs[p][w] { - Z_poly[ - p * q_space * w_space * x_space - + q * w_space * x_space - + w * x_space - + x - ] = self.Z[p][q][w][x]; - } - } - } - } - DensePolynomial::new(Z_poly) - } + // _debug + // // Convert to a (p, q_rev, x_rev) regular dense poly of form (p, q, x) + // pub fn to_dense_poly(&self) -> DensePolynomial { + // let ZERO = E::ZERO; + + // let p_space = self.num_vars_p.pow2(); + // let q_space = self.num_vars_q.pow2(); + // let w_space = self.num_vars_w.pow2(); + // let x_space = self.num_vars_x.pow2(); + + // let mut Z_poly = vec![ZERO; p_space * q_space * w_space * x_space]; + // for p in 0..self.num_instances { + // for q in 0..self.num_proofs[p] { + // for w in 0..self.num_witness_secs { + // for x in 0..self.num_inputs[p][w] { + // Z_poly[ + // p * q_space * w_space * x_space + // + q * w_space * x_space + // + w * x_space + // + x + // ] = self.Z[p][q][w][x]; + // } + // } + // } + // } + // DensePolynomial::new(Z_poly) + // } } \ No newline at end of file diff --git a/spartan_parallel/src/dense_mlpoly.rs b/spartan_parallel/src/dense_mlpoly.rs index 78385690..544f3998 100644 --- a/spartan_parallel/src/dense_mlpoly.rs +++ b/spartan_parallel/src/dense_mlpoly.rs @@ -11,12 +11,15 @@ use rayon::{iter::ParallelIterator, slice::ParallelSliceMut}; use serde::{Deserialize, Serialize}; use std::cmp::min; +// _debug +/* #[derive(Debug, Clone)] pub struct DensePolynomial { num_vars: usize, // the number of variables in the multilinear polynomial len: usize, Z: Vec, // evaluations of the polynomial in all the 2^num_vars Boolean inputs } +*/ pub struct EqPolynomial { r: Vec, @@ -112,6 +115,8 @@ impl IdentityPolynomial { } } +// _debug +/* impl DensePolynomial { pub fn new(mut Z: Vec) -> Self { // If length of Z is not a power of 2, append Z with 0 @@ -352,7 +357,10 @@ impl DensePolynomial { DenseMultilinearExtension::from_evaluation_vec_smart(self.num_vars, self.Z.clone()) } } +*/ +// _debug +/* impl Index for DensePolynomial { type Output = E; @@ -361,6 +369,7 @@ impl Index for DensePolynomial { &(self.Z[_index]) } } +*/ /* #[cfg(test)] diff --git a/spartan_parallel/src/lib.rs b/spartan_parallel/src/lib.rs index 05444dbd..0df1e78f 100644 --- a/spartan_parallel/src/lib.rs +++ b/spartan_parallel/src/lib.rs @@ -42,12 +42,11 @@ use std::{ io::Write, marker::PhantomData, }; -use dense_mlpoly::DensePolynomial; +use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; use errors::{ProofVerifyError, R1CSError}; use instance::Instance; use itertools::Itertools; use math::Math; -use multilinear_extensions::mle::DenseMultilinearExtension; use mpcs::{PolynomialCommitmentScheme, ProverParam}; use r1csinstance::{R1CSCommitment, R1CSDecommitment, R1CSEvalProof, R1CSInstance}; use r1csproof::R1CSProof; @@ -155,7 +154,7 @@ struct IOProofs> { impl> IOProofs { // Given the polynomial in execution order, generate all proofs fn prove( - exec_poly_inputs: &DensePolynomial, + exec_poly_inputs: &DenseMultilinearExtension, num_ios: usize, num_inputs_unpadded: usize, @@ -334,8 +333,8 @@ struct ShiftProofs> { impl> ShiftProofs { fn prove( - orig_polys: Vec<&DensePolynomial>, - shifted_polys: Vec<&DensePolynomial>, + orig_polys: Vec<&DenseMultilinearExtension>, + shifted_polys: Vec<&DenseMultilinearExtension>, // For each orig_poly, how many entries at the front of proof 0 are non-zero? header_len_list: Vec, transcript: &mut Transcript, @@ -346,11 +345,11 @@ impl> ShiftProofs assert_eq!(num_instances, shifted_polys.len()); let max_poly_size = orig_polys .iter() - .fold(0, |m, p| if p.len() > m { p.len() } else { m }); + .fold(0, |m, p| if p.evaluations().len() > m { p.evaluations().len() } else { m }); let max_poly_size = shifted_polys .iter() - .fold(max_poly_size, |m, p| if p.len() > m { p.len() } else { m }); + .fold(max_poly_size, |m, p| if p.evaluations().len() > m { p.evaluations().len() } else { m }); // Open entry 0..header_len_list[p] - 1 for p in 0..num_instances { for _i in 0..header_len_list[p] {} @@ -368,9 +367,9 @@ impl> ShiftProofs for p in 0..num_instances { let orig_poly = orig_polys[p]; let shifted_poly = shifted_polys[p]; - let orig_eval = (0..orig_poly.len()).fold(E::ZERO, |a, b| a + orig_poly[b] * rc[b]); + let orig_eval = (0..orig_poly.evaluations().len()).fold(E::ZERO, |a, b| a + orig_poly.get_ext_field_vec()[b] * rc[b]); let shifted_eval = - (0..shifted_poly.len()).fold(E::ZERO, |a, b| a + shifted_poly[b] * rc[b]); + (0..shifted_poly.evaluations().len()).fold(E::ZERO, |a, b| a + shifted_poly.get_ext_field_vec()[b] * rc[b]); orig_evals.push(orig_eval); shifted_evals.push(shifted_eval); } @@ -2826,7 +2825,7 @@ impl> SNARK< } perm_w0.extend(vec![E::ZERO; num_ios - 2 * num_inputs_unpadded]); // create a multilinear polynomial using the supplied assignment for variables - let _perm_poly_w0 = DensePolynomial::new(perm_w0.clone()); + let _perm_poly_w0: DenseMultilinearExtension = DenseMultilinearExtension::from_evaluation_vec_smart(perm_w0.len().log_2(), perm_w0); // block_w2 let block_w2_verifier = { @@ -2989,7 +2988,7 @@ impl> SNARK< ] .concat(); // create a multilinear polynomial using the supplied assignment for variables - let _poly_init_stacks = DensePolynomial::new(init_stacks.clone()); + let _poly_init_stacks: DenseMultilinearExtension = DenseMultilinearExtension::from_evaluation_vec_smart(init_stacks.len().log_2(), init_stacks); VerifierWitnessSecInfo::new( vec![INIT_PHY_MEM_WIDTH], &vec![total_num_init_phy_mem_accesses], @@ -3024,7 +3023,7 @@ impl> SNARK< ] .concat(); // create a multilinear polynomial using the supplied assignment for variables - let _poly_init_mems = DensePolynomial::new(init_mems.clone()); + let _poly_init_mems: DenseMultilinearExtension = DenseMultilinearExtension::from_evaluation_vec_smart(init_mems.len().log_2(), init_mems); VerifierWitnessSecInfo::new( vec![INIT_VIR_MEM_WIDTH], &vec![total_num_init_vir_mem_accesses], diff --git a/spartan_parallel/src/product_tree.rs b/spartan_parallel/src/product_tree.rs index 4ec266d1..d3fe7d0f 100644 --- a/spartan_parallel/src/product_tree.rs +++ b/spartan_parallel/src/product_tree.rs @@ -1,7 +1,7 @@ #![allow(dead_code)] use ff_ext::ExtensionField; -use super::dense_mlpoly::DensePolynomial; +use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; use super::dense_mlpoly::EqPolynomial; use super::math::Math; use super::sumcheck::SumcheckInstanceProof; @@ -10,35 +10,40 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone)] pub struct ProductCircuit { - left_vec: Vec>, - right_vec: Vec>, + left_vec: Vec>, + right_vec: Vec>, } impl ProductCircuit { fn compute_layer( - inp_left: &DensePolynomial, - inp_right: &DensePolynomial, - ) -> (DensePolynomial, DensePolynomial) { - let len = inp_left.len() + inp_right.len(); + inp_left: &DenseMultilinearExtension, + inp_right: &DenseMultilinearExtension, + ) -> (DenseMultilinearExtension, DenseMultilinearExtension) { + let len = inp_left.evaluations().len() + inp_right.evaluations().len(); let outp_left = (0..len / 4) - .map(|i| inp_left[i] * inp_right[i]) + .map(|i| inp_left.get_ext_field_vec()[i] * inp_right.get_ext_field_vec()[i]) .collect::>(); let outp_right = (len / 4..len / 2) - .map(|i| inp_left[i] * inp_right[i]) + .map(|i| inp_left.get_ext_field_vec()[i] * inp_right.get_ext_field_vec()[i]) .collect::>(); ( - DensePolynomial::new(outp_left), - DensePolynomial::new(outp_right), + DenseMultilinearExtension::from_evaluation_vec_smart(outp_left.len().log_2(), outp_left), + DenseMultilinearExtension::from_evaluation_vec_smart(outp_right.len().log_2(), outp_right), ) } - pub fn new(poly: &DensePolynomial) -> Self { - let mut left_vec: Vec> = Vec::new(); - let mut right_vec: Vec> = Vec::new(); + pub fn new(poly: &DenseMultilinearExtension) -> Self { + let mut left_vec: Vec> = Vec::new(); + let mut right_vec: Vec> = Vec::new(); - let num_layers = poly.len().log_2(); - let (outp_left, outp_right) = poly.split(poly.len() / 2); + let split_idx = poly.evaluations().len() / 2; + let num_layers = poly.evaluations().len().log_2(); + + let (outp_left, outp_right): (DenseMultilinearExtension, DenseMultilinearExtension) = ( + DenseMultilinearExtension::from_evaluation_vec_smart(split_idx.log_2(), poly.get_ext_field_vec()[0..split_idx].to_vec()), + DenseMultilinearExtension::from_evaluation_vec_smart(split_idx.log_2(), poly.get_ext_field_vec()[split_idx..].to_vec()) + ); left_vec.push(outp_left); right_vec.push(outp_right); @@ -57,27 +62,27 @@ impl ProductCircuit { pub fn evaluate(&self) -> E { let len = self.left_vec.len(); - assert_eq!(self.left_vec[len - 1].get_num_vars(), 0); - assert_eq!(self.right_vec[len - 1].get_num_vars(), 0); - self.left_vec[len - 1][0] * self.right_vec[len - 1][0] + assert_eq!(self.left_vec[len - 1].num_vars, 0); + assert_eq!(self.right_vec[len - 1].num_vars, 0); + self.left_vec[len - 1].get_ext_field_vec()[0] * self.right_vec[len - 1].get_ext_field_vec()[0] } } #[derive(Clone)] pub struct DotProductCircuit { - left: DensePolynomial, - right: DensePolynomial, - weight: DensePolynomial, + left: DenseMultilinearExtension, + right: DenseMultilinearExtension, + weight: DenseMultilinearExtension, } impl DotProductCircuit { pub fn new( - left: DensePolynomial, - right: DensePolynomial, - weight: DensePolynomial, + left: DenseMultilinearExtension, + right: DenseMultilinearExtension, + weight: DenseMultilinearExtension, ) -> Self { - assert_eq!(left.len(), right.len()); - assert_eq!(left.len(), weight.len()); + assert_eq!(left.evaluations().len(), right.evaluations().len()); + assert_eq!(left.evaluations().len(), weight.evaluations().len()); DotProductCircuit { left, right, @@ -86,17 +91,28 @@ impl DotProductCircuit { } pub fn evaluate(&self) -> E { - (0..self.left.len()) - .map(|i| self.left[i] * self.right[i] * self.weight[i]) + (0..self.left.evaluations().len()) + .map(|i| self.left.get_ext_field_vec()[i] * self.right.get_ext_field_vec()[i] * self.weight.get_ext_field_vec()[i]) .sum() } pub fn split(&mut self) -> (DotProductCircuit, DotProductCircuit) { - let idx = self.left.len() / 2; - assert_eq!(idx * 2, self.left.len()); - let (l1, l2) = self.left.split(idx); - let (r1, r2) = self.right.split(idx); - let (w1, w2) = self.weight.split(idx); + let idx = self.left.evaluations().len() / 2; + assert_eq!(idx * 2, self.left.evaluations().len()); + + let (l1, l2): (DenseMultilinearExtension, DenseMultilinearExtension) = ( + DenseMultilinearExtension::from_evaluation_vec_smart(idx.log_2(), self.left.get_ext_field_vec()[0..idx].to_vec()), + DenseMultilinearExtension::from_evaluation_vec_smart(idx.log_2(), self.left.get_ext_field_vec()[idx..].to_vec()) + ); + let (r1, r2): (DenseMultilinearExtension, DenseMultilinearExtension) = ( + DenseMultilinearExtension::from_evaluation_vec_smart(idx.log_2(), self.right.get_ext_field_vec()[0..idx].to_vec()), + DenseMultilinearExtension::from_evaluation_vec_smart(idx.log_2(), self.right.get_ext_field_vec()[idx..].to_vec()) + ); + let (w1, w2): (DenseMultilinearExtension, DenseMultilinearExtension) = ( + DenseMultilinearExtension::from_evaluation_vec_smart(idx.log_2(), self.weight.get_ext_field_vec()[0..idx].to_vec()), + DenseMultilinearExtension::from_evaluation_vec_smart(idx.log_2(), self.weight.get_ext_field_vec()[idx..].to_vec()) + ); + ( DotProductCircuit { left: l1, @@ -179,12 +195,13 @@ impl ProductCircuitEvalProof { let mut claim = circuit.evaluate(); let mut rand = Vec::new(); for layer_id in (0..num_layers).rev() { - let len = circuit.left_vec[layer_id].len() + circuit.right_vec[layer_id].len(); + let len = circuit.left_vec[layer_id].evaluations().len() + circuit.right_vec[layer_id].evaluations().len(); - let mut poly_C = DensePolynomial::new(EqPolynomial::new(rand.clone()).evals()); - assert_eq!(poly_C.len(), len / 2); + let poly_c_evals = EqPolynomial::new(rand.clone()).evals(); + let mut poly_C = DenseMultilinearExtension::from_evaluation_vec_smart(poly_c_evals.len().log_2(), poly_c_evals); + assert_eq!(poly_C.evaluations().len(), len / 2); - let num_rounds_prod = poly_C.len().log_2(); + let num_rounds_prod = poly_C.evaluations().len().log_2(); let comb_func_prod = |poly_A_comp: &E, poly_B_comp: &E, poly_C_comp: &E| -> E { *poly_A_comp * *poly_B_comp * *poly_C_comp }; @@ -269,19 +286,20 @@ impl ProductCircuitEvalProofBatched { let mut rand = Vec::new(); for layer_id in (0..num_layers).rev() { // prepare paralell instance that share poly_C first - let len = prod_circuit_vec[0].left_vec[layer_id].len() - + prod_circuit_vec[0].right_vec[layer_id].len(); + let len = prod_circuit_vec[0].left_vec[layer_id].evaluations().len() + + prod_circuit_vec[0].right_vec[layer_id].evaluations().len(); - let mut poly_C_par = DensePolynomial::new(EqPolynomial::new(rand.clone()).evals()); - assert_eq!(poly_C_par.len(), len / 2); + let poly_c_par_evals = EqPolynomial::new(rand.clone()).evals(); + let mut poly_C_par = DenseMultilinearExtension::from_evaluation_vec_smart(poly_c_par_evals.len().log_2(), poly_c_par_evals); + assert_eq!(poly_C_par.evaluations().len(), len / 2); - let num_rounds_prod = poly_C_par.len().log_2(); + let num_rounds_prod = poly_C_par.evaluations().len().log_2(); let comb_func_prod = |poly_A_comp: &E, poly_B_comp: &E, poly_C_comp: &E| -> E { *poly_A_comp * *poly_B_comp * *poly_C_comp }; - let mut poly_A_batched_par: Vec<&mut DensePolynomial> = Vec::new(); - let mut poly_B_batched_par: Vec<&mut DensePolynomial> = Vec::new(); + let mut poly_A_batched_par: Vec<&mut DenseMultilinearExtension> = Vec::new(); + let mut poly_B_batched_par: Vec<&mut DenseMultilinearExtension> = Vec::new(); for prod_circuit in prod_circuit_vec.iter_mut() { poly_A_batched_par.push(&mut prod_circuit.left_vec[layer_id]); poly_B_batched_par.push(&mut prod_circuit.right_vec[layer_id]) @@ -293,16 +311,16 @@ impl ProductCircuitEvalProofBatched { ); // prepare sequential instances that don't share poly_C - let mut poly_A_batched_seq: Vec<&mut DensePolynomial> = Vec::new(); - let mut poly_B_batched_seq: Vec<&mut DensePolynomial> = Vec::new(); - let mut poly_C_batched_seq: Vec<&mut DensePolynomial> = Vec::new(); + let mut poly_A_batched_seq: Vec<&mut DenseMultilinearExtension> = Vec::new(); + let mut poly_B_batched_seq: Vec<&mut DenseMultilinearExtension> = Vec::new(); + let mut poly_C_batched_seq: Vec<&mut DenseMultilinearExtension> = Vec::new(); if layer_id == 0 && !dotp_circuit_vec.is_empty() { // add additional claims for item in dotp_circuit_vec.iter() { claims_to_verify.push(item.evaluate()); - assert_eq!(len / 2, item.left.len()); - assert_eq!(len / 2, item.right.len()); - assert_eq!(len / 2, item.weight.len()); + assert_eq!(len / 2, item.left.evaluations().len()); + assert_eq!(len / 2, item.right.evaluations().len()); + assert_eq!(len / 2, item.weight.evaluations().len()); } for dotp_circuit in dotp_circuit_vec.iter_mut() { diff --git a/spartan_parallel/src/r1csinstance.rs b/spartan_parallel/src/r1csinstance.rs index 87c001a3..0f1fbaeb 100644 --- a/spartan_parallel/src/r1csinstance.rs +++ b/spartan_parallel/src/r1csinstance.rs @@ -8,7 +8,7 @@ use ff_ext::ExtensionField; use crate::transcript::{Transcript, append_field_to_transcript}; use super::custom_dense_mlpoly::DensePolynomialPqx; -use super::dense_mlpoly::DensePolynomial; +use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; use super::errors::ProofVerifyError; use super::math::Math; use super::random::RandomTape; @@ -456,9 +456,12 @@ impl R1CSInstance { c_evals.push(evals[2]); } // Bind A, B, C to rp - let a_eval = DensePolynomial::new(a_evals).evaluate(rp); - let b_eval = DensePolynomial::new(b_evals).evaluate(rp); - let c_eval = DensePolynomial::new(c_evals).evaluate(rp); + let a_poly = DenseMultilinearExtension::from_evaluation_vec_smart(a_evals.len().log_2(), a_evals); + let b_poly = DenseMultilinearExtension::from_evaluation_vec_smart(b_evals.len().log_2(), b_evals); + let c_poly = DenseMultilinearExtension::from_evaluation_vec_smart(c_evals.len().log_2(), c_evals); + let a_eval = a_poly.evaluate(rp); + let b_eval = b_poly.evaluate(rp); + let c_eval = c_poly.evaluate(rp); let eval_bound_rp = (a_eval, b_eval, c_eval); (eval_list, eval_bound_rp) diff --git a/spartan_parallel/src/r1csproof.rs b/spartan_parallel/src/r1csproof.rs index c4dcc51e..d018de22 100644 --- a/spartan_parallel/src/r1csproof.rs +++ b/spartan_parallel/src/r1csproof.rs @@ -1,6 +1,7 @@ #![allow(clippy::too_many_arguments)] use super::custom_dense_mlpoly::DensePolynomialPqx; -use super::dense_mlpoly::{DensePolynomial, EqPolynomial}; +use super::dense_mlpoly::EqPolynomial; +use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; use super::errors::ProofVerifyError; use super::math::Math; use super::r1csinstance::R1CSInstance; @@ -65,9 +66,9 @@ impl> R1CSPr num_rounds_p: usize, num_proofs: &Vec, num_cons: &Vec, - evals_tau_p: &mut DensePolynomial, - evals_tau_q: &mut DensePolynomial, - evals_tau_x: &mut DensePolynomial, + evals_tau_p: &mut DenseMultilinearExtension, + evals_tau_q: &mut DenseMultilinearExtension, + evals_tau_x: &mut DenseMultilinearExtension, evals_Az: &mut DensePolynomialPqx, evals_Bz: &mut DensePolynomialPqx, evals_Cz: &mut DensePolynomialPqx, @@ -108,7 +109,7 @@ impl> R1CSPr num_witness_secs: usize, num_inputs: Vec>, claim: &E, - evals_eq: &mut DensePolynomial, + evals_eq: &mut DenseMultilinearExtension, evals_ABC: &mut DensePolynomialPqx, evals_z: &mut DensePolynomialPqx, transcript: &mut Transcript, @@ -242,9 +243,12 @@ impl> R1CSPr let tau_x = challenge_vector(transcript, b"challenge_tau_x", num_rounds_x); // compute the initial evaluation table for R(\tau, x) - let mut poly_tau_p = DensePolynomial::new(EqPolynomial::new(tau_p).evals()); - let mut poly_tau_q = DensePolynomial::new(EqPolynomial::new(tau_q).evals()); - let mut poly_tau_x = DensePolynomial::new(EqPolynomial::new(tau_x).evals()); + let tau_p_evals = EqPolynomial::new(tau_p).evals(); + let tau_q_evals = EqPolynomial::new(tau_q).evals(); + let tau_x_evals = EqPolynomial::new(tau_x).evals(); + let mut poly_tau_p = DenseMultilinearExtension::from_evaluation_vec_smart(tau_p_evals.len().log_2(), tau_p_evals); + let mut poly_tau_q = DenseMultilinearExtension::from_evaluation_vec_smart(tau_q_evals.len().log_2(), tau_q_evals); + let mut poly_tau_x = DenseMultilinearExtension::from_evaluation_vec_smart(tau_x_evals.len().log_2(), tau_x_evals); let (mut poly_Az, mut poly_Bz, mut poly_Cz) = inst.multiply_vec_block( num_instances, num_proofs.clone(), @@ -274,9 +278,9 @@ impl> R1CSPr transcript, ); - assert_eq!(poly_tau_p.len(), 1); - assert_eq!(poly_tau_q.len(), 1); - assert_eq!(poly_tau_x.len(), 1); + assert_eq!(poly_tau_p.evaluations().len(), 1); + assert_eq!(poly_tau_q.evaluations().len(), 1); + assert_eq!(poly_tau_x.evaluations().len(), 1); assert_eq!(poly_Az.len(), 1); assert_eq!(poly_Bz.len(), 1); assert_eq!(poly_Cz.len(), 1); @@ -284,7 +288,7 @@ impl> R1CSPr timer_sc_proof_phase1.stop(); let (_tau_claim, Az_claim, Bz_claim, Cz_claim) = ( - &(poly_tau_p[0] * poly_tau_q[0] * poly_tau_x[0]), + &(poly_tau_p.get_ext_field_vec()[0] * poly_tau_q.get_ext_field_vec()[0] * poly_tau_x.get_ext_field_vec()[0]), &poly_Az.index(0, 0, 0, 0), &poly_Bz.index(0, 0, 0, 0), &poly_Cz.index(0, 0, 0, 0), @@ -355,7 +359,8 @@ impl> R1CSPr timer_tmp.stop(); // An Eq function to match p with rp - let mut eq_p_rp_poly = DensePolynomial::new(EqPolynomial::new(rp).evals()); + let rp_evals = EqPolynomial::new(rp).evals(); + let mut eq_p_rp_poly = DenseMultilinearExtension::from_evaluation_vec_smart(rp_evals.len().log_2(), rp_evals); // Sumcheck 2: (rA + rB + rC) * Z * eq(p) = e let timer_tmp = Timer::new("prove_sum_check"); @@ -532,7 +537,7 @@ impl> R1CSPr } timer_polyeval.stop(); - let poly_vars = DensePolynomial::new(eval_vars_comb_list); + let poly_vars = DenseMultilinearExtension::from_evaluation_vec_smart(eval_vars_comb_list.len().log_2(), eval_vars_comb_list); let eval_vars_at_ry = poly_vars.evaluate(&rp); // prove the final step of sum-check #2 // Deferred to verifier diff --git a/spartan_parallel/src/sparse_mlpoly.rs b/spartan_parallel/src/sparse_mlpoly.rs index 9cb22f6e..0b927de8 100644 --- a/spartan_parallel/src/sparse_mlpoly.rs +++ b/spartan_parallel/src/sparse_mlpoly.rs @@ -4,7 +4,6 @@ use ff_ext::ExtensionField; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; -use super::dense_mlpoly::DensePolynomial; use super::dense_mlpoly::{EqPolynomial, IdentityPolynomial}; use super::errors::ProofVerifyError; use super::math::Math; @@ -36,9 +35,9 @@ pub struct SparseMatPolynomial { } pub struct Derefs { - row_ops_val: Vec>, - col_ops_val: Vec>, - comb: DensePolynomial, + row_ops_val: Vec>, + col_ops_val: Vec>, + comb: DenseMultilinearExtension, } #[derive(Debug, Serialize, Deserialize)] @@ -47,7 +46,7 @@ pub struct DerefsCommitment Derefs { - pub fn new(row_ops_val: Vec>, col_ops_val: Vec>) -> Self { + pub fn new(row_ops_val: Vec>, col_ops_val: Vec>) -> Self { assert_eq!(row_ops_val.len(), col_ops_val.len()); let ret_row_ops_val = row_ops_val.clone(); @@ -55,7 +54,10 @@ impl Derefs { let derefs = { // combine all polynomials into a single polynomial (used below to produce a single commitment) - let comb = DensePolynomial::merge(row_ops_val.into_iter().chain(col_ops_val.into_iter())); + let mut comb = row_ops_val[0].clone(); + for p in row_ops_val.into_iter().skip(1).chain(col_ops_val.into_iter()) { + comb.merge(p); + } Derefs { row_ops_val: ret_row_ops_val, @@ -70,10 +72,10 @@ impl Derefs { impl> DerefsCommitment { pub fn commit(derefs: &Derefs) -> DerefsCommitment { - let l = derefs.comb.len(); + let l = derefs.comb.evaluations().len(); let param = Pcs::setup(l).unwrap(); let (pp, _vp) = Pcs::trim(param, l).unwrap(); - let comm_ops_val = Pcs::get_pure_commitment(&Pcs::commit(&pp, &derefs.comb.to_ceno_mle()).expect("Commitment should not fail")); + let comm_ops_val = Pcs::get_pure_commitment(&Pcs::commit(&pp, &derefs.comb).expect("Commitment should not fail")); DerefsCommitment { comm_ops_val } } } @@ -90,13 +92,13 @@ impl> DerefsEvalProof, + joint_poly: &DenseMultilinearExtension, r: &[E], evals: Vec, transcript: &mut Transcript, random_tape: &mut RandomTape, ) -> Self { - assert_eq!(joint_poly.get_num_vars(), r.len() + evals.len().log_2()); + assert_eq!(joint_poly.num_vars, r.len() + evals.len().log_2()); // append the claimed evaluations to transcript append_field_vector_to_transcript(b"evals_ops_val", transcript, &evals); @@ -105,12 +107,11 @@ impl> DerefsEvalProof> DerefsEvalProof> DerefsEvalProof> DerefsEvalProof { ops_addr_usize: Vec>, - ops_addr: Vec>, - read_ts: Vec>, - audit_ts: DensePolynomial, + ops_addr: Vec>, + read_ts: Vec>, + audit_ts: DenseMultilinearExtension, } impl AddrTimestamps { @@ -225,8 +224,8 @@ impl AddrTimestamps { } let mut audit_ts = vec![0usize; num_cells]; - let mut ops_addr_vec: Vec> = Vec::new(); - let mut read_ts_vec: Vec> = Vec::new(); + let mut ops_addr_vec: Vec> = Vec::new(); + let mut read_ts_vec: Vec> = Vec::new(); for ops_addr_inst in ops_addr.iter() { let mut read_ts = vec![0usize; num_ops]; @@ -242,39 +241,43 @@ impl AddrTimestamps { audit_ts[addr] = w_ts; } - ops_addr_vec.push(DensePolynomial::from_usize(ops_addr_inst)); - read_ts_vec.push(DensePolynomial::from_usize(&read_ts)); + let ops_addr_inst_evals = ops_addr_inst.into_iter().map(|&n| E::from_u128(n as u128)).collect::>(); + let read_ts_evals = read_ts.into_iter().map(|n| E::from_u128(n as u128)).collect::>(); + ops_addr_vec.push(DenseMultilinearExtension::from_evaluation_vec_smart(ops_addr_inst_evals.len().log_2(), ops_addr_inst_evals)); + read_ts_vec.push(DenseMultilinearExtension::from_evaluation_vec_smart(read_ts_evals.len().log_2(), read_ts_evals)); } + let audit_ts_evals = audit_ts.into_iter().map(|n| E::from_u128(n as u128)).collect::>(); + AddrTimestamps { ops_addr: ops_addr_vec, ops_addr_usize: ops_addr, read_ts: read_ts_vec, - audit_ts: DensePolynomial::from_usize(&audit_ts), + audit_ts: DenseMultilinearExtension::from_evaluation_vec_smart(audit_ts_evals.len().log_2(), audit_ts_evals), } } - fn deref_mem(addr: &[usize], mem_val: &[E]) -> DensePolynomial { - DensePolynomial::new( - (0..addr.len()) - .map(|i| { - let a = addr[i]; - mem_val[a] - }) - .collect::>(), - ) + fn deref_mem(addr: &[usize], mem_val: &[E]) -> DenseMultilinearExtension { + let evals = (0..addr.len()) + .map(|i| { + let a = addr[i]; + mem_val[a] + }) + .collect::>(); + + DenseMultilinearExtension::from_evaluation_vec_smart(evals.len().log_2(), evals) } - pub fn deref(&self, mem_val: &[E]) -> Vec> { + pub fn deref(&self, mem_val: &[E]) -> Vec> { (0..self.ops_addr.len()) .map(|i| AddrTimestamps::deref_mem(&self.ops_addr_usize[i], mem_val)) - .collect::>>() + .collect::>>() } } pub struct MultiSparseMatPolynomialAsDense { batch_size: usize, - val: Vec>, + val: Vec>, row: AddrTimestamps, col: AddrTimestamps, comb_ops_ceno_mle: DenseMultilinearExtension, @@ -342,12 +345,12 @@ impl SparseMatPolynomial { let mut ops_row_vec: Vec> = Vec::new(); let mut ops_col_vec: Vec> = Vec::new(); - let mut val_vec: Vec> = Vec::new(); + let mut val_vec: Vec> = Vec::new(); for poly in sparse_polys { let (ops_row, ops_col, val) = poly.sparse_to_dense_vecs(N); ops_row_vec.push(ops_row); ops_col_vec.push(ops_col); - val_vec.push(DensePolynomial::new(val)); + val_vec.push(DenseMultilinearExtension::from_evaluation_vec_smart(val.len().log_2(), val)); } let any_poly = &sparse_polys[0]; @@ -366,28 +369,29 @@ impl SparseMatPolynomial { let ret_val_vec = val_vec.clone(); // combine polynomials into a single polynomial for commitment purposes - let comb_ops = DensePolynomial::merge( - row - .ops_addr - .into_iter() - .chain(row.read_ts.into_iter()) - .chain(col.ops_addr.into_iter()) - .chain(col.read_ts.into_iter()) - .chain(val_vec.into_iter()), - ); - let mut comb_mem = row.audit_ts.clone(); - comb_mem.extend(&col.audit_ts); + let mut comb_ops = row.ops_addr[0].clone(); + for p in + row.ops_addr + .into_iter() + .skip(1) + .chain(row.read_ts.into_iter()) + .chain(col.ops_addr.into_iter()) + .chain(col.read_ts.into_iter()) + .chain(val_vec.into_iter()) + { + comb_ops.merge(p); + } - let comb_ops_ceno_mle = comb_ops.clone().to_ceno_mle(); - let comb_mem_ceno_mle = comb_mem.clone().to_ceno_mle(); + let mut comb_mem = row.audit_ts.clone(); + comb_mem.merge(col.audit_ts); MultiSparseMatPolynomialAsDense { batch_size: sparse_polys.len(), row: ret_row, col: ret_col, val: ret_val_vec, - comb_ops_ceno_mle, - comb_mem_ceno_mle, + comb_ops_ceno_mle: comb_ops, + comb_mem_ceno_mle: comb_mem, } } @@ -510,8 +514,8 @@ impl> SparseMatPolyCommitm ( SparseMatPolyCommitment { batch_size, - num_mem_cells: dense.row.audit_ts.len(), - num_ops: dense.row.read_ts[0].len(), + num_mem_cells: dense.row.audit_ts.evaluations().len(), + num_ops: dense.row.read_ts[0].evaluations().len(), comm_comb_ops, comm_comb_mem, _phantom: E::ZERO, @@ -545,16 +549,16 @@ struct Layers { impl Layers { fn build_hash_layer( eval_table: &[E], - addrs_vec: &[DensePolynomial], - derefs_vec: &[DensePolynomial], - read_ts_vec: &[DensePolynomial], - audit_ts: &DensePolynomial, + addrs_vec: &[DenseMultilinearExtension], + derefs_vec: &[DenseMultilinearExtension], + read_ts_vec: &[DenseMultilinearExtension], + audit_ts: &DenseMultilinearExtension, r_mem_check: &(E, E), ) -> ( - DensePolynomial, - Vec>, - Vec>, - DensePolynomial, + DenseMultilinearExtension, + Vec>, + Vec>, + DenseMultilinearExtension, ) { let (r_hash, r_multiset_check) = r_mem_check; @@ -564,49 +568,45 @@ impl Layers { // hash init and audit that does not depend on #instances let num_mem_cells = eval_table.len(); - let poly_init_hashed = DensePolynomial::new( - (0..num_mem_cells) - .map(|i| { - // at init time, addr is given by i, init value is given by eval_table, and ts = 0 - hash_func(&E::from(i as u64), &eval_table[i], &E::ZERO) - *r_multiset_check - }) - .collect::>(), - ); - let poly_audit_hashed = DensePolynomial::new( - (0..num_mem_cells) - .map(|i| { - // at audit time, addr is given by i, value is given by eval_table, and ts is given by audit_ts - hash_func(&E::from(i as u64), &eval_table[i], &audit_ts[i]) - *r_multiset_check - }) - .collect::>(), - ); + let poly_init_hashed_evals = (0..num_mem_cells) + .map(|i| { + // at init time, addr is given by i, init value is given by eval_table, and ts = 0 + hash_func(&E::from(i as u64), &eval_table[i], &E::ZERO) - *r_multiset_check + }) + .collect::>(); + let poly_init_hashed = DenseMultilinearExtension::from_evaluation_vec_smart(poly_init_hashed_evals.len().log_2(), poly_init_hashed_evals); + let poly_audit_hashed_evals = (0..num_mem_cells) + .map(|i| { + // at audit time, addr is given by i, value is given by eval_table, and ts is given by audit_ts + hash_func(&E::from(i as u64), &eval_table[i], &audit_ts.get_ext_field_vec()[i]) - *r_multiset_check + }) + .collect::>(); + let poly_audit_hashed = DenseMultilinearExtension::from_evaluation_vec_smart(poly_audit_hashed_evals.len().log_2(), poly_audit_hashed_evals); // hash read and write that depends on #instances - let mut poly_read_hashed_vec: Vec> = Vec::new(); - let mut poly_write_hashed_vec: Vec> = Vec::new(); + let mut poly_read_hashed_vec: Vec> = Vec::new(); + let mut poly_write_hashed_vec: Vec> = Vec::new(); for i in 0..addrs_vec.len() { let (addrs, derefs, read_ts) = (&addrs_vec[i], &derefs_vec[i], &read_ts_vec[i]); - assert_eq!(addrs.len(), derefs.len()); - assert_eq!(addrs.len(), read_ts.len()); - let num_ops = addrs.len(); - let poly_read_hashed = DensePolynomial::new( - (0..num_ops) - .map(|i| { - // at read time, addr is given by addrs, value is given by derefs, and ts is given by read_ts - hash_func(&addrs[i], &derefs[i], &read_ts[i]) - *r_multiset_check - }) - .collect::>(), - ); + assert_eq!(addrs.evaluations().len(), derefs.evaluations().len()); + assert_eq!(addrs.evaluations().len(), read_ts.evaluations().len()); + let num_ops = addrs.evaluations().len(); + let poly_read_hashed_evals = (0..num_ops) + .map(|i| { + // at read time, addr is given by addrs, value is given by derefs, and ts is given by read_ts + hash_func(&addrs.get_ext_field_vec()[i], &derefs.get_ext_field_vec()[i], &read_ts.get_ext_field_vec()[i]) - *r_multiset_check + }) + .collect::>(); + let poly_read_hashed = DenseMultilinearExtension::from_evaluation_vec_smart(poly_read_hashed_evals.len().log_2(), poly_read_hashed_evals); poly_read_hashed_vec.push(poly_read_hashed); - let poly_write_hashed = DensePolynomial::new( - (0..num_ops) - .map(|i| { - // at write time, addr is given by addrs, value is given by derefs, and ts is given by write_ts = read_ts + 1 - hash_func(&addrs[i], &derefs[i], &(read_ts[i] + E::ONE)) - *r_multiset_check - }) - .collect::>(), - ); + let poly_write_hashed_evals = (0..num_ops) + .map(|i| { + // at write time, addr is given by addrs, value is given by derefs, and ts is given by write_ts = read_ts + 1 + hash_func(&addrs.get_ext_field_vec()[i], &derefs.get_ext_field_vec()[i], &(read_ts.get_ext_field_vec()[i] + E::ONE)) - *r_multiset_check + }) + .collect::>(); + let poly_write_hashed = DenseMultilinearExtension::from_evaluation_vec_smart(poly_write_hashed_evals.len().log_2(), poly_write_hashed_evals); poly_write_hashed_vec.push(poly_write_hashed); } @@ -621,7 +621,7 @@ impl Layers { pub fn new( eval_table: &[E], addr_timestamps: &AddrTimestamps, - poly_ops_val: &[DensePolynomial], + poly_ops_val: &[DenseMultilinearExtension], r_mem_check: &(E, E), ) -> Self { let (poly_init_hashed, poly_read_hashed_vec, poly_write_hashed_vec, poly_audit_hashed) = @@ -786,15 +786,15 @@ impl> HashLayerProof> HashLayerProof = vec![eval_row_audit_ts, eval_col_audit_ts]; append_field_vector_to_transcript(b"claim_evals_mem", transcript, &evals_mem); + let num_vars_mem = evals_mem.len().log_2(); let challenges_mem = - challenge_vector(transcript, b"challenge_combine_two_to_one", evals_mem.len().log_2()); + challenge_vector(transcript, b"challenge_combine_two_to_one", num_vars_mem); - let mut poly_evals_mem = DensePolynomial::new(evals_mem); - for i in (0..challenges_mem.len()).rev() { - poly_evals_mem.bound_poly_var_bot(&challenges_mem[i]); - } - assert_eq!(poly_evals_mem.len(), 1); - let joint_claim_eval_mem = poly_evals_mem[0]; + let mut poly_evals_mem = DenseMultilinearExtension::from_evaluation_vec_smart(num_vars_mem, evals_mem); + // _debug: variable order + poly_evals_mem.fix_variables_in_place(&challenges_mem); + assert_eq!(poly_evals_mem.evaluations().len(), 1); + let joint_claim_eval_mem = poly_evals_mem.get_ext_field_vec()[0]; let mut r_joint_mem = challenges_mem; r_joint_mem.extend(rand_mem); debug_assert_eq!(dense.comb_mem_ceno_mle.evaluate(&r_joint_mem), joint_claim_eval_mem); @@ -946,20 +946,19 @@ impl> HashLayerProof> HashLayerProof = vec![*eval_row_audit_ts, *eval_col_audit_ts]; append_field_vector_to_transcript(b"claim_evals_mem", transcript, &evals_mem); + let num_vars_mem = evals_mem.len().log_2(); let challenges_mem = - challenge_vector(transcript, b"challenge_combine_two_to_one", evals_mem.len().log_2()); + challenge_vector(transcript, b"challenge_combine_two_to_one", num_vars_mem); - let mut poly_evals_mem = DensePolynomial::new(evals_mem); - for i in (0..challenges_mem.len()).rev() { - poly_evals_mem.bound_poly_var_bot(&challenges_mem[i]); - } - assert_eq!(poly_evals_mem.len(), 1); - let joint_claim_eval_mem = poly_evals_mem[0]; + let mut poly_evals_mem = DenseMultilinearExtension::from_evaluation_vec_smart(num_vars_mem, evals_mem); + poly_evals_mem.fix_variables_in_place(&challenges_mem); + assert_eq!(poly_evals_mem.evaluations().len(), 1); + let joint_claim_eval_mem = poly_evals_mem.get_ext_field_vec()[0]; let mut r_joint_mem = challenges_mem; r_joint_mem.extend(rand_mem); append_field_to_transcript(b"joint_claim_eval_mem", transcript, joint_claim_eval_mem); - let l: usize = 1 << poly_evals_mem.get_num_vars(); + let l: usize = 1 << poly_evals_mem.num_vars; let (_pp, vp) = Pcs::trim(Pcs::setup(l).expect("Param setup should not fail."), l).expect("Param trim should not fail."); Pcs::verify(&vp, &Pcs::get_pure_commitment(&self.comm_mem), &r_joint_mem, &joint_claim_eval_mem, &self.proof_mem, transcript).map_err(|e| ProofVerifyError::InternalError)?; diff --git a/spartan_parallel/src/sumcheck.rs b/spartan_parallel/src/sumcheck.rs index 1f134c82..69eaa995 100644 --- a/spartan_parallel/src/sumcheck.rs +++ b/spartan_parallel/src/sumcheck.rs @@ -4,7 +4,7 @@ use crate::custom_dense_mlpoly::DensePolynomialPqx; use crate::math::Math; use ff_ext::ExtensionField; -use super::dense_mlpoly::DensePolynomial; +use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; use super::errors::ProofVerifyError; use super::transcript::{Transcript, challenge_scalar}; use super::unipoly::{CompressedUniPoly, UniPoly}; @@ -69,12 +69,14 @@ impl SumcheckInstanceProof { } impl SumcheckInstanceProof { + // _debug: remove native sumcheck prover + /* pub fn prove_cubic( claim: &E, num_rounds: usize, - poly_A: &mut DensePolynomial, - poly_B: &mut DensePolynomial, - poly_C: &mut DensePolynomial, + poly_A: &mut DenseMultilinearExtension, + poly_B: &mut DenseMultilinearExtension, + poly_C: &mut DenseMultilinearExtension, comb_func: F, transcript: &mut Transcript, ) -> (Self, Vec, Vec) @@ -89,7 +91,7 @@ impl SumcheckInstanceProof { let mut eval_point_2 = E::ZERO; let mut eval_point_3 = E::ZERO; - let len = poly_A.len() / 2; + let len = poly_A.evaluations().len() / 2; for i in 0..len { // eval 0: bound_func is A(low) eval_point_0 += comb_func(&poly_A[i], &poly_B[i], &poly_C[i]); @@ -146,14 +148,14 @@ impl SumcheckInstanceProof { claim: &E, num_rounds: usize, poly_vec_par: ( - &mut Vec<&mut DensePolynomial>, - &mut Vec<&mut DensePolynomial>, - &mut DensePolynomial, + &mut Vec<&mut DenseMultilinearExtension>, + &mut Vec<&mut DenseMultilinearExtension>, + &mut DenseMultilinearExtension, ), poly_vec_seq: ( - &mut Vec<&mut DensePolynomial>, - &mut Vec<&mut DensePolynomial>, - &mut Vec<&mut DensePolynomial>, + &mut Vec<&mut DenseMultilinearExtension>, + &mut Vec<&mut DenseMultilinearExtension>, + &mut Vec<&mut DenseMultilinearExtension>, ), coeffs: &[E], comb_func: F, @@ -316,7 +318,8 @@ impl SumcheckInstanceProof { claims_dotp, ) } - + */ + pub fn prove_cubic_disjoint_rounds( claim: &E, num_rounds: usize, @@ -326,7 +329,7 @@ impl SumcheckInstanceProof { single_inst: bool, // indicates whether poly_B only has one instance num_witness_secs: usize, mut num_inputs: Vec>, - poly_A: &mut DensePolynomial, + poly_A: &mut DenseMultilinearExtension, poly_B: &mut DensePolynomialPqx, poly_C: &mut DensePolynomialPqx, comb_func: F, @@ -418,10 +421,11 @@ impl SumcheckInstanceProof { } for y in 0..num_inputs[p][w] { // evaluate A, B, C on p, w, y + let poly_A_vec = poly_A.get_ext_field_vec(); let (poly_A_low, poly_A_high) = match mode { - MODE_X => (poly_A[p], poly_A[p]), - MODE_W => (poly_A[p], poly_A[p]), - MODE_P => (poly_A[2 * p], poly_A[2 * p + 1]), + MODE_X => (poly_A_vec[p], poly_A_vec[p]), + MODE_W => (poly_A_vec[p], poly_A_vec[p]), + MODE_P => (poly_A_vec[2 * p], poly_A_vec[2 * p + 1]), _ => unreachable!() }; let poly_B_low = poly_B.index_low(p_inst, 0, w, y, mode); @@ -477,7 +481,7 @@ impl SumcheckInstanceProof { // bound all tables to the verifier's challenege if mode == MODE_P { - poly_A.bound_poly_var_bot(&r_j); + poly_A.fix_variables_in_place(&[r_j]); } if mode != MODE_P || !single_inst { poly_B.bound_poly(&r_j, mode); @@ -491,7 +495,7 @@ impl SumcheckInstanceProof { SumcheckInstanceProof::new(polys), r, vec![ - poly_A[0], + poly_A.get_ext_field_vec()[0], poly_B.index(0, 0, 0, 0), poly_C.index(0, 0, 0, 0), ], @@ -507,9 +511,9 @@ impl SumcheckInstanceProof { num_rounds_p: usize, mut num_proofs: Vec, mut num_cons: Vec, - poly_Ap: &mut DensePolynomial, - poly_Aq: &mut DensePolynomial, - poly_Ax: &mut DensePolynomial, + poly_Ap: &mut DenseMultilinearExtension, + poly_Aq: &mut DenseMultilinearExtension, + poly_Ax: &mut DenseMultilinearExtension, poly_B: &mut DensePolynomialPqx, poly_C: &mut DensePolynomialPqx, poly_D: &mut DensePolynomialPqx, @@ -596,8 +600,11 @@ impl SumcheckInstanceProof { let mut eval_point_3 = ZERO; for x in 0..num_cons[p] { // evaluate A, B, C, D on p, q, x - let poly_A_low = poly_Ap[p] * poly_Aq[q] * poly_Ax[2 * x]; - let poly_A_high = poly_Ap[p] * poly_Aq[q] * poly_Ax[2 * x + 1]; + let poly_Ap_vec = poly_Ap.get_ext_field_vec(); + let poly_Aq_vec = poly_Aq.get_ext_field_vec(); + let poly_Ax_vec = poly_Ax.get_ext_field_vec(); + let poly_A_low = poly_Ap_vec[p] * poly_Aq_vec[q] * poly_Ax_vec[2 * x]; + let poly_A_high = poly_Ap_vec[p] * poly_Aq_vec[q] * poly_Ax_vec[2 * x + 1]; let poly_B_low = poly_B.index_low(p, q, 0, x, mode); let poly_B_high = poly_B.index_high(p, q, 0, x, mode); let poly_C_low = poly_C.index_low(p, q, 0, x, mode); @@ -666,14 +673,17 @@ impl SumcheckInstanceProof { for q in 0..num_proofs[p] { for x in 0..num_cons[p] { // evaluate A, B, C, D on p, q, x + let poly_Ap_vec = poly_Ap.get_ext_field_vec(); + let poly_Aq_vec = poly_Aq.get_ext_field_vec(); + let poly_Ax_vec = poly_Ax.get_ext_field_vec(); let (poly_A_low, poly_A_high) = match mode { MODE_Q => ( - poly_Ap[p] * poly_Aq[2 * q] * poly_Ax[x], - poly_Ap[p] * poly_Aq[2 * q + 1] * poly_Ax[x], + poly_Ap_vec[p] * poly_Aq_vec[2 * q] * poly_Ax_vec[x], + poly_Ap_vec[p] * poly_Aq_vec[2 * q + 1] * poly_Ax_vec[x], ), MODE_P => ( - poly_Ap[2 * p] * poly_Aq[q] * poly_Ax[x], - poly_Ap[2 * p + 1] * poly_Aq[q] * poly_Ax[x], + poly_Ap_vec[2 * p] * poly_Aq_vec[q] * poly_Ax_vec[x], + poly_Ap_vec[2 * p + 1] * poly_Aq_vec[q] * poly_Ax_vec[x], ), _ => unreachable!() }; @@ -742,11 +752,11 @@ impl SumcheckInstanceProof { // bound all tables to the verifier's challenege if mode == MODE_X { - poly_Ax.bound_poly_var_bot(&r_j); + poly_Ax.fix_variables_in_place(&[r_j]); } else if mode == MODE_Q { - poly_Aq.bound_poly_var_bot(&r_j); + poly_Aq.fix_variables_in_place(&[r_j]); } else if mode == MODE_P { - poly_Ap.bound_poly_var_bot(&r_j); + poly_Ap.fix_variables_in_place(&[r_j]); } else { unreachable!() } @@ -761,7 +771,7 @@ impl SumcheckInstanceProof { SumcheckInstanceProof::new(polys), r, vec![ - poly_Ap[0] * poly_Aq[0] * poly_Ax[0], + poly_Ap.get_ext_field_vec()[0] * poly_Aq.get_ext_field_vec()[0] * poly_Ax.get_ext_field_vec()[0], poly_B.index(0, 0, 0, 0), poly_C.index(0, 0, 0, 0), poly_D.index(0, 0, 0, 0), From efdc44a40a7ed0bf3e9b849ea1062d999af337dd Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 12 Feb 2025 17:04:25 -0500 Subject: [PATCH 14/18] Correct DensePolynomial in Spartan sumcheck --- spartan_parallel/src/sumcheck.rs | 95 ++++++++++++++++++-------------- 1 file changed, 53 insertions(+), 42 deletions(-) diff --git a/spartan_parallel/src/sumcheck.rs b/spartan_parallel/src/sumcheck.rs index 69eaa995..a5cbb13a 100644 --- a/spartan_parallel/src/sumcheck.rs +++ b/spartan_parallel/src/sumcheck.rs @@ -70,7 +70,6 @@ impl SumcheckInstanceProof { impl SumcheckInstanceProof { // _debug: remove native sumcheck prover - /* pub fn prove_cubic( claim: &E, num_rounds: usize, @@ -93,13 +92,17 @@ impl SumcheckInstanceProof { let len = poly_A.evaluations().len() / 2; for i in 0..len { + let poly_A_vec = poly_A.get_ext_field_vec(); + let poly_B_vec = poly_B.get_ext_field_vec(); + let poly_C_vec = poly_C.get_ext_field_vec(); + // eval 0: bound_func is A(low) - eval_point_0 += comb_func(&poly_A[i], &poly_B[i], &poly_C[i]); + eval_point_0 += comb_func(&poly_A_vec[i], &poly_B_vec[i], &poly_C_vec[i]); // eval 2: bound_func is -A(low) + 2*A(high) - let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; - let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i]; + let poly_A_bound_point = poly_A_vec[len + i] + poly_A_vec[len + i] - poly_A_vec[i]; + let poly_B_bound_point = poly_B_vec[len + i] + poly_B_vec[len + i] - poly_B_vec[i]; + let poly_C_bound_point = poly_C_vec[len + i] + poly_C_vec[len + i] - poly_C_vec[i]; eval_point_2 = eval_point_2 + comb_func( &poly_A_bound_point, @@ -108,9 +111,9 @@ impl SumcheckInstanceProof { ); // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) - let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i]; - let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i]; + let poly_A_bound_point = poly_A_bound_point + poly_A_vec[len + i] - poly_A_vec[i]; + let poly_B_bound_point = poly_B_bound_point + poly_B_vec[len + i] - poly_B_vec[i]; + let poly_C_bound_point = poly_C_bound_point + poly_C_vec[len + i] - poly_C_vec[i]; eval_point_3 = eval_point_3 + comb_func( @@ -130,9 +133,9 @@ impl SumcheckInstanceProof { let r_j = challenge_scalar(transcript, b"challenge_nextround"); r.push(r_j); // bound all tables to the verifier's challenege - poly_A.bound_poly_var_top(&r_j); - poly_B.bound_poly_var_top(&r_j); - poly_C.bound_poly_var_top(&r_j); + poly_A.fix_variables_in_place(&[r_j]); + poly_B.fix_variables_in_place(&[r_j]); + poly_C.fix_variables_in_place(&[r_j]); e = poly.evaluate(&r_j); cubic_polys.push(poly.compress()); } @@ -140,7 +143,11 @@ impl SumcheckInstanceProof { ( SumcheckInstanceProof::new(cubic_polys), r, - vec![poly_A[0], poly_B[0], poly_C[0]], + vec![ + poly_A.get_ext_field_vec()[0], + poly_B.get_ext_field_vec()[0], + poly_C.get_ext_field_vec()[0] + ], ) } @@ -167,7 +174,6 @@ impl SumcheckInstanceProof { let (poly_A_vec_par, poly_B_vec_par, poly_C_par) = poly_vec_par; let (poly_A_vec_seq, poly_B_vec_seq, poly_C_vec_seq) = poly_vec_seq; - //let (poly_A_vec_seq, poly_B_vec_seq, poly_C_vec_seq) = poly_vec_seq; let mut e = *claim; let mut r: Vec = Vec::new(); let mut cubic_polys: Vec> = Vec::new(); @@ -180,15 +186,18 @@ impl SumcheckInstanceProof { let mut eval_point_2 = E::ZERO; let mut eval_point_3 = E::ZERO; - let len = poly_A.len() / 2; + let len = poly_A.evaluations().len() / 2; for i in 0..len { + let poly_A_vec = poly_A.get_ext_field_vec(); + let poly_B_vec = poly_B.get_ext_field_vec(); + let poly_C_par_vec = poly_C_par.get_ext_field_vec(); // eval 0: bound_func is A(low) - eval_point_0 = eval_point_0 + comb_func(&poly_A[i], &poly_B[i], &poly_C_par[i]); + eval_point_0 = eval_point_0 + comb_func(&poly_A_vec[i], &poly_B_vec[i], &poly_C_par_vec[i]); // eval 2: bound_func is -A(low) + 2*A(high) - let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; - let poly_C_bound_point = poly_C_par[len + i] + poly_C_par[len + i] - poly_C_par[i]; + let poly_A_bound_point = poly_A_vec[len + i] + poly_A_vec[len + i] - poly_A_vec[i]; + let poly_B_bound_point = poly_B_vec[len + i] + poly_B_vec[len + i] - poly_B_vec[i]; + let poly_C_bound_point = poly_C_par_vec[len + i] + poly_C_par_vec[len + i] - poly_C_par_vec[i]; eval_point_2 = eval_point_2 + comb_func( &poly_A_bound_point, @@ -197,9 +206,9 @@ impl SumcheckInstanceProof { ); // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) - let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i]; - let poly_C_bound_point = poly_C_bound_point + poly_C_par[len + i] - poly_C_par[i]; + let poly_A_bound_point = poly_A_bound_point + poly_A_vec[len + i] - poly_A_vec[i]; + let poly_B_bound_point = poly_B_bound_point + poly_B_vec[len + i] - poly_B_vec[i]; + let poly_C_bound_point = poly_C_bound_point + poly_C_par_vec[len + i] - poly_C_par_vec[i]; eval_point_3 = eval_point_3 + comb_func( @@ -220,14 +229,17 @@ impl SumcheckInstanceProof { let mut eval_point_0 = E::ZERO; let mut eval_point_2 = E::ZERO; let mut eval_point_3 = E::ZERO; - let len = poly_A.len() / 2; + let len = poly_A.evaluations().len() / 2; for i in 0..len { + let poly_A_vec = poly_A.get_ext_field_vec(); + let poly_B_vec = poly_B.get_ext_field_vec(); + let poly_C_vec = poly_C.get_ext_field_vec(); // eval 0: bound_func is A(low) - eval_point_0 = eval_point_0 + comb_func(&poly_A[i], &poly_B[i], &poly_C[i]); + eval_point_0 = eval_point_0 + comb_func(&poly_A_vec[i], &poly_B_vec[i], &poly_C_vec[i]); // eval 2: bound_func is -A(low) + 2*A(high) - let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; - let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i]; + let poly_A_bound_point = poly_A_vec[len + i] + poly_A_vec[len + i] - poly_A_vec[i]; + let poly_B_bound_point = poly_B_vec[len + i] + poly_B_vec[len + i] - poly_B_vec[i]; + let poly_C_bound_point = poly_C_vec[len + i] + poly_C_vec[len + i] - poly_C_vec[i]; eval_point_2 = eval_point_2 + comb_func( &poly_A_bound_point, @@ -235,9 +247,9 @@ impl SumcheckInstanceProof { &poly_C_bound_point, ); // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) - let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i]; - let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i]; - let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i]; + let poly_A_bound_point = poly_A_bound_point + poly_A_vec[len + i] - poly_A_vec[i]; + let poly_B_bound_point = poly_B_bound_point + poly_B_vec[len + i] - poly_B_vec[i]; + let poly_C_bound_point = poly_C_bound_point + poly_C_vec[len + i] - poly_C_vec[i]; eval_point_3 = eval_point_3 + comb_func( &poly_A_bound_point, @@ -273,19 +285,19 @@ impl SumcheckInstanceProof { // bound all tables to the verifier's challenege for (poly_A, poly_B) in poly_A_vec_par.iter_mut().zip(poly_B_vec_par.iter_mut()) { - poly_A.bound_poly_var_top(&r_j); - poly_B.bound_poly_var_top(&r_j); + poly_A.fix_variables_in_place(&[r_j]); + poly_B.fix_variables_in_place(&[r_j]); } - poly_C_par.bound_poly_var_top(&r_j); + poly_C_par.fix_variables_in_place(&[r_j]); for (poly_A, poly_B, poly_C) in izip!( poly_A_vec_seq.iter_mut(), poly_B_vec_seq.iter_mut(), poly_C_vec_seq.iter_mut() ) { - poly_A.bound_poly_var_top(&r_j); - poly_B.bound_poly_var_top(&r_j); - poly_C.bound_poly_var_top(&r_j); + poly_A.fix_variables_in_place(&[r_j]); + poly_B.fix_variables_in_place(&[r_j]); + poly_C.fix_variables_in_place(&[r_j]); } e = poly.evaluate(&r_j); @@ -293,21 +305,21 @@ impl SumcheckInstanceProof { } let poly_A_par_final = (0..poly_A_vec_par.len()) - .map(|i| poly_A_vec_par[i][0]) + .map(|i| poly_A_vec_par[i].get_ext_field_vec()[0]) .collect(); let poly_B_par_final = (0..poly_B_vec_par.len()) - .map(|i| poly_B_vec_par[i][0]) + .map(|i| poly_B_vec_par[i].get_ext_field_vec()[0]) .collect(); - let claims_prod = (poly_A_par_final, poly_B_par_final, poly_C_par[0]); + let claims_prod = (poly_A_par_final, poly_B_par_final, poly_C_par.get_ext_field_vec()[0]); let poly_A_seq_final = (0..poly_A_vec_seq.len()) - .map(|i| poly_A_vec_seq[i][0]) + .map(|i| poly_A_vec_seq[i].get_ext_field_vec()[0]) .collect(); let poly_B_seq_final = (0..poly_B_vec_seq.len()) - .map(|i| poly_B_vec_seq[i][0]) + .map(|i| poly_B_vec_seq[i].get_ext_field_vec()[0]) .collect(); let poly_C_seq_final = (0..poly_C_vec_seq.len()) - .map(|i| poly_C_vec_seq[i][0]) + .map(|i| poly_C_vec_seq[i].get_ext_field_vec()[0]) .collect(); let claims_dotp = (poly_A_seq_final, poly_B_seq_final, poly_C_seq_final); @@ -318,7 +330,6 @@ impl SumcheckInstanceProof { claims_dotp, ) } - */ pub fn prove_cubic_disjoint_rounds( claim: &E, From a7f79ee2ec2fad1eabc23662385e0fc7d202eab1 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 12 Feb 2025 18:03:33 -0500 Subject: [PATCH 15/18] Ensure BaseField vectors --- spartan_parallel/src/lib.rs | 2 +- spartan_parallel/src/sparse_mlpoly.rs | 21 +++++++++++---------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/spartan_parallel/src/lib.rs b/spartan_parallel/src/lib.rs index 0df1e78f..fe10b8d7 100644 --- a/spartan_parallel/src/lib.rs +++ b/spartan_parallel/src/lib.rs @@ -948,7 +948,7 @@ impl> SNARK< let param = Pcs::setup(num_vars).unwrap(); let (pp, _error) = Pcs::trim(param, num_vars).unwrap(); // create a multilinear polynomial using the supplied assignment for variables - let poly = DenseMultilinearExtension::from_evaluation_vec_smart(mat_concat_p.len().log_2(), mat_concat_p); + let poly = DenseMultilinearExtension::from_evaluations_vec(mat_concat_p.len().log_2(), mat_concat_p); let p_comm = Pcs::commit(&pp, &poly).unwrap(); let v_comm = Pcs::get_pure_commitment(&p_comm); (poly, p_comm, v_comm) diff --git a/spartan_parallel/src/sparse_mlpoly.rs b/spartan_parallel/src/sparse_mlpoly.rs index 0b927de8..7dadd4d5 100644 --- a/spartan_parallel/src/sparse_mlpoly.rs +++ b/spartan_parallel/src/sparse_mlpoly.rs @@ -241,19 +241,19 @@ impl AddrTimestamps { audit_ts[addr] = w_ts; } - let ops_addr_inst_evals = ops_addr_inst.into_iter().map(|&n| E::from_u128(n as u128)).collect::>(); - let read_ts_evals = read_ts.into_iter().map(|n| E::from_u128(n as u128)).collect::>(); - ops_addr_vec.push(DenseMultilinearExtension::from_evaluation_vec_smart(ops_addr_inst_evals.len().log_2(), ops_addr_inst_evals)); - read_ts_vec.push(DenseMultilinearExtension::from_evaluation_vec_smart(read_ts_evals.len().log_2(), read_ts_evals)); + let ops_addr_inst_evals = ops_addr_inst.into_iter().map(|&n| E::from_u128(n as u128).as_bases()[0]).collect::>(); + let read_ts_evals = read_ts.into_iter().map(|n| E::from_u128(n as u128).as_bases()[0]).collect::>(); + ops_addr_vec.push(DenseMultilinearExtension::from_evaluations_vec(ops_addr_inst_evals.len().log_2(), ops_addr_inst_evals)); + read_ts_vec.push(DenseMultilinearExtension::from_evaluations_vec(read_ts_evals.len().log_2(), read_ts_evals)); } - let audit_ts_evals = audit_ts.into_iter().map(|n| E::from_u128(n as u128)).collect::>(); + let audit_ts_evals = audit_ts.into_iter().map(|n| E::from_u128(n as u128).as_bases()[0]).collect::>(); AddrTimestamps { ops_addr: ops_addr_vec, ops_addr_usize: ops_addr, read_ts: read_ts_vec, - audit_ts: DenseMultilinearExtension::from_evaluation_vec_smart(audit_ts_evals.len().log_2(), audit_ts_evals), + audit_ts: DenseMultilinearExtension::from_evaluations_vec(audit_ts_evals.len().log_2(), audit_ts_evals), } } @@ -261,11 +261,11 @@ impl AddrTimestamps { let evals = (0..addr.len()) .map(|i| { let a = addr[i]; - mem_val[a] + mem_val[a].as_bases()[0] }) - .collect::>(); + .collect::>(); - DenseMultilinearExtension::from_evaluation_vec_smart(evals.len().log_2(), evals) + DenseMultilinearExtension::from_evaluations_vec(evals.len().log_2(), evals) } pub fn deref(&self, mem_val: &[E]) -> Vec> { @@ -350,7 +350,8 @@ impl SparseMatPolynomial { let (ops_row, ops_col, val) = poly.sparse_to_dense_vecs(N); ops_row_vec.push(ops_row); ops_col_vec.push(ops_col); - val_vec.push(DenseMultilinearExtension::from_evaluation_vec_smart(val.len().log_2(), val)); + let val = val.into_iter().map(|e| e.as_bases()[0]).collect::>(); + val_vec.push(DenseMultilinearExtension::from_evaluations_vec(val.len().log_2(), val)); } let any_poly = &sparse_polys[0]; From 3b6955e18595f2f50848bb44e0e7784faa441148 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 12 Feb 2025 18:23:29 -0500 Subject: [PATCH 16/18] Remove DensePolynomial --- spartan_parallel/src/dense_mlpoly.rs | 422 +-------------------------- 1 file changed, 1 insertion(+), 421 deletions(-) diff --git a/spartan_parallel/src/dense_mlpoly.rs b/spartan_parallel/src/dense_mlpoly.rs index 544f3998..bdb6cd5b 100644 --- a/spartan_parallel/src/dense_mlpoly.rs +++ b/spartan_parallel/src/dense_mlpoly.rs @@ -113,424 +113,4 @@ impl IdentityPolynomial { .map(|i| E::from((len - i - 1).pow2() as u64) * r[i]) .sum() } -} - -// _debug -/* -impl DensePolynomial { - pub fn new(mut Z: Vec) -> Self { - // If length of Z is not a power of 2, append Z with 0 - let zero = E::ZERO; - Z.extend(vec![zero; Z.len().next_power_of_two() - Z.len()]); - DensePolynomial { - num_vars: Z.len().log_2(), - len: Z.len(), - Z, - } - } - - pub fn get_num_vars(&self) -> usize { - self.num_vars - } - - pub fn len(&self) -> usize { - self.len - } - - pub fn clone(&self) -> DensePolynomial { - DensePolynomial::new(self.Z[0..self.len].to_vec()) - } - - pub fn split(&self, idx: usize) -> (DensePolynomial, DensePolynomial) { - assert!(idx < self.len()); - ( - DensePolynomial::new(self.Z[..idx].to_vec()), - DensePolynomial::new(self.Z[idx..2 * idx].to_vec()), - ) - } - - pub fn bound(&self, L: &[E]) -> Vec { - let (left_num_vars, right_num_vars) = - EqPolynomial::::compute_factored_lens(self.get_num_vars()); - let L_size = left_num_vars.pow2(); - let R_size = right_num_vars.pow2(); - (0..R_size) - .map(|i| (0..L_size).map(|j| L[j] * self.Z[j * R_size + i]).sum()) - .collect() - } - - pub fn bound_poly_var_top(&mut self, r: &E) { - let n = self.len() / 2; - for i in 0..n { - self.Z[i] = self.Z[i] + *r * (self.Z[i + n] - self.Z[i]); - } - self.num_vars -= 1; - self.len = n; - } - - // Bound_var_top but the polynomial is in (x, q, p) form and certain (p, q) pair is invalid - pub fn bound_poly_var_top_disjoint_rounds( - &mut self, - r: &E, - proof_space: usize, - instance_space: usize, - cons_len: usize, - proof_len: usize, - instance_len: usize, - num_proofs: &Vec, - ) { - let n = self.len() / 2; - assert_eq!(n, cons_len * proof_len * instance_len); - - for p in 0..instance_len { - // Certain p, q combinations within the boolean hypercube always evaluate to 0 - let max_q = if proof_len != proof_space { - proof_len - } else { - num_proofs[p] - }; - for q in 0..max_q { - for x in 0..cons_len { - let i = x * proof_space * instance_space + q * instance_space + p; - self.Z[i] = self.Z[i] + *r * (self.Z[i + n] - self.Z[i]); - } - } - } - self.num_vars -= 1; - self.len = n; - } - - // The polynomial is in (q, p, x) form and certain (p, q) pair is invalid - // Binding the entire "q" section and q is in reverse order - // Use "num_proofs" to record how many "q"s need to process for each "p" - pub fn bound_poly_var_front_rq( - &mut self, - r_q: &Vec, - mut max_proof_space: usize, - instance_space: usize, - cons_space: usize, - mut num_proofs: Vec, - ) { - let mut n = self.len(); - assert_eq!(n, max_proof_space * instance_space * cons_space); - - for r in r_q { - n /= 2; - max_proof_space /= 2; - - for p in 0..instance_space { - if num_proofs[p] == 1 { - // q = 0 - for x in 0..cons_space { - let i = p * cons_space + x; - self.Z[i] = (E::ONE - *r) * self.Z[i]; - } - } else { - num_proofs[p] /= 2; - let step = max_proof_space / num_proofs[p]; - for q in (0..max_proof_space).step_by(step) { - for x in 0..cons_space { - let i = q * instance_space * cons_space + p * cons_space + x; - self.Z[i] = self.Z[i] + *r * (self.Z[i + n] - self.Z[i]); - } - } - } - } - self.num_vars -= 1; - self.len = n; - } - } - - pub fn bound_poly_var_bot(&mut self, r: &E) { - let n = self.len() / 2; - for i in 0..n { - self.Z[i] = self.Z[2 * i] + *r * (self.Z[2 * i + 1] - self.Z[2 * i]); - } - self.num_vars -= 1; - self.len = n; - } - - fn fold_r(proofs: &mut [E], r: &[E], step: usize, mut l: usize) { - for r in r { - let r1 = E::ONE - r.clone(); - let r2 = r.clone(); - - l = l.div_ceil(2); - (0..l).for_each(|i| { - proofs[i * step] = r1 * proofs[2 * i * step] + r2 * proofs[(2 * i + 1) * step]; - }); - } - } - - // returns Z(r) in O(n) time - pub fn evaluate_and_consume_parallel(&mut self, r: &[E]) -> E { - assert_eq!(r.len(), self.get_num_vars()); - let mut inst = std::mem::take(&mut self.Z); - - let len = self.len; - let dist_size = len / min(len, rayon::current_num_threads().next_power_of_two()); // distributed number of proofs on each thread - let num_threads = len / dist_size; - - // To perform rigorous parallelism, both len and # threads must be powers of 2 - // # threads must fully divide num_proofs for even distribution - assert_eq!(len, len.next_power_of_two()); - assert_eq!(num_threads, num_threads.next_power_of_two()); - - // Determine parallelism levels - let levels = len.log_2(); // total layers - let sub_levels = dist_size.log_2(); // parallel layers - let final_levels = num_threads.log_2(); // single core final layers - // Divide r into sub and final - let sub_r = &r[0..sub_levels]; - let final_r = &r[sub_levels..levels]; - - if sub_levels > 0 { - inst = inst - .par_chunks_mut(dist_size) - .map(|chunk| { - Self::fold_r(chunk, sub_r, 1, dist_size); - chunk.to_vec() - }) - .flatten().collect() - } - - if final_levels > 0 { - // aggregate the final result from sub-threads outputs using a single core - Self::fold_r(&mut inst, final_r, dist_size, num_threads); - } - inst[0] - } - - // returns Z(r) in O(n) time - pub fn evaluate(&self, r: &[E]) -> E { - // r must have a value for each variable - assert_eq!(r.len(), self.get_num_vars()); - let chis = EqPolynomial::new(r.to_vec()).evals(); - assert_eq!(chis.len(), self.Z.len()); - Self::compute_dotproduct(&self.Z, &chis) - } - - fn compute_dotproduct(a: &[E], b: &[E]) -> E { - assert_eq!(a.len(), b.len()); - (0..a.len()).map(|i| a[i] * b[i]).sum() - } - - fn vec(&self) -> &Vec { - &self.Z - } - - pub fn extend(&mut self, other: &DensePolynomial) { - // TODO: allow extension even when some vars are bound - assert_eq!(self.Z.len(), self.len); - let other_vec = other.vec(); - assert_eq!(other_vec.len(), self.len); - self.Z.extend(other_vec); - self.num_vars += 1; - self.len *= 2; - assert_eq!(self.Z.len(), self.len); - } - - pub fn merge<'a, I>(polys: I) -> DensePolynomial - where - I: IntoIterator>, - { - let mut Z: Vec = Vec::new(); - for poly in polys.into_iter() { - Z.extend(poly.vec()); - } - - // pad the polynomial with zero polynomial at the end - Z.resize(Z.len().next_power_of_two(), E::ZERO); - - DensePolynomial::new(Z) - } - - pub fn from_usize(Z: &[usize]) -> Self { - DensePolynomial::new( - (0..Z.len()) - .map(|i| E::from(Z[i] as u64)) - .collect::>(), - ) - } - - pub fn to_ceno_mle(&self) -> DenseMultilinearExtension { - DenseMultilinearExtension::from_evaluation_vec_smart(self.num_vars, self.Z.clone()) - } -} -*/ - -// _debug -/* -impl Index for DensePolynomial { - type Output = E; - - #[inline(always)] - fn index(&self, _index: usize) -> &E { - &(self.Z[_index]) - } -} -*/ - -/* -#[cfg(test)] -mod tests { - use super::*; - use crate::scalar::Scalar; - use rand::rngs::OsRng; - - fn evaluate_with_LR(Z: &[Scalar], r: &[Scalar]) -> Scalar { - let eq = EqPolynomial::new(r.to_vec()); - let (L, R) = eq.compute_factored_evals(); - - let ell = r.len(); - // ensure ell is even - assert!(ell % 2 == 0); - // compute n = 2^\ell - let n = ell.pow2(); - // compute m = sqrt(n) = 2^{\ell/2} - let m = n.square_root(); - - // compute vector-matrix product between L and Z viewed as a matrix - let LZ = (0..m) - .map(|i| (0..m).map(|j| L[j] * Z[j * m + i]).sum()) - .collect::>(); - - // compute dot product between LZ and R - DensePolynomial::compute_dotproduct(&LZ, &R) - } - - #[test] - fn check_polynomial_evaluation() { - // Z = [1, 2, 1, 4] - let Z = vec![ - Scalar::one(), - Scalar::from(2_usize), - Scalar::from(1_usize), - Scalar::from(4_usize), - ]; - - // r = [4,3] - let r = vec![Scalar::from(4_usize), Scalar::from(3_usize)]; - - let eval_with_LR = evaluate_with_LR(&Z, &r); - let poly = DensePolynomial::new(Z); - - let eval = poly.evaluate(&r); - assert_eq!(eval, Scalar::from(28_usize)); - assert_eq!(eval_with_LR, eval); - } - - pub fn compute_factored_chis_at_r(r: &[Scalar]) -> (Vec, Vec) { - let mut L: Vec = Vec::new(); - let mut R: Vec = Vec::new(); - - let ell = r.len(); - assert!(ell % 2 == 0); // ensure ell is even - let n = ell.pow2(); - let m = n.square_root(); - - // compute row vector L - for i in 0..m { - let mut chi_i = Scalar::one(); - for j in 0..ell / 2 { - let bit_j = ((m * i) & (1 << (r.len() - j - 1))) > 0; - if bit_j { - chi_i *= r[j]; - } else { - chi_i *= Scalar::one() - r[j]; - } - } - L.push(chi_i); - } - - // compute column vector R - for i in 0..m { - let mut chi_i = Scalar::one(); - for j in ell / 2..ell { - let bit_j = (i & (1 << (r.len() - j - 1))) > 0; - if bit_j { - chi_i *= r[j]; - } else { - chi_i *= Scalar::one() - r[j]; - } - } - R.push(chi_i); - } - (L, R) - } - - pub fn compute_chis_at_r(r: &[Scalar]) -> Vec { - let ell = r.len(); - let n = ell.pow2(); - let mut chis: Vec = Vec::new(); - for i in 0..n { - let mut chi_i = Scalar::one(); - for j in 0..r.len() { - let bit_j = (i & (1 << (r.len() - j - 1))) > 0; - if bit_j { - chi_i *= r[j]; - } else { - chi_i *= Scalar::one() - r[j]; - } - } - chis.push(chi_i); - } - chis - } - - pub fn compute_outerproduct(L: Vec, R: Vec) -> Vec { - assert_eq!(L.len(), R.len()); - (0..L.len()) - .map(|i| (0..R.len()).map(|j| L[i] * R[j]).collect::>()) - .collect::>>() - .into_iter() - .flatten() - .collect::>() - } - - #[test] - fn check_memoized_chis() { - let mut csprng: OsRng = OsRng; - - let s = 10; - let mut r: Vec = Vec::new(); - for _i in 0..s { - r.push(Scalar::random(&mut csprng)); - } - let chis = tests::compute_chis_at_r(&r); - let chis_m = EqPolynomial::new(r).evals(); - assert_eq!(chis, chis_m); - } - - #[test] - fn check_factored_chis() { - let mut csprng: OsRng = OsRng; - - let s = 10; - let mut r: Vec = Vec::new(); - for _i in 0..s { - r.push(Scalar::random(&mut csprng)); - } - let chis = EqPolynomial::new(r.clone()).evals(); - let (L, R) = EqPolynomial::new(r).compute_factored_evals(); - let O = compute_outerproduct(L, R); - assert_eq!(chis, O); - } - - #[test] - fn check_memoized_factored_chis() { - let mut csprng: OsRng = OsRng; - - let s = 10; - let mut r: Vec = Vec::new(); - for _i in 0..s { - r.push(Scalar::random(&mut csprng)); - } - let (L, R) = tests::compute_factored_chis_at_r(&r); - let eq = EqPolynomial::new(r); - let (L2, R2) = eq.compute_factored_evals(); - assert_eq!(L, L2); - assert_eq!(R, R2); - } -} -*/ \ No newline at end of file +} \ No newline at end of file From aaa997ae23ae164340ec4f9f903311a0046a0fde Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 12 Feb 2025 18:25:12 -0500 Subject: [PATCH 17/18] Remove debug block --- spartan_parallel/src/custom_dense_mlpoly.rs | 28 --------------------- 1 file changed, 28 deletions(-) diff --git a/spartan_parallel/src/custom_dense_mlpoly.rs b/spartan_parallel/src/custom_dense_mlpoly.rs index 728d8b35..24360b84 100644 --- a/spartan_parallel/src/custom_dense_mlpoly.rs +++ b/spartan_parallel/src/custom_dense_mlpoly.rs @@ -365,32 +365,4 @@ impl DensePolynomialPqx { cl.bound_poly_vars_rp(rp_rev); return cl.index(0, 0, 0, 0); } - - // _debug - // // Convert to a (p, q_rev, x_rev) regular dense poly of form (p, q, x) - // pub fn to_dense_poly(&self) -> DensePolynomial { - // let ZERO = E::ZERO; - - // let p_space = self.num_vars_p.pow2(); - // let q_space = self.num_vars_q.pow2(); - // let w_space = self.num_vars_w.pow2(); - // let x_space = self.num_vars_x.pow2(); - - // let mut Z_poly = vec![ZERO; p_space * q_space * w_space * x_space]; - // for p in 0..self.num_instances { - // for q in 0..self.num_proofs[p] { - // for w in 0..self.num_witness_secs { - // for x in 0..self.num_inputs[p][w] { - // Z_poly[ - // p * q_space * w_space * x_space - // + q * w_space * x_space - // + w * x_space - // + x - // ] = self.Z[p][q][w][x]; - // } - // } - // } - // } - // DensePolynomial::new(Z_poly) - // } } \ No newline at end of file From df10b3891202df4fbc2986703de32bd286417e3a Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 12 Feb 2025 18:26:14 -0500 Subject: [PATCH 18/18] Remove debug blocks: --- spartan_parallel/src/dense_mlpoly.rs | 10 ---------- spartan_parallel/src/sumcheck.rs | 1 - 2 files changed, 11 deletions(-) diff --git a/spartan_parallel/src/dense_mlpoly.rs b/spartan_parallel/src/dense_mlpoly.rs index bdb6cd5b..5e04c3b6 100644 --- a/spartan_parallel/src/dense_mlpoly.rs +++ b/spartan_parallel/src/dense_mlpoly.rs @@ -11,16 +11,6 @@ use rayon::{iter::ParallelIterator, slice::ParallelSliceMut}; use serde::{Deserialize, Serialize}; use std::cmp::min; -// _debug -/* -#[derive(Debug, Clone)] -pub struct DensePolynomial { - num_vars: usize, // the number of variables in the multilinear polynomial - len: usize, - Z: Vec, // evaluations of the polynomial in all the 2^num_vars Boolean inputs -} -*/ - pub struct EqPolynomial { r: Vec, } diff --git a/spartan_parallel/src/sumcheck.rs b/spartan_parallel/src/sumcheck.rs index a5cbb13a..4e3d3b39 100644 --- a/spartan_parallel/src/sumcheck.rs +++ b/spartan_parallel/src/sumcheck.rs @@ -69,7 +69,6 @@ impl SumcheckInstanceProof { } impl SumcheckInstanceProof { - // _debug: remove native sumcheck prover pub fn prove_cubic( claim: &E, num_rounds: usize,