diff --git a/crates/aggregator/src/actors/publickey_aggregator.rs b/crates/aggregator/src/actors/publickey_aggregator.rs index 54d3d7c9c..b38331ff0 100644 --- a/crates/aggregator/src/actors/publickey_aggregator.rs +++ b/crates/aggregator/src/actors/publickey_aggregator.rs @@ -16,7 +16,7 @@ use e3_events::{ prelude::*, BusHandle, ComputeRequest, ComputeRequestError, ComputeResponse, ComputeResponseKind, CorrelationId, DKGRecursiveAggregationComplete, Die, DkgAggregationRequest, E3Failed, E3Stage, E3id, EnclaveEvent, EnclaveEventData, EventContext, - FailureReason, KeyshareCreated, OrderedSet, PkAggregationProofPending, + FailureReason, KeyshareCreated, NodesFoldStepRequest, OrderedSet, PkAggregationProofPending, PkAggregationProofRequest, PkAggregationProofSigned, Proof, ProofType, PublicKeyAggregated, Sequenced, ShareVerificationComplete, ShareVerificationDispatched, SignedProofFailed, SignedProofPayload, TypedEvent, VerificationKind, ZkRequest, ZkResponse, @@ -342,6 +342,9 @@ impl PublicKeyAggregator { dkg_aggregated_proof: None, c5_proof_pending: None, last_ec: Some(ec.clone()), + nodes_fold_accumulator: None, + nodes_fold_completed_slots: 0, + nodes_fold_step_correlation: None, }) })?; @@ -392,6 +395,9 @@ impl PublicKeyAggregator { circuit_committee_h, dkg_aggregation_correlation, dkg_aggregated_proof, + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation, .. } = state else { @@ -412,6 +418,9 @@ impl PublicKeyAggregator { dkg_aggregated_proof, c5_proof_pending: Some(c5_proof), last_ec: Some(ec.clone()), + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation, }) })?; self.try_publish_complete() @@ -538,6 +547,9 @@ impl PublicKeyAggregator { dkg_aggregation_correlation, dkg_aggregated_proof, c5_proof_pending, + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation, last_ec: _, } = state else { @@ -562,13 +574,211 @@ impl PublicKeyAggregator { dkg_aggregated_proof, c5_proof_pending, last_ec: Some(ec.clone()), + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation, }) })?; - self.try_dispatch_dkg_aggregation(&ec) + self.try_dispatch_nodes_fold_step(&ec) } - /// Dispatch [`ZkRequest::DkgAggregation`] once C5 and all honest NodeFold proofs are ready. + /// Dispatch the next [`ZkRequest::NodesFoldStep`] if the next slot's proof is buffered + /// and no step is currently in flight. When all H slots are done, calls + /// [`try_dispatch_dkg_aggregation`]. + fn try_dispatch_nodes_fold_step(&mut self, ec: &EventContext) -> Result<()> { + let state = self.state.get(); + let Some(PublicKeyAggregatorState::GeneratingC5Proof { + dkg_node_proofs, + honest_party_ids, + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation, + dkg_aggregation_correlation, + dkg_aggregated_proof, + .. + }) = state.as_ref() + else { + return Ok(()); + }; + + if nodes_fold_step_correlation.is_some() + || dkg_aggregation_correlation.is_some() + || dkg_aggregated_proof.is_some() + { + return Ok(()); + } + + let next_slot = *nodes_fold_completed_slots; + let total_slots = honest_party_ids.len(); + + if next_slot as usize >= total_slots { + return self.try_dispatch_dkg_aggregation(ec); + } + + let Some(&party_id) = honest_party_ids.iter().nth(next_slot as usize) else { + return Ok(()); + }; + + let Some(Some(inner_proof)) = dkg_node_proofs.get(&party_id) else { + return Ok(()); + }; + + let inner_proof = inner_proof.clone(); + let prior_accumulator = nodes_fold_accumulator.clone(); + + let corr = CorrelationId::new(); + self.bus.publish( + ComputeRequest::zk( + ZkRequest::NodesFoldStep(NodesFoldStepRequest { + inner_proof, + prior_accumulator, + slot_index: next_slot, + total_slots, + e3_id: self.e3_id.to_string(), + params_preset: self.params_preset, + committee_size: self.committee_size, + }), + corr, + self.e3_id.clone(), + ), + ec.clone(), + )?; + + info!( + "PublicKeyAggregator: dispatched NodesFoldStep slot={}/{} for E3 {}", + next_slot, total_slots, self.e3_id + ); + + self.state.try_mutate(ec, |state| { + let PublicKeyAggregatorState::GeneratingC5Proof { + public_key, + keyshare_bytes, + nodes, + party_nodes, + dkg_node_proofs, + dkg_fold_attestations, + honest_party_ids, + dishonest_parties, + circuit_committee_n, + circuit_committee_h, + dkg_aggregation_correlation, + dkg_aggregated_proof, + c5_proof_pending, + last_ec, + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation: _, + } = state + else { + return Ok(state); + }; + Ok(PublicKeyAggregatorState::GeneratingC5Proof { + public_key, + keyshare_bytes, + nodes, + party_nodes, + dkg_node_proofs, + dkg_fold_attestations, + honest_party_ids, + dishonest_parties, + circuit_committee_n, + circuit_committee_h, + dkg_aggregation_correlation, + dkg_aggregated_proof, + c5_proof_pending, + last_ec, + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation: Some(corr), + }) + })?; + Ok(()) + } + + /// Handle a completed [`ZkResponse::NodesFoldStep`]: advance the accumulator and dispatch + /// the next fold step (or the final DkgAggregation when all H slots are done). + fn handle_nodes_fold_step_response( + &mut self, + correlation_id: CorrelationId, + accumulator_proof: Proof, + ) -> Result<()> { + let state = self.state.get(); + let Some(PublicKeyAggregatorState::GeneratingC5Proof { + nodes_fold_step_correlation, + nodes_fold_completed_slots, + last_ec, + .. + }) = state.as_ref() + else { + return Ok(()); + }; + + if nodes_fold_step_correlation.as_ref() != Some(&correlation_id) { + return Ok(()); + } + + let completed = nodes_fold_completed_slots + 1; + let Some(ec) = last_ec.clone() else { + return Err(anyhow::anyhow!( + "No EventContext for NodesFoldStep response" + )); + }; + + info!( + "PublicKeyAggregator: NodesFoldStep complete (slot {} done) for E3 {}", + completed - 1, + self.e3_id + ); + + self.state.try_mutate_without_context(|state| { + let PublicKeyAggregatorState::GeneratingC5Proof { + public_key, + keyshare_bytes, + nodes, + party_nodes, + dkg_node_proofs, + dkg_fold_attestations, + honest_party_ids, + dishonest_parties, + circuit_committee_n, + circuit_committee_h, + dkg_aggregation_correlation, + dkg_aggregated_proof, + c5_proof_pending, + last_ec, + nodes_fold_step_correlation: _, + .. + } = state + else { + return Ok(state); + }; + Ok(PublicKeyAggregatorState::GeneratingC5Proof { + public_key, + keyshare_bytes, + nodes, + party_nodes, + dkg_node_proofs, + dkg_fold_attestations, + honest_party_ids, + dishonest_parties, + circuit_committee_n, + circuit_committee_h, + dkg_aggregation_correlation, + dkg_aggregated_proof, + c5_proof_pending, + last_ec, + nodes_fold_accumulator: Some(accumulator_proof), + nodes_fold_completed_slots: completed, + nodes_fold_step_correlation: None, + }) + })?; + + self.try_dispatch_nodes_fold_step(&ec) + } + + /// Dispatch [`ZkRequest::DkgAggregation`] once C5, all honest NodeFold proofs, and the + /// streaming nodes_fold are all ready. fn try_dispatch_dkg_aggregation(&mut self, ec: &EventContext) -> Result<()> { let state = self.state.get(); let Some(PublicKeyAggregatorState::GeneratingC5Proof { @@ -580,6 +790,8 @@ impl PublicKeyAggregator { dkg_aggregated_proof, circuit_committee_n, circuit_committee_h, + nodes_fold_accumulator, + nodes_fold_completed_slots, .. }) = state.as_ref() else { @@ -644,6 +856,9 @@ impl PublicKeyAggregator { dkg_aggregated_proof, c5_proof_pending: _, last_ec, + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation, } = state else { return Ok(state); @@ -664,6 +879,9 @@ impl PublicKeyAggregator { dkg_aggregated_proof, c5_proof_pending: None, last_ec, + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation, }) })?; return Ok(()); @@ -698,6 +916,13 @@ impl PublicKeyAggregator { return Ok(()); } + // Streaming fold must be complete before dispatching the final aggregation. + let fold_complete = *nodes_fold_completed_slots == honest_party_ids.len() as u32; + if !fold_complete { + return Ok(()); + } + let precomputed_fold = nodes_fold_accumulator.clone(); + // Build the FULL committee address vector (length N) in ascending party_id order. // The DKG aggregator circuit's `committee_members: [Field; N_PARTIES]` is the // committee-hash preimage; passing only the H honest subset would silently @@ -725,6 +950,7 @@ impl PublicKeyAggregator { ComputeRequest::zk( ZkRequest::DkgAggregation(DkgAggregationRequest { node_fold_proofs, + nodes_fold_proof: precomputed_fold, c5_proof: c5_proof.clone(), party_ids, committee_addresses, @@ -753,6 +979,9 @@ impl PublicKeyAggregator { dkg_aggregated_proof, c5_proof_pending, last_ec, + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation, } = state else { return Ok(state); @@ -772,6 +1001,9 @@ impl PublicKeyAggregator { dkg_aggregated_proof, c5_proof_pending, last_ec, + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation, }) })?; Ok(()) @@ -901,21 +1133,125 @@ impl PublicKeyAggregator { fn handle_compute_response(&mut self, msg: TypedEvent) -> Result<()> { let (msg, _ec) = msg.into_components(); - if let ComputeResponseKind::Zk(ZkResponse::DkgAggregation(resp)) = msg.response { - if msg.e3_id != self.e3_id { - return Ok(()); + if msg.e3_id != self.e3_id { + return Ok(()); + } + match msg.response { + ComputeResponseKind::Zk(ZkResponse::NodesFoldStep(resp)) => { + self.handle_nodes_fold_step_response(msg.correlation_id, resp.accumulator_proof)?; } - let state = self.state.get(); - let Some(PublicKeyAggregatorState::GeneratingC5Proof { last_ec, .. }) = state.as_ref() - else { - return Ok(()); - }; - let Some(_ec) = last_ec.clone() else { - return Err(anyhow::anyhow!( - "No EventContext for DkgAggregation response" - )); - }; - self.state.try_mutate_without_context(|state| { + ComputeResponseKind::Zk(ZkResponse::DkgAggregation(resp)) => { + let state = self.state.get(); + let Some(PublicKeyAggregatorState::GeneratingC5Proof { last_ec, .. }) = + state.as_ref() + else { + return Ok(()); + }; + let Some(_ec) = last_ec.clone() else { + return Err(anyhow::anyhow!( + "No EventContext for DkgAggregation response" + )); + }; + self.state.try_mutate_without_context(|state| { + let PublicKeyAggregatorState::GeneratingC5Proof { + public_key, + keyshare_bytes, + nodes, + party_nodes, + dkg_node_proofs, + dkg_fold_attestations, + honest_party_ids, + dishonest_parties, + circuit_committee_n, + circuit_committee_h, + dkg_aggregation_correlation, + dkg_aggregated_proof, + c5_proof_pending, + last_ec, + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation, + } = state + else { + return Ok(state); + }; + if dkg_aggregation_correlation.as_ref() != Some(&msg.correlation_id) { + return Ok(PublicKeyAggregatorState::GeneratingC5Proof { + public_key, + keyshare_bytes, + nodes, + party_nodes, + dkg_node_proofs, + dkg_fold_attestations, + honest_party_ids, + dishonest_parties, + circuit_committee_n, + circuit_committee_h, + dkg_aggregation_correlation, + dkg_aggregated_proof, + c5_proof_pending, + last_ec, + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation, + }); + } + Ok(PublicKeyAggregatorState::GeneratingC5Proof { + public_key, + keyshare_bytes, + nodes, + party_nodes, + dkg_node_proofs, + dkg_fold_attestations, + honest_party_ids, + dishonest_parties, + circuit_committee_n, + circuit_committee_h, + dkg_aggregation_correlation: None, + dkg_aggregated_proof: Some(resp.proof.clone()), + c5_proof_pending, + last_ec, + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation, + }) + })?; + self.try_publish_complete()?; + } + _ => {} + } + Ok(()) + } + + fn handle_compute_request_error(&mut self, msg: TypedEvent) -> Result<()> { + let (msg, ec) = msg.into_components(); + if msg.request().e3_id != self.e3_id { + return Ok(()); + } + + let matched_nodes_fold_step = matches!( + self.state.get(), + Some(PublicKeyAggregatorState::GeneratingC5Proof { + nodes_fold_step_correlation, + .. + }) if nodes_fold_step_correlation.as_ref() == Some(msg.correlation_id()) + ); + + if matched_nodes_fold_step { + error!( + "PublicKeyAggregator: NodesFoldStep failed for E3 {}: {:?}", + self.e3_id, + msg.get_err() + ); + self.bus.publish( + E3Failed { + e3_id: self.e3_id.clone(), + failed_at_stage: E3Stage::CommitteeFinalized, + reason: FailureReason::DKGInvalidShares, + }, + ec.clone(), + )?; + self.state.try_mutate(&ec, |state| { let PublicKeyAggregatorState::GeneratingC5Proof { public_key, keyshare_bytes, @@ -929,30 +1265,15 @@ impl PublicKeyAggregator { circuit_committee_h, dkg_aggregation_correlation, dkg_aggregated_proof, - c5_proof_pending, + c5_proof_pending: _, last_ec, + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation: _, } = state else { return Ok(state); }; - if dkg_aggregation_correlation.as_ref() != Some(&msg.correlation_id) { - return Ok(PublicKeyAggregatorState::GeneratingC5Proof { - public_key, - keyshare_bytes, - nodes, - party_nodes, - dkg_node_proofs, - dkg_fold_attestations, - honest_party_ids, - dishonest_parties, - circuit_committee_n, - circuit_committee_h, - dkg_aggregation_correlation, - dkg_aggregated_proof, - c5_proof_pending, - last_ec, - }); - } Ok(PublicKeyAggregatorState::GeneratingC5Proof { public_key, keyshare_bytes, @@ -964,24 +1285,19 @@ impl PublicKeyAggregator { dishonest_parties, circuit_committee_n, circuit_committee_h, - dkg_aggregation_correlation: None, - dkg_aggregated_proof: Some(resp.proof.clone()), - c5_proof_pending, + dkg_aggregation_correlation, + dkg_aggregated_proof, + c5_proof_pending: None, last_ec, + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation: None, }) })?; - self.try_publish_complete()?; - } - Ok(()) - } - - fn handle_compute_request_error(&mut self, msg: TypedEvent) -> Result<()> { - let (msg, ec) = msg.into_components(); - if msg.request().e3_id != self.e3_id { return Ok(()); } - let matched_correlation = matches!( + let matched_dkg_aggregation = matches!( self.state.get(), Some(PublicKeyAggregatorState::GeneratingC5Proof { dkg_aggregation_correlation, @@ -989,7 +1305,7 @@ impl PublicKeyAggregator { }) if dkg_aggregation_correlation.as_ref() == Some(msg.correlation_id()) ); - if !matched_correlation { + if !matched_dkg_aggregation { return Ok(()); } @@ -1024,6 +1340,9 @@ impl PublicKeyAggregator { dkg_aggregated_proof, c5_proof_pending: _, last_ec, + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation, } = state else { return Ok(state); @@ -1044,6 +1363,9 @@ impl PublicKeyAggregator { dkg_aggregated_proof, c5_proof_pending: None, last_ec, + nodes_fold_accumulator, + nodes_fold_completed_slots, + nodes_fold_step_correlation, }) })?; @@ -1315,6 +1637,9 @@ mod tests { dkg_aggregated_proof: None, c5_proof_pending: Some(dummy_proof(CircuitName::PkAggregation)), last_ec: None, + nodes_fold_accumulator: None, + nodes_fold_completed_slots: 0, + nodes_fold_step_correlation: None, } } @@ -1358,6 +1683,7 @@ mod tests { let request = ComputeRequest::zk( ZkRequest::DkgAggregation(DkgAggregationRequest { node_fold_proofs: vec![dummy_proof(CircuitName::PkAggregation)], + nodes_fold_proof: None, c5_proof: dummy_proof(CircuitName::PkAggregation), party_ids: vec![0], committee_addresses: vec!["0x0000000000000000000000000000000000000001" diff --git a/crates/aggregator/src/domain/publickey_aggregation.rs b/crates/aggregator/src/domain/publickey_aggregation.rs index da5361362..b8f99e9c7 100644 --- a/crates/aggregator/src/domain/publickey_aggregation.rs +++ b/crates/aggregator/src/domain/publickey_aggregation.rs @@ -206,6 +206,12 @@ pub enum PublicKeyAggregatorState { dkg_aggregated_proof: Option, c5_proof_pending: Option, last_ec: Option>, + /// Accumulated nodes_fold proof after `nodes_fold_completed_slots` streaming steps. + nodes_fold_accumulator: Option, + /// Number of slots folded so far; equals the next slot index to dispatch. + nodes_fold_completed_slots: u32, + /// Correlation ID of the in-flight [`ZkRequest::NodesFoldStep`], if any. + nodes_fold_step_correlation: Option, }, Complete { public_key: ArcBytes, diff --git a/crates/events/src/enclave_event/compute_request/mod.rs b/crates/events/src/enclave_event/compute_request/mod.rs index 9a76a88f6..2c8fdec28 100644 --- a/crates/events/src/enclave_event/compute_request/mod.rs +++ b/crates/events/src/enclave_event/compute_request/mod.rs @@ -92,6 +92,7 @@ impl ToString for ComputeRequest { ZkRequest::ThresholdShareDecryption(_) => "ZkThresholdShareDecryption", ZkRequest::DecryptedSharesAggregation(_) => "ZkDecryptedSharesAggregation", ZkRequest::NodeDkgFold(_) => "ZkNodeDkgFold", + ZkRequest::NodesFoldStep(_) => "ZkNodesFoldStep", ZkRequest::DkgAggregation(_) => "ZkDkgAggregation", ZkRequest::DecryptionAggregation(_) => "ZkDecryptionAggregation", }, diff --git a/crates/events/src/enclave_event/compute_request/zk.rs b/crates/events/src/enclave_event/compute_request/zk.rs index fa5e88225..c82263147 100644 --- a/crates/events/src/enclave_event/compute_request/zk.rs +++ b/crates/events/src/enclave_event/compute_request/zk.rs @@ -38,6 +38,8 @@ pub enum ZkRequest { DecryptedSharesAggregation(DecryptedSharesAggregationProofRequest), /// Per-node DKG recursive fold (C2abFold → … → NodeFold). NodeDkgFold(NodeDkgFoldRequest), + /// Single step of the streaming cross-node nodes_fold accumulation. + NodesFoldStep(NodesFoldStepRequest), /// Cross-node DKG aggregator (NodesFold + C5 + DkgAggregator). DkgAggregation(DkgAggregationRequest), /// Phase-7 decryption aggregator (C6Fold + C7 + DecryptionAggregator). @@ -71,10 +73,30 @@ pub struct NodeDkgFoldRequest { pub committee_size: CiphernodesCommitteeSize, } +/// Single step of the streaming cross-node nodes_fold accumulation. +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct NodesFoldStepRequest { + /// The `node_fold` proof for this slot. + pub inner_proof: Proof, + /// The prior accumulator proof, or `None` for the first step. + pub prior_accumulator: Option, + /// Slot index for this honest party (position in ascending `party_id` order). + pub slot_index: u32, + /// Total honest-party count H. + pub total_slots: usize, + /// E3 identifier used for job namespacing. + pub e3_id: String, + pub params_preset: BfvPreset, + pub committee_size: CiphernodesCommitteeSize, +} + /// Cross-node DKG aggregation (NodesFold + C5 + DkgAggregator). #[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct DkgAggregationRequest { pub node_fold_proofs: Vec, + /// Pre-computed nodes_fold accumulator proof. When present the prover skips the + /// sequential fold and uses this directly as input to the DkgAggregator circuit. + pub nodes_fold_proof: Option, pub c5_proof: Proof, pub party_ids: Vec, /// Ordered committee addresses (`topNodes`) for `committee_hash_*` public inputs. @@ -293,6 +315,8 @@ pub enum ZkResponse { DecryptedSharesAggregation(DecryptedSharesAggregationProofResponse), /// Output of [`ZkRequest::NodeDkgFold`]. NodeDkgFold(NodeDkgFoldResponse), + /// Output of [`ZkRequest::NodesFoldStep`]. + NodesFoldStep(NodesFoldStepResponse), /// Output of [`ZkRequest::DkgAggregation`]. DkgAggregation(DkgAggregationResponse), /// Output of [`ZkRequest::DecryptionAggregation`]. @@ -305,6 +329,12 @@ pub struct NodeDkgFoldResponse { pub proof: Proof, } +/// Response from [`ZkRequest::NodesFoldStep`]. +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct NodesFoldStepResponse { + pub accumulator_proof: Proof, +} + /// Response from [`ZkRequest::DkgAggregation`]. #[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct DkgAggregationResponse { diff --git a/crates/multithread/src/multithread.rs b/crates/multithread/src/multithread.rs index 1afacb965..b11aa50ad 100644 --- a/crates/multithread/src/multithread.rs +++ b/crates/multithread/src/multithread.rs @@ -29,10 +29,10 @@ use e3_events::{ DecryptionAggregationResponse, DkgAggregationRequest, DkgAggregationResponse, DkgShareDecryptionProofRequest, DkgShareDecryptionProofResponse, EnclaveEvent, EnclaveEventData, EventPublisher, EventSubscriber, EventType, NodeDkgFoldRequest, - NodeDkgFoldResponse, PartyVerificationResult, PkAggregationProofRequest, - PkAggregationProofResponse, PkBfvProofRequest, PkBfvProofResponse, PkGenerationProofRequest, - PkGenerationProofResponse, Proof, ShareComputationProofRequest, ShareComputationProofResponse, - ShareEncryptionProofRequest, ShareEncryptionProofResponse, + NodeDkgFoldResponse, NodesFoldStepRequest, NodesFoldStepResponse, PartyVerificationResult, + PkAggregationProofRequest, PkAggregationProofResponse, PkBfvProofRequest, PkBfvProofResponse, + PkGenerationProofRequest, PkGenerationProofResponse, Proof, ShareComputationProofRequest, + ShareComputationProofResponse, ShareEncryptionProofRequest, ShareEncryptionProofResponse, ThresholdShareDecryptionProofRequest, ThresholdShareDecryptionProofResponse, TypedEvent, VerifyShareDecryptionProofsRequest, VerifyShareDecryptionProofsResponse, VerifyShareProofsRequest, VerifyShareProofsResponse, ZkError as ZkEventError, ZkRequest, @@ -72,9 +72,9 @@ use e3_zk_helpers::threshold::pk_aggregation::PkAggregationCircuitData; use e3_zk_helpers::CiphernodesCommittee; use e3_zk_helpers::CiphernodesCommitteeSize; use e3_zk_prover::{ - prove_decryption_aggregation_jobs, prove_dkg_aggregation, prove_node_dkg_fold, CircuitVariant, - DecryptionAggregationJob, DkgAggregationInput, NodeDkgFoldInput, NodeDkgFoldProveResult, - Provable, ZkBackend, ZkError, ZkProver, + generate_nodes_fold_step, prove_decryption_aggregation_jobs, prove_dkg_aggregation, + prove_node_dkg_fold, CircuitVariant, DecryptionAggregationJob, DkgAggregationInput, + NodeDkgFoldInput, NodeDkgFoldProveResult, Provable, ZkBackend, ZkError, ZkProver, }; use fhe::bfv::{Ciphertext, Encoding, Plaintext, PublicKey, SecretKey}; use fhe::mbfv::PublicKeyShare; @@ -693,6 +693,9 @@ fn handle_zk_request( ZkRequest::NodeDkgFold(req) => timefunc("zk_node_dkg_fold", id, || { handle_node_dkg_fold_proof(&prover, req, request.clone(), report.clone()) }), + ZkRequest::NodesFoldStep(req) => timefunc("zk_nodes_fold_step", id, || { + handle_nodes_fold_step_proof(&prover, req, request.clone()) + }), ZkRequest::DkgAggregation(req) => timefunc("zk_dkg_aggregation", id, || { handle_dkg_aggregation_proof(&prover, req, request.clone()) }), @@ -749,6 +752,35 @@ fn handle_node_dkg_fold_proof( )) } +fn handle_nodes_fold_step_proof( + prover: &ZkProver, + req: NodesFoldStepRequest, + request: ComputeRequest, +) -> Result { + let artifacts_dir = + prover.resolve_artifacts_dir(req.params_preset, req.committee_size.as_str()); + let accumulator_proof = generate_nodes_fold_step( + prover, + &req.inner_proof, + req.prior_accumulator.as_ref(), + req.slot_index, + req.total_slots, + &format!("{}-nodesfold-step-{}", req.e3_id, req.slot_index), + artifacts_dir.as_str(), + ) + .map_err(|e| { + ComputeRequestError::new( + ComputeRequestErrorKind::Zk(ZkEventError::ProofGenerationFailed(e.to_string())), + request.clone(), + ) + })?; + Ok(ComputeResponse::zk( + ZkResponse::NodesFoldStep(NodesFoldStepResponse { accumulator_proof }), + request.correlation_id, + request.e3_id, + )) +} + fn handle_dkg_aggregation_proof( prover: &ZkProver, req: DkgAggregationRequest, @@ -757,6 +789,7 @@ fn handle_dkg_aggregation_proof( let job_id = zk_bb_work_id(&request); let input = DkgAggregationInput { node_fold_proofs: &req.node_fold_proofs, + nodes_fold_proof: req.nodes_fold_proof.as_ref(), c5_proof: &req.c5_proof, party_ids: &req.party_ids, committee_addresses: &req.committee_addresses, diff --git a/crates/zk-prover/src/circuits/aggregation/c3_accumulator.rs b/crates/zk-prover/src/circuits/aggregation/c3_accumulator.rs index 85319b198..6f12871d7 100644 --- a/crates/zk-prover/src/circuits/aggregation/c3_accumulator.rs +++ b/crates/zk-prover/src/circuits/aggregation/c3_accumulator.rs @@ -32,6 +32,31 @@ fn c3_fold_public_input_field_count(total_slots: usize) -> usize { const C3_FOLD_PREFIX_LEN: usize = 4; const C3_FOLD_SLOT_WIDTH: usize = 3; +struct C3FoldVks { + inner_vk: vk::VkArtifacts, + fold_vk: vk::VkArtifacts, + kernel_vk: vk::VkArtifacts, +} + +impl C3FoldVks { + fn load(prover: &ZkProver, artifacts_dir: &str) -> Result { + Ok(Self { + inner_vk: vk::load_vk_artifacts( + &prover.circuits_dir(CircuitVariant::Recursive, artifacts_dir), + CircuitName::ShareEncryption, + )?, + fold_vk: vk::load_vk_artifacts( + &prover.circuits_dir(CircuitVariant::Default, artifacts_dir), + CircuitName::C3Fold, + )?, + kernel_vk: vk::load_vk_artifacts( + &prover.circuits_dir(CircuitVariant::Default, artifacts_dir), + CircuitName::C3FoldKernel, + )?, + }) + } +} + /// Proves [`CircuitName::C3FoldKernel`] for the same `inner` / `total_slots` as the fold step. /// /// Uses work dir `job_id` (caller should use a suffix of the fold `e3_id` so jobs stay distinct). @@ -130,15 +155,7 @@ fn parse_c3_fold_public_field_strings(proof: &Proof) -> Result, ZkEr ) } -/// One sequential `c3_fold` step. -/// -/// `prior_fold` is `None` on the first step and `Some` on all subsequent steps. -/// `total_slots` is `N_PARTIES * L_THRESHOLD` — one slot per (party, threshold-modulus) pair. -/// On the first step this sets the accumulator size; on subsequent steps it is cross-checked -/// against the slot count already encoded in `prior_fold`. -/// -/// Used only by [`generate_sequential_c3_fold`]; callers should use that entry point. -fn generate_c3_fold_step( +fn generate_c3_fold_step_with_vks( prover: &ZkProver, inner: &Proof, prior_fold: Option<&Proof>, @@ -146,22 +163,14 @@ fn generate_c3_fold_step( total_slots: usize, e3_id: &str, artifacts_dir: &str, + vks: &C3FoldVks, ) -> Result { let is_first_step = prior_fold.is_none(); - let inner_vk = vk::load_vk_artifacts( - &prover.circuits_dir(CircuitVariant::Recursive, artifacts_dir), - CircuitName::ShareEncryption, - )?; let c3_public_inputs = share_encryption_inner_public_inputs(inner)?; - let expected_acc_pub = c3_fold_public_input_field_count(total_slots); - let (acc_vk_art, acc_proof, acc_public_inputs) = if is_first_step { - let kernel_vk = vk::load_vk_artifacts( - &prover.circuits_dir(CircuitVariant::Default, artifacts_dir), - CircuitName::C3FoldKernel, - )?; + let (acc_vk_fields, acc_vk_hash, acc_proof, acc_public_inputs) = if is_first_step { let kernel_job_id = format!("{e3_id}-c3fold-kernel"); let kernel_proof = generate_c3_fold_kernel_genesis_proof( prover, @@ -180,13 +189,13 @@ fn generate_c3_fold_step( ))); } ( - kernel_vk, + vks.kernel_vk.verification_key.clone(), + vks.kernel_vk.key_hash.clone(), bytes_to_field_strings(&kernel_proof.data)?, acc_pi, ) } else { let p = prior_fold.expect("prior_fold required when is_first_step is false"); - // Parse once; derive slot count from field count to avoid a second parse. let acc_pi = parse_c3_fold_public_field_strings(p)?; let prior_slots = (acc_pi.len() - 4) / 3; if prior_slots == 0 { @@ -209,24 +218,22 @@ fn generate_c3_fold_step( ))); } ( - vk::load_vk_artifacts( - &prover.circuits_dir(CircuitVariant::Default, artifacts_dir), - CircuitName::C3Fold, - )?, + vks.fold_vk.verification_key.clone(), + vks.fold_vk.key_hash.clone(), bytes_to_field_strings(&p.data)?, acc_pi, ) }; let full_input = C3FoldStepInput { - inner_vk: inner_vk.verification_key, + inner_vk: vks.inner_vk.verification_key.clone(), inner_proof: bytes_to_field_strings(&inner.data)?, c3_public_inputs, - acc_vk: acc_vk_art.verification_key, + acc_vk: acc_vk_fields, acc_proof, acc_public_inputs, - inner_key_hash: inner_vk.key_hash, - acc_key_hash: acc_vk_art.key_hash, + inner_key_hash: vks.inner_vk.key_hash.clone(), + acc_key_hash: acc_vk_hash, is_first_step, slot_index, }; @@ -291,12 +298,13 @@ pub fn generate_sequential_c3_fold( } seen[idx] = true; } + let vks = C3FoldVks::load(prover, artifacts_dir)?; sequential_fold( "generate_sequential_c3_fold", inner_proofs, slot_indices, |inner, prior, slot| { - generate_c3_fold_step( + generate_c3_fold_step_with_vks( prover, inner, prior, @@ -304,6 +312,7 @@ pub fn generate_sequential_c3_fold( total_slots, e3_id, artifacts_dir, + &vks, ) }, ) diff --git a/crates/zk-prover/src/circuits/aggregation/c6_accumulator.rs b/crates/zk-prover/src/circuits/aggregation/c6_accumulator.rs index 3c7fcb4a8..199710891 100644 --- a/crates/zk-prover/src/circuits/aggregation/c6_accumulator.rs +++ b/crates/zk-prover/src/circuits/aggregation/c6_accumulator.rs @@ -29,6 +29,31 @@ fn c6_fold_public_input_field_count(total_slots: usize) -> usize { const C6_FOLD_PREFIX_LEN: usize = 4; const C6_FOLD_SLOT_WIDTH: usize = 4; +struct C6FoldVks { + inner_vk: vk::VkArtifacts, + fold_vk: vk::VkArtifacts, + kernel_vk: vk::VkArtifacts, +} + +impl C6FoldVks { + fn load(prover: &ZkProver, artifacts_dir: &str) -> Result { + Ok(Self { + inner_vk: vk::load_vk_artifacts( + &prover.circuits_dir(CircuitVariant::Recursive, artifacts_dir), + CircuitName::ThresholdShareDecryption, + )?, + fold_vk: vk::load_vk_artifacts( + &prover.circuits_dir(CircuitVariant::Default, artifacts_dir), + CircuitName::C6Fold, + )?, + kernel_vk: vk::load_vk_artifacts( + &prover.circuits_dir(CircuitVariant::Default, artifacts_dir), + CircuitName::C6FoldKernel, + )?, + }) + } +} + fn generate_c6_fold_kernel_genesis_proof( prover: &ZkProver, inner: &Proof, @@ -124,7 +149,7 @@ fn parse_c6_fold_public_field_strings(proof: &Proof) -> Result, ZkEr ) } -fn generate_c6_fold_step( +fn generate_c6_fold_step_with_vks( prover: &ZkProver, inner: &Proof, prior_fold: Option<&Proof>, @@ -132,22 +157,14 @@ fn generate_c6_fold_step( total_slots: usize, e3_id: &str, artifacts_dir: &str, + vks: &C6FoldVks, ) -> Result { let is_first_step = prior_fold.is_none(); - let inner_vk = vk::load_vk_artifacts( - &prover.circuits_dir(CircuitVariant::Recursive, artifacts_dir), - CircuitName::ThresholdShareDecryption, - )?; let c6_public_inputs = threshold_share_decryption_inner_public_inputs(inner)?; - let expected_acc_pub = c6_fold_public_input_field_count(total_slots); - let (acc_vk_art, acc_proof, acc_public_inputs) = if is_first_step { - let kernel_vk = vk::load_vk_artifacts( - &prover.circuits_dir(CircuitVariant::Default, artifacts_dir), - CircuitName::C6FoldKernel, - )?; + let (acc_vk_fields, acc_vk_hash, acc_proof, acc_public_inputs) = if is_first_step { let kernel_job_id = format!("{e3_id}-c6fold-kernel"); let kernel_proof = generate_c6_fold_kernel_genesis_proof( prover, @@ -167,7 +184,8 @@ fn generate_c6_fold_step( ))); } ( - kernel_vk, + vks.kernel_vk.verification_key.clone(), + vks.kernel_vk.key_hash.clone(), bytes_to_field_strings(&kernel_proof.data)?, acc_pi, ) @@ -195,24 +213,22 @@ fn generate_c6_fold_step( ))); } ( - vk::load_vk_artifacts( - &prover.circuits_dir(CircuitVariant::Default, artifacts_dir), - CircuitName::C6Fold, - )?, + vks.fold_vk.verification_key.clone(), + vks.fold_vk.key_hash.clone(), bytes_to_field_strings(&p.data)?, acc_pi, ) }; let full_input = C6FoldStepInput { - inner_vk: inner_vk.verification_key, + inner_vk: vks.inner_vk.verification_key.clone(), inner_proof: bytes_to_field_strings(&inner.data)?, c6_public_inputs, - acc_vk: acc_vk_art.verification_key, + acc_vk: acc_vk_fields, acc_proof, acc_public_inputs, - inner_key_hash: inner_vk.key_hash, - acc_key_hash: acc_vk_art.key_hash, + inner_key_hash: vks.inner_vk.key_hash.clone(), + acc_key_hash: acc_vk_hash, is_first_step, slot_index, }; @@ -274,12 +290,13 @@ pub fn generate_sequential_c6_fold( } seen[idx] = true; } + let vks = C6FoldVks::load(prover, artifacts_dir)?; sequential_fold( "generate_sequential_c6_fold", inner_proofs, slot_indices, |inner, prior, slot| { - generate_c6_fold_step( + generate_c6_fold_step_with_vks( prover, inner, prior, @@ -287,6 +304,7 @@ pub fn generate_sequential_c6_fold( total_slots, e3_id, artifacts_dir, + &vks, ) }, ) diff --git a/crates/zk-prover/src/circuits/aggregation/node_dkg_fold.rs b/crates/zk-prover/src/circuits/aggregation/node_dkg_fold.rs index 9970415d8..3b2f3ac09 100644 --- a/crates/zk-prover/src/circuits/aggregation/node_dkg_fold.rs +++ b/crates/zk-prover/src/circuits/aggregation/node_dkg_fold.rs @@ -369,6 +369,8 @@ pub fn prove_node_dkg_fold( /// Inputs for [`prove_dkg_aggregation`]. pub struct DkgAggregationInput<'a> { pub node_fold_proofs: &'a [Proof], + /// Pre-computed nodes_fold accumulator. When `Some`, the sequential fold is skipped. + pub nodes_fold_proof: Option<&'a Proof>, pub c5_proof: &'a Proof, /// Honest party ids in the same order as `node_fold_proofs` (e.g. sorted ascending). pub party_ids: &'a [u64], @@ -430,15 +432,19 @@ pub fn prove_dkg_aggregation( "DkgAggregator honest-set H must equal registered committee size until expulsion enables H < N" ); } - let slot_indices: Vec = (0u32..h as u32).collect(); - let nodes_fold_proof = generate_sequential_nodes_fold( - prover, - input.node_fold_proofs, - &slot_indices, - h, - &format!("{e3_id}-nodesfold"), - artifacts_dir, - )?; + let nodes_fold_proof = if let Some(precomputed) = input.nodes_fold_proof { + precomputed.clone() + } else { + let slot_indices: Vec = (0u32..h as u32).collect(); + generate_sequential_nodes_fold( + prover, + input.node_fold_proofs, + &slot_indices, + h, + &format!("{e3_id}-nodesfold"), + artifacts_dir, + )? + }; let nodes_fold_vk = vk::load_vk_artifacts( &prover.circuits_dir(CircuitVariant::Default, artifacts_dir), diff --git a/crates/zk-prover/src/circuits/aggregation/nodes_fold_accumulator.rs b/crates/zk-prover/src/circuits/aggregation/nodes_fold_accumulator.rs index 5950579af..7c352527f 100644 --- a/crates/zk-prover/src/circuits/aggregation/nodes_fold_accumulator.rs +++ b/crates/zk-prover/src/circuits/aggregation/nodes_fold_accumulator.rs @@ -58,6 +58,31 @@ struct NodesFoldStepInput { slot_index: u32, } +struct NodesFoldVks { + inner_vk: vk::VkArtifacts, + fold_vk: vk::VkArtifacts, + kernel_vk: vk::VkArtifacts, +} + +impl NodesFoldVks { + fn load(prover: &ZkProver, artifacts_dir: &str) -> Result { + Ok(Self { + inner_vk: vk::load_vk_artifacts( + &prover.circuits_dir(CircuitVariant::Default, artifacts_dir), + CircuitName::NodeFold, + )?, + fold_vk: vk::load_vk_artifacts( + &prover.circuits_dir(CircuitVariant::Default, artifacts_dir), + CircuitName::NodesFold, + )?, + kernel_vk: vk::load_vk_artifacts( + &prover.circuits_dir(CircuitVariant::Default, artifacts_dir), + CircuitName::NodesFoldKernel, + )?, + }) + } +} + /// Proves [`CircuitName::NodesFoldKernel`] for the same `inner` / `total_slots` / `slot_index` as the /// fold step. fn generate_nodes_fold_kernel_genesis_proof( @@ -129,7 +154,7 @@ fn parse_nodes_fold_public_field_strings(proof: &Proof) -> Result, Z parse_acc_public_field_strings_flat(proof, CircuitName::NodesFold, NODES_FOLD_PREFIX_LEN) } -fn generate_nodes_fold_step( +fn generate_nodes_fold_step_with_vks( prover: &ZkProver, inner: &Proof, prior_fold: Option<&Proof>, @@ -137,13 +162,10 @@ fn generate_nodes_fold_step( total_slots: usize, e3_id: &str, artifacts_dir: &str, + vks: &NodesFoldVks, ) -> Result { let is_first_step = prior_fold.is_none(); - let inner_vk = vk::load_vk_artifacts( - &prover.circuits_dir(CircuitVariant::Default, artifacts_dir), - CircuitName::NodeFold, - )?; let nf_fields = node_fold_statement_field_count(inner)?; let node_fold_public_inputs = bytes_to_field_strings(inner.public_signals.as_ref())?; if node_fold_public_inputs.len() != nf_fields { @@ -154,11 +176,7 @@ fn generate_nodes_fold_step( let expected_acc_pub = nodes_fold_acc_public_len(nf_fields, total_slots); - let (acc_vk_art, acc_proof, acc_public_inputs) = if is_first_step { - let kernel_vk = vk::load_vk_artifacts( - &prover.circuits_dir(CircuitVariant::Default, artifacts_dir), - CircuitName::NodesFoldKernel, - )?; + let (acc_vk_fields, acc_vk_hash, acc_proof, acc_public_inputs) = if is_first_step { let kernel_job_id = format!("{e3_id}-nodesfold-kernel"); let kernel_proof = generate_nodes_fold_kernel_genesis_proof( prover, @@ -178,7 +196,8 @@ fn generate_nodes_fold_step( ))); } ( - kernel_vk, + vks.kernel_vk.verification_key.clone(), + vks.kernel_vk.key_hash.clone(), bytes_to_field_strings(&kernel_proof.data)?, acc_pi, ) @@ -194,24 +213,22 @@ fn generate_nodes_fold_step( ))); } ( - vk::load_vk_artifacts( - &prover.circuits_dir(CircuitVariant::Default, artifacts_dir), - CircuitName::NodesFold, - )?, + vks.fold_vk.verification_key.clone(), + vks.fold_vk.key_hash.clone(), bytes_to_field_strings(&p.data)?, acc_pi, ) }; let full_input = NodesFoldStepInput { - inner_vk: inner_vk.verification_key, + inner_vk: vks.inner_vk.verification_key.clone(), inner_proof: bytes_to_field_strings(&inner.data)?, node_fold_public_inputs, - acc_vk: acc_vk_art.verification_key, + acc_vk: acc_vk_fields, acc_proof, acc_public_inputs, - inner_key_hash: inner_vk.key_hash, - acc_key_hash: acc_vk_art.key_hash, + inner_key_hash: vks.inner_vk.key_hash.clone(), + acc_key_hash: acc_vk_hash, is_first_step, slot_index, }; @@ -237,6 +254,29 @@ fn generate_nodes_fold_step( ) } +/// Single streaming step — loads VKs from disk on each call (suitable for async dispatch). +pub fn generate_nodes_fold_step( + prover: &ZkProver, + inner: &Proof, + prior_fold: Option<&Proof>, + slot_index: u32, + total_slots: usize, + e3_id: &str, + artifacts_dir: &str, +) -> Result { + let vks = NodesFoldVks::load(prover, artifacts_dir)?; + generate_nodes_fold_step_with_vks( + prover, + inner, + prior_fold, + slot_index, + total_slots, + e3_id, + artifacts_dir, + &vks, + ) +} + /// Folds `inner_proofs` (one [`CircuitName::NodeFold`] per honest party) into a single /// [`CircuitName::NodesFold`] proof for [`CircuitName::DkgAggregator`]. /// @@ -250,12 +290,13 @@ pub fn generate_sequential_nodes_fold( e3_id: &str, artifacts_dir: &str, ) -> Result { + let vks = NodesFoldVks::load(prover, artifacts_dir)?; sequential_fold( "generate_sequential_nodes_fold", inner_proofs, slot_indices, |inner, prior, slot| { - generate_nodes_fold_step( + generate_nodes_fold_step_with_vks( prover, inner, prior, @@ -263,6 +304,7 @@ pub fn generate_sequential_nodes_fold( total_slots, e3_id, artifacts_dir, + &vks, ) }, ) diff --git a/crates/zk-prover/src/lib.rs b/crates/zk-prover/src/lib.rs index 5286dc20c..eec058f3c 100644 --- a/crates/zk-prover/src/lib.rs +++ b/crates/zk-prover/src/lib.rs @@ -32,7 +32,9 @@ pub use circuits::aggregation::node_dkg_fold::{ DecryptionAggregationJob, DkgAggregationInput, FoldProveStepTiming, NodeDkgFoldInput, NodeDkgFoldProveResult, }; -pub use circuits::aggregation::nodes_fold_accumulator::generate_sequential_nodes_fold; +pub use circuits::aggregation::nodes_fold_accumulator::{ + generate_nodes_fold_step, generate_sequential_nodes_fold, +}; pub use config::{verify_checksum, BbTarget, CircuitInfo, VersionInfo, ZkConfig}; pub use dkg_attestation_bundle::encode_dkg_attestation_bundle; pub use e3_events::CircuitVariant; diff --git a/crates/zk-prover/tests/fold_accumulators_e2e_tests.rs b/crates/zk-prover/tests/fold_accumulators_e2e_tests.rs index ac6ba69e9..ab1d61fe7 100644 --- a/crates/zk-prover/tests/fold_accumulators_e2e_tests.rs +++ b/crates/zk-prover/tests/fold_accumulators_e2e_tests.rs @@ -196,6 +196,7 @@ async fn recursive_aggregation_default_artifacts_staged() { let base = backend .circuits_dir .join("insecure-512") + .join("micro") .join("default") .join(CircuitName::C3Fold.dir_path()); let pkg = CircuitName::C3Fold.as_str(); @@ -232,6 +233,7 @@ async fn recursive_aggregation_c6_fold_kernel_artifacts_staged() { let base = backend .circuits_dir .join("insecure-512") + .join("micro") .join("default") .join(CircuitName::C6FoldKernel.dir_path()); let pkg = CircuitName::C6FoldKernel.as_str(); @@ -269,7 +271,11 @@ async fn node_fold_pipeline_recursive_aggregation_artifacts_staged() { setup_recursive_aggregation_fold_circuit(&backend, c).await; } - let preset_base = backend.circuits_dir.join("insecure-512").join("default"); + let preset_base = backend + .circuits_dir + .join("insecure-512") + .join("micro") + .join("default"); for &c in NODE_FOLD_PIPELINE { let base = preset_base.join(c.dir_path()); let pkg = c.as_str(); @@ -361,7 +367,7 @@ async fn c3_fold_sequential_proves_and_verifies() { return; } - let artifacts_dir = preset.artifacts_dir(); + let artifacts_dir = preset.artifacts_dir_for_committee("micro"); let inner_e3_a = "e3-c3fold-inner-0"; let inner_e3_b = "e3-c3fold-inner-1"; let fold_e3 = "e3-c3fold-step"; @@ -472,7 +478,7 @@ async fn c6_fold_sequential_proves_and_verifies() { ); return; } - let artifacts_dir = preset.artifacts_dir(); + let artifacts_dir = preset.artifacts_dir_for_committee("micro"); let inner_e3_a = "e3-c6fold-inner-0"; let inner_e3_b = "e3-c6fold-inner-1"; let fold_e3 = "e3-c6fold-step";