diff --git a/Cargo.lock b/Cargo.lock index b25d89b1ca..e5b31e5d0c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3057,11 +3057,14 @@ dependencies = [ "e3-fhe", "e3-fhe-params", "e3-multithread", + "e3-polynomial", "e3-request", "e3-sortition", "e3-trbfv", "e3-utils", "e3-zk-helpers", + "fhe-math", + "num-bigint", "serde", "tracing", ] diff --git a/crates/aggregator/Cargo.toml b/crates/aggregator/Cargo.toml index bb8bd0876e..70abbb63da 100644 --- a/crates/aggregator/Cargo.toml +++ b/crates/aggregator/Cargo.toml @@ -18,7 +18,10 @@ e3-evm = { workspace = true } e3-fhe = { workspace = true } e3-fhe-params = { workspace = true } e3-multithread = { workspace = true } +e3-polynomial = { workspace = true } e3-trbfv = { workspace = true } +fhe-math = { workspace = true } +num-bigint = { workspace = true } e3-bfv-client = { workspace = true } e3-request = { workspace = true } e3-sortition = { workspace = true } diff --git a/crates/aggregator/src/threshold_plaintext_aggregator.rs b/crates/aggregator/src/threshold_plaintext_aggregator.rs index 695822208a..955ce108f6 100644 --- a/crates/aggregator/src/threshold_plaintext_aggregator.rs +++ b/crates/aggregator/src/threshold_plaintext_aggregator.rs @@ -11,12 +11,12 @@ use actix::prelude::*; use anyhow::{anyhow, bail, ensure, Result}; use e3_data::Persistable; use e3_events::{ - prelude::*, trap, AggregationProofPending, AggregationProofSigned, BusHandle, + prelude::*, trap, AggregationProofPending, AggregationProofSigned, BusHandle, CircuitName, CommitteeMemberExpelled, ComputeRequest, ComputeResponse, ComputeResponseKind, CorrelationId, DecryptedSharesAggregationProofRequest, DecryptionshareCreated, Die, E3id, EType, EnclaveEvent, - EnclaveEventData, EventContext, PartyProofsToVerify, PlaintextAggregated, Proof, Seed, - Sequenced, ShareVerificationComplete, ShareVerificationDispatched, SignedProofPayload, - TypedEvent, VerificationKind, ZkResponse, + EnclaveEventData, EventContext, PartyProofsToVerify, PlaintextAggregated, Proof, ProofType, + Seed, Sequenced, ShareVerificationComplete, ShareVerificationDispatched, SignedProofFailed, + SignedProofPayload, TypedEvent, VerificationKind, ZkResponse, }; use e3_fhe_params::BfvPreset; use e3_sortition::{E3CommitteeContainsRequest, E3CommitteeContainsResponse, Sortition}; @@ -26,8 +26,11 @@ use e3_trbfv::{ }; use e3_utils::NotifySync; use e3_utils::{utility_types::ArcBytes, MAILBOX_LIMIT}; +use e3_zk_helpers::circuits::commitments::compute_threshold_decryption_share_commitment; use e3_zk_helpers::circuits::threshold::decrypted_shares_aggregation::MAX_MSG_NON_ZERO_COEFFS; -use tracing::{debug, info, trace, warn}; +use e3_zk_helpers::threshold::share_decryption::{Bits as C6Bits, Bounds as C6Bounds}; +use e3_zk_helpers::Computation; +use tracing::{debug, error, info, trace, warn}; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Collecting { @@ -376,7 +379,7 @@ impl ThresholdPlaintextAggregator { .ok_or(anyhow!("Could not get state"))? .try_into()?; - let dishonest_parties = &msg.dishonest_parties; + let mut dishonest_parties = msg.dishonest_parties.clone(); if !dishonest_parties.is_empty() { warn!( "C6 verification: {} dishonest parties filtered: {:?}", @@ -386,7 +389,7 @@ impl ThresholdPlaintextAggregator { } // Filter shares to only honest parties - let honest_shares: Vec<(u64, Vec)> = state + let mut honest_shares: Vec<(u64, Vec)> = state .shares .iter() .filter(|(id, _)| !dishonest_parties.contains(id)) @@ -400,9 +403,32 @@ impl ThresholdPlaintextAggregator { state.threshold_m + 1 ); + // Verify each honest party's raw decryption share matches the + // d_commitment attested by their verified C6 proof. Catches the attack + // where a node sends a valid C6 proof for share d_A but broadcasts + // different bytes d_B. + let share_mismatch_parties = + self.verify_shares_match_c6_commitments(&honest_shares, &state.c6_proofs); + if !share_mismatch_parties.is_empty() { + warn!( + "C6 share-commitment mismatch for {} parties: {:?} — excluding from aggregation", + share_mismatch_parties.len(), + share_mismatch_parties, + ); + + dishonest_parties.extend(&share_mismatch_parties); + honest_shares.retain(|(id, _)| !share_mismatch_parties.contains(id)); + ensure!( + honest_shares.len() > state.threshold_m as usize, + "Not enough honest shares after d_commitment check: {} honest, {} required", + honest_shares.len(), + state.threshold_m + 1 + ); + } + info!( "C6 verification passed: {} honest parties, transitioning to Computing", - honest_shares.len() + honest_shares.len(), ); // Collect honest C6 wrapped proofs sorted by party_id for cross-node folding. @@ -462,6 +488,118 @@ impl ThresholdPlaintextAggregator { Ok(()) } + /// Verify that each honest party's raw decryption share bytes match the + /// `d_commitment` output in their verified C6 proof. Returns party IDs + /// that failed the check. + fn verify_shares_match_c6_commitments( + &self, + honest_shares: &[(u64, Vec)], + c6_proofs: &HashMap>, + ) -> BTreeSet { + let mut mismatched = BTreeSet::new(); + + let Ok((threshold_params, _)) = e3_fhe_params::build_pair_for_preset(self.params_preset) + else { + warn!("Could not build BFV params for d_commitment check — skipping"); + return mismatched; + }; + + // Reuse the same Bounds/Bits computation that C6 codegen uses, + // so d_bit stays in sync if the formula ever changes. + let Ok(bounds) = C6Bounds::compute(self.params_preset, &()) else { + warn!("Could not compute bounds for d_commitment check — skipping"); + return mismatched; + }; + let Ok(bits) = C6Bits::compute(self.params_preset, &bounds) else { + warn!("Could not compute bits for d_commitment check — skipping"); + return mismatched; + }; + let d_bit = bits.d_bit; + + let max_k = MAX_MSG_NON_ZERO_COEFFS; + let c6_output_layout = CircuitName::ThresholdShareDecryption.output_layout(); + let moduli: Vec = threshold_params.moduli().to_vec(); + + for (party_id, shares) in honest_shares { + let Some(proofs) = c6_proofs.get(party_id) else { + warn!( + "No C6 proofs for party {} — marking as mismatched", + party_id + ); + mismatched.insert(*party_id); + continue; + }; + let Some(first_proof) = proofs.first() else { + warn!( + "Empty C6 proof list for party {} — marking as mismatched", + party_id + ); + mismatched.insert(*party_id); + continue; + }; + let Some(c6_d_bytes) = c6_output_layout + .extract_field(&first_proof.payload.proof.public_signals, "d_commitment") + else { + warn!( + "Could not extract d_commitment from C6 proof for party {} — marking as mismatched", + party_id + ); + mismatched.insert(*party_id); + continue; + }; + + let Some(share_bytes) = shares.first() else { + warn!( + "No share bytes for party {} — marking as mismatched", + party_id + ); + mismatched.insert(*party_id); + continue; + }; + let Ok(poly) = e3_trbfv::helpers::try_poly_from_bytes(share_bytes, &threshold_params) + else { + warn!( + "Could not deserialize share for party {} — marking as mismatched", + party_id + ); + mismatched.insert(*party_id); + continue; + }; + let mut crt = e3_polynomial::CrtPolynomial::from_fhe_polynomial(&poly); + + // Apply the same transformations C6's Inputs::compute applies: + // reverse coefficient order + center each limb mod qi. + crt.reverse(); + if let Err(e) = crt.center(&moduli) { + warn!( + "Could not center d_share for party {} — marking as mismatched: {e}", + party_id + ); + mismatched.insert(*party_id); + continue; + } + + let computed = compute_threshold_decryption_share_commitment(&crt, d_bit, max_k); + + // Convert to big-endian 32-byte padded format matching + // Barretenberg's public_signals encoding. + let (_, be_bytes) = computed.to_bytes_be(); + let mut computed_padded = [0u8; 32]; + let start = 32usize.saturating_sub(be_bytes.len()); + computed_padded[start..].copy_from_slice(&be_bytes[..be_bytes.len().min(32)]); + + if computed_padded != c6_d_bytes { + warn!( + "d_commitment mismatch for party {}: raw share commitment differs from C6 proof output", + party_id + ); + mismatched.insert(*party_id); + } + } + + mismatched + } + /// Publish AggregationProofPending for C7 proof generation through ProofRequestActor. pub fn dispatch_c7_proof_request( &mut self, diff --git a/crates/multithread/src/multithread.rs b/crates/multithread/src/multithread.rs index 6357ede622..3aed910f11 100644 --- a/crates/multithread/src/multithread.rs +++ b/crates/multithread/src/multithread.rs @@ -1443,7 +1443,7 @@ fn handle_decrypted_shares_aggregation_proof( proofs.push(proof); } - // 4. Return response + // 5. Return response Ok(ComputeResponse::zk( ZkResponse::DecryptedSharesAggregation(DecryptedSharesAggregationProofResponse { proofs }), request.correlation_id, diff --git a/crates/polynomial/src/crt_polynomial.rs b/crates/polynomial/src/crt_polynomial.rs index ab9fa037ea..237e081b3d 100644 --- a/crates/polynomial/src/crt_polynomial.rs +++ b/crates/polynomial/src/crt_polynomial.rs @@ -75,18 +75,19 @@ impl CrtPolynomial { Self::from_bigint_vectors(limbs) } - /// Builds a `CrtPolynomial` from an fhe-math `Poly` in PowerBasis representation. + /// Builds a `CrtPolynomial` from an fhe-math `Poly` in any representation. /// /// Used to prepare inputs for ZK circuits by converting FHE BFV ciphertext polynomials - /// into CRT limb format. If `p` is in NTT form, it is converted to PowerBasis first. + /// into CRT limb format. If `p` is not in PowerBasis form (e.g. NTT or NttShoup), + /// it is converted first. /// /// # Arguments /// - /// * `p` - An fhe-math polynomial (PowerBasis or Ntt). + /// * `p` - An fhe-math polynomial (any representation). pub fn from_fhe_polynomial(p: &Poly) -> Self { let mut p = p.clone(); - if *p.representation() == Representation::Ntt { + if *p.representation() != Representation::PowerBasis { p.change_representation(Representation::PowerBasis); } diff --git a/crates/zk-helpers/src/circuits/threshold/share_decryption/computation.rs b/crates/zk-helpers/src/circuits/threshold/share_decryption/computation.rs index 839e3dc631..644351809f 100644 --- a/crates/zk-helpers/src/circuits/threshold/share_decryption/computation.rs +++ b/crates/zk-helpers/src/circuits/threshold/share_decryption/computation.rs @@ -363,6 +363,86 @@ mod tests { assert_eq!(bits.d_bit, expected_bit); } + /// Verifies that `CrtPolynomial::reverse()` + `center()` matches + /// `Inputs::compute` for d_commitment, and that the Poly bytes round-trip + /// is lossless. + #[test] + fn test_d_commitment_matches_inputs_compute() { + use crate::circuits::commitments::compute_threshold_decryption_share_commitment; + use crate::circuits::threshold::decrypted_shares_aggregation::MAX_MSG_NON_ZERO_COEFFS; + use crate::threshold::share_decryption::ShareDecryptionCircuitData; + use crate::CiphernodesCommitteeSize; + use fhe_math::rq::{Poly, Representation}; + use fhe_traits::{DeserializeWithContext, Serialize as FheSer}; + use num_traits::ToPrimitive; + + let preset = DEFAULT_BFV_PRESET; + let committee = CiphernodesCommitteeSize::Small.values(); + let sample = ShareDecryptionCircuitData::generate_sample(preset, committee).unwrap(); + let (threshold_params, _) = build_pair_for_preset(preset).unwrap(); + let bounds = Bounds::compute(preset, &()).unwrap(); + let bits = Bits::compute(preset, &bounds).unwrap(); + let moduli: Vec = threshold_params.moduli().to_vec(); + + // Ground truth: Inputs::compute (what the Noir prover receives) + let inputs = Inputs::compute(preset, &sample).unwrap(); + let truth = compute_threshold_decryption_share_commitment( + &inputs.d, + bits.d_bit, + MAX_MSG_NON_ZERO_COEFFS, + ); + + // Aggregator path: CrtPolynomial::reverse() + center() + let mut crt = sample.d_share.clone(); + crt.reverse(); + crt.center(&moduli).unwrap(); + let from_api = compute_threshold_decryption_share_commitment( + &crt, + bits.d_bit, + MAX_MSG_NON_ZERO_COEFFS, + ); + assert_eq!( + truth, from_api, + "CrtPolynomial API must match Inputs::compute" + ); + + // Bytes round-trip: Poly → to_bytes → from_bytes → from_fhe_polynomial + let raw: Vec> = sample + .d_share + .limbs + .iter() + .map(|l| { + l.coefficients() + .iter() + .map(|c| c.to_u64().unwrap()) + .collect() + }) + .collect(); + let n = raw[0].len(); + let mut arr = ndarray::Array2::::zeros((raw.len(), n)); + for (i, limb) in raw.iter().enumerate() { + for (j, &v) in limb.iter().enumerate() { + arr[[i, j]] = v; + } + } + let ctx = threshold_params.ctx_at_level(0).unwrap(); + let mut poly = Poly::zero(&ctx, Representation::PowerBasis); + poly.set_coefficients(arr); + let poly_rt = Poly::from_bytes(&poly.to_bytes(), &ctx).unwrap(); + let mut crt_rt = CrtPolynomial::from_fhe_polynomial(&poly_rt); + crt_rt.reverse(); + crt_rt.center(&moduli).unwrap(); + let from_bytes = compute_threshold_decryption_share_commitment( + &crt_rt, + bits.d_bit, + MAX_MSG_NON_ZERO_COEFFS, + ); + assert_eq!( + truth, from_bytes, + "Bytes round-trip must match Inputs::compute" + ); + } + #[test] fn test_constants_json_roundtrip() { let constants = Configs::compute(DEFAULT_BFV_PRESET, &()).unwrap(); diff --git a/templates/default/tests/integration.spec.ts b/templates/default/tests/integration.spec.ts index dedf593598..1b49a1095c 100644 --- a/templates/default/tests/integration.spec.ts +++ b/templates/default/tests/integration.spec.ts @@ -189,7 +189,7 @@ describe('Integration', () => { const { waitForEvent } = await setupEventListeners(sdk, store) const committeeSize = CommitteeSize.Micro - const duration = 500 + const duration = 450 const inputWindow = await calculateInputWindow(publicClient, duration) const thresholdBfvParams = await sdk.getThresholdBfvParamsSet() const e3ProgramParams = encodeBfvParams(thresholdBfvParams)