diff --git a/Cargo.lock b/Cargo.lock index a65e93d4aa..b859ce0653 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2801,6 +2801,7 @@ dependencies = [ "e3-trbfv", "e3-utils", "once_cell", + "rayon", "tempfile", "tokio", "tracing", @@ -3145,6 +3146,7 @@ dependencies = [ "e3-utils", "rand 0.8.5", "rayon", + "thiserror 1.0.69", "tokio", "tracing", "zeroize", diff --git a/crates/aggregator/src/ext.rs b/crates/aggregator/src/ext.rs index 1e97ab4ac5..b605eb9eb7 100644 --- a/crates/aggregator/src/ext.rs +++ b/crates/aggregator/src/ext.rs @@ -22,7 +22,6 @@ use e3_events::{prelude::*, E3id}; use e3_events::{BusHandle, EType, EnclaveEvent, EnclaveEventData}; use e3_fhe::ext::FHE_KEY; use e3_fhe::Fhe; -use e3_multithread::Multithread; use e3_request::{E3Context, E3ContextSnapshot, E3Extension, META_KEY}; use e3_sortition::Sortition; @@ -249,19 +248,13 @@ fn create_publickey_aggregator( pub struct ThresholdPlaintextAggregatorExtension { bus: BusHandle, sortition: Addr, - multithread: Addr, } impl ThresholdPlaintextAggregatorExtension { - pub fn create( - bus: &BusHandle, - sortition: &Addr, - multithread: &Addr, - ) -> Box { + pub fn create(bus: &BusHandle, sortition: &Addr) -> Box { Box::new(Self { bus: bus.clone(), sortition: sortition.clone(), - multithread: multithread.clone(), }) } } @@ -302,7 +295,6 @@ impl E3Extension for ThresholdPlaintextAggregatorExtension { bus: self.bus.clone(), sortition: self.sortition.clone(), e3_id: e3_id.clone(), - multithread: self.multithread.clone(), }, sync_state, ) @@ -331,7 +323,6 @@ impl E3Extension for ThresholdPlaintextAggregatorExtension { bus: self.bus.clone(), sortition: self.sortition.clone(), e3_id: ctx.e3_id.clone(), - multithread: self.multithread.clone(), }, sync_state, ) diff --git a/crates/aggregator/src/threshold_plaintext_aggregator.rs b/crates/aggregator/src/threshold_plaintext_aggregator.rs index 1c5ee9e8b7..ca861ff876 100644 --- a/crates/aggregator/src/threshold_plaintext_aggregator.rs +++ b/crates/aggregator/src/threshold_plaintext_aggregator.rs @@ -7,19 +7,17 @@ use std::collections::HashMap; use actix::prelude::*; -use anyhow::{anyhow, bail, Result}; +use anyhow::{anyhow, bail, ensure, Result}; use e3_data::Persistable; use e3_events::{ - prelude::*, BusHandle, ComputeRequest, DecryptionshareCreated, Die, E3id, EnclaveEvent, - EnclaveEventData, PlaintextAggregated, Seed, + prelude::*, trap, BusHandle, ComputeRequest, ComputeResponse, CorrelationId, + DecryptionshareCreated, Die, E3id, EType, EnclaveEvent, EnclaveEventData, PlaintextAggregated, + Seed, }; -use e3_multithread::Multithread; use e3_sortition::{GetNodesForE3, Sortition}; use e3_trbfv::{ - calculate_threshold_decryption::{ - CalculateThresholdDecryptionRequest, CalculateThresholdDecryptionResponse, - }, - TrBFVConfig, TrBFVRequest, + calculate_threshold_decryption::CalculateThresholdDecryptionRequest, TrBFVConfig, TrBFVRequest, + TrBFVResponse, }; use e3_utils::utility_types::ArcBytes; use tracing::{debug, error, info, trace}; @@ -112,7 +110,7 @@ impl ThresholdPlaintextAggregatorState { } #[derive(Message)] -#[rtype(result = "anyhow::Result<()>")] +#[rtype("()")] pub struct ComputeAggregate { pub shares: Vec<(u64, Vec)>, pub ciphertext_output: Vec, @@ -121,7 +119,6 @@ pub struct ComputeAggregate { } pub struct ThresholdPlaintextAggregator { - multithread: Addr, bus: BusHandle, sortition: Addr, e3_id: E3id, @@ -129,7 +126,6 @@ pub struct ThresholdPlaintextAggregator { } pub struct ThresholdPlaintextAggregatorParams { - pub multithread: Addr, pub bus: BusHandle, pub sortition: Addr, pub e3_id: E3id, @@ -141,7 +137,6 @@ impl ThresholdPlaintextAggregator { state: Persistable, ) -> Self { ThresholdPlaintextAggregator { - multithread: params.multithread, bus: params.bus, sortition: params.sortition, e3_id: params.e3_id, @@ -200,12 +195,10 @@ impl ThresholdPlaintextAggregator { }) } - pub fn create_calculate_threshold_decryption_event( - &self, - msg: ComputeAggregate, - ) -> Result { + pub fn handle_compute_aggregate(&mut self, msg: ComputeAggregate) -> Result<()> { info!("create_calculate_threshold_decryption_event..."); + let e3_id = self.e3_id.clone(); let state: Computing = self .state .get() @@ -215,7 +208,7 @@ impl ThresholdPlaintextAggregator { let trbfv_config = TrBFVConfig::new(state.params.clone(), state.threshold_n, state.threshold_m); - Ok(ComputeRequest::TrBFV( + let event = ComputeRequest::new( TrBFVRequest::CalculateThresholdDecryption( CalculateThresholdDecryptionRequest { ciphertexts: msg.ciphertext_output, @@ -224,7 +217,40 @@ impl ThresholdPlaintextAggregator { } .into(), ), - )) + CorrelationId::new(), + e3_id, + ); + self.bus.publish(event)?; + Ok(()) + } + + pub fn handle_compute_response(&mut self, msg: ComputeResponse) -> Result<()> { + ensure!( + msg.e3_id == self.e3_id, + "PlaintextAggregator should never receive incorrect e3_id msgs" + ); + + let TrBFVResponse::CalculateThresholdDecryption(response) = msg.response else { + // Must be another compute response so ignoring + return Ok(()); + }; + + info!("Received response {:?}", response); + + // Update the local state + let plaintext = response.plaintext; + + self.set_decryption(plaintext.clone())?; + + // Dispatch the PlaintextAggregated event + let event = PlaintextAggregated { + decrypted_output: plaintext, // Extracting here for now + e3_id: self.e3_id.clone(), + }; + + info!("Dispatching plaintext event {:?}", event); + self.bus.publish(event)?; + Ok(()) } } @@ -238,6 +264,7 @@ impl Handler for ThresholdPlaintextAggregator { match msg.into_data() { EnclaveEventData::DecryptionshareCreated(data) => ctx.notify(data), EnclaveEventData::E3RequestComplete(_) => ctx.notify(Die), + EnclaveEventData::ComputeResponse(data) => ctx.notify(data), _ => (), } } @@ -306,46 +333,20 @@ impl Handler for ThresholdPlaintextAggregator { } impl Handler for ThresholdPlaintextAggregator { - type Result = ResponseActFuture>; + type Result = (); fn handle(&mut self, msg: ComputeAggregate, _: &mut Self::Context) -> Self::Result { - let event = match self.create_calculate_threshold_decryption_event(msg) { - Ok(event) => event, - Err(e) => { - error!("{e}"); - return e3_utils::actix::bail_result(self, "{e}"); - } - }; - Box::pin( - self.multithread - .send(event) - .into_actor(self) - .map(move |res, act, _| { - let response: CalculateThresholdDecryptionResponse = match res? { - Ok(res) => res.try_into()?, - Err(e) => { - error!("{e}"); - bail!(e) - } - }; - - info!("Received response {:?}", response); - - // Update the local state - let plaintext = response.plaintext; - - act.set_decryption(plaintext.clone())?; - - // Dispatch the PlaintextAggregated event - let event = PlaintextAggregated { - decrypted_output: plaintext, // Extracting here for now - e3_id: act.e3_id.clone(), - }; - - info!("Dispatching plaintext event {:?}", event); - act.bus.publish(event)?; - Ok(()) - }), - ) + trap(EType::PlaintextAggregation, &self.bus.clone(), || { + self.handle_compute_aggregate(msg) + }) + } +} + +impl Handler for ThresholdPlaintextAggregator { + type Result = (); + fn handle(&mut self, msg: ComputeResponse, _: &mut Self::Context) -> Self::Result { + trap(EType::PlaintextAggregation, &self.bus.clone(), || { + self.handle_compute_response(msg) + }) } } diff --git a/crates/ciphernode-builder/Cargo.toml b/crates/ciphernode-builder/Cargo.toml index 233b9cda94..07b940c81e 100644 --- a/crates/ciphernode-builder/Cargo.toml +++ b/crates/ciphernode-builder/Cargo.toml @@ -25,6 +25,7 @@ e3-request.workspace = true e3-sortition.workspace = true e3-trbfv.workspace = true e3-utils.workspace = true +rayon.workspace = true tempfile.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/crates/ciphernode-builder/src/ciphernode_builder.rs b/crates/ciphernode-builder/src/ciphernode_builder.rs index dc6901f655..2c09de7996 100644 --- a/crates/ciphernode-builder/src/ciphernode_builder.rs +++ b/crates/ciphernode-builder/src/ciphernode_builder.rs @@ -16,7 +16,7 @@ use e3_aggregator::ext::{ use e3_config::chain_config::ChainConfig; use e3_crypto::Cipher; use e3_data::{InMemStore, Repositories, RepositoriesFactory}; -use e3_events::{EnclaveEvent, EventBus, EventBusConfig}; +use e3_events::{BusHandle, EnclaveEvent, EventBus, EventBusConfig}; use e3_evm::{ helpers::{ load_signer_from_repository, ConcreteReadProvider, ConcreteWriteProvider, EthProvider, @@ -29,13 +29,14 @@ use e3_evm::{ }; use e3_fhe::ext::FheExtension; use e3_keyshare::ext::{KeyshareExtension, ThresholdKeyshareExtension}; -use e3_multithread::Multithread; +use e3_multithread::{Multithread, MultithreadReport, TaskPool}; use e3_request::E3Router; use e3_sortition::{ CiphernodeSelector, CiphernodeSelectorFactory, FinalizedCommitteesRepositoryFactory, NodeStateRepositoryFactory, Sortition, SortitionBackend, SortitionRepositoryFactory, }; use e3_utils::{rand_eth_addr, SharedRng}; +use rayon::ThreadPool; use std::{collections::HashMap, path::PathBuf, sync::Arc}; use tracing::{error, info}; @@ -52,26 +53,27 @@ enum EventSystemType { #[derive(Derivative)] #[derivative(Debug)] pub struct CiphernodeBuilder { - name: String, address: Option, chains: Vec, #[derivative(Debug = "ignore")] cipher: Arc, contract_components: ContractComponents, + event_system: EventSystemType, in_mem_store: Option>, keyshare: Option, logging: bool, - event_system: EventSystemType, multithread_cache: Option>, multithread_concurrent_jobs: Option, - multithread_capture_events: bool, + multithread_report: Option>, + name: String, plaintext_agg: bool, pubkey_agg: bool, rng: SharedRng, - source_bus: Option>>>, sortition_backend: SortitionBackend, + source_bus: Option>>>, testmode_errors: bool, testmode_history: bool, + task_pool: Option, threads: Option, threshold_plaintext_agg: bool, } @@ -104,26 +106,27 @@ impl CiphernodeBuilder { /// - cipher - Cipher for encryption and decryption of sensitive data pub fn new(name: &str, rng: SharedRng, cipher: Arc) -> Self { Self { - name: name.to_owned(), address: None, chains: vec![], cipher, contract_components: ContractComponents::default(), + event_system: EventSystemType::InMem, in_mem_store: None, keyshare: None, logging: false, multithread_cache: None, + multithread_concurrent_jobs: None, + multithread_report: None, + name: name.to_owned(), plaintext_agg: false, pubkey_agg: false, - multithread_concurrent_jobs: None, - event_system: EventSystemType::InMem, rng, - source_bus: None, sortition_backend: SortitionBackend::score(), + source_bus: None, testmode_errors: false, testmode_history: false, + task_pool: None, threads: None, - multithread_capture_events: false, threshold_plaintext_agg: false, } } @@ -215,9 +218,15 @@ impl CiphernodeBuilder { self } - /// Inject a preexisting multithread actor. This is mainly used for testing. - pub fn with_injected_multithread(mut self, multithread: Addr) -> Self { - self.multithread_cache = Some(multithread); + /// Connect rayon work to the given threadpool + pub fn with_shared_taskpool(mut self, pool: &TaskPool) -> Self { + self.task_pool = Some(pool.clone()); + self + } + + /// Shared MultithreadReport for benchmarking + pub fn with_shared_multithread_report(mut self, report: &Addr) -> Self { + self.multithread_report = Some(report.clone()); self } @@ -246,11 +255,6 @@ impl CiphernodeBuilder { self } - pub fn with_multithread_capture_events(mut self) -> Self { - self.multithread_capture_events = true; - self - } - /// Setup a ThresholdPlaintextAggregator pub fn with_threshold_plaintext_aggregation(mut self) -> Self { self.threshold_plaintext_agg = true; @@ -470,13 +474,12 @@ impl CiphernodeBuilder { let mut e3_builder = E3Router::builder(&bus, store.clone()); if let Some(KeyshareKind::Threshold) = self.keyshare { - let multithread = self.ensure_multithread(); + let _ = self.ensure_multithread(&bus); let share_encryption_params = e3_trbfv::helpers::get_share_encryption_params(); info!("Setting up ThresholdKeyshareExtension"); e3_builder = e3_builder.with(ThresholdKeyshareExtension::create( &bus, &self.cipher, - &multithread, &addr, share_encryption_params, )) @@ -502,11 +505,9 @@ impl CiphernodeBuilder { if self.threshold_plaintext_agg { info!("Setting up ThresholdPlaintextAggregatorExtension NEW!"); - let multithread = self.ensure_multithread(); + let _ = self.ensure_multithread(&bus); e3_builder = e3_builder.with(ThresholdPlaintextAggregatorExtension::create( - &bus, - &sortition, - &multithread, + &bus, &sortition, )) } @@ -526,19 +527,29 @@ impl CiphernodeBuilder { )) } - fn ensure_multithread(&mut self) -> Addr { + fn ensure_multithread(&mut self, bus: &BusHandle) -> Addr { // If we have it cached return it if let Some(cached) = self.multithread_cache.clone() { return cached; } + info!("Setting up multithread actor..."); + + // Setup threadpool if not set + let task_pool = self.task_pool.clone().unwrap_or_else(|| { + Multithread::create_taskpool( + self.threads.unwrap_or(1), + self.multithread_concurrent_jobs.unwrap_or(1), + ) + }); + // Create it let addr = Multithread::attach( + bus, self.rng.clone(), self.cipher.clone(), - self.threads.unwrap_or(1), - self.multithread_concurrent_jobs.unwrap_or(1), - self.multithread_capture_events, + task_pool, + self.multithread_report.clone(), ); // Set the cache diff --git a/crates/entrypoint/src/start/aggregator_start.rs b/crates/entrypoint/src/start/aggregator_start.rs index 602863c46e..b9132f25e5 100644 --- a/crates/entrypoint/src/start/aggregator_start.rs +++ b/crates/entrypoint/src/start/aggregator_start.rs @@ -5,7 +5,7 @@ // or FITNESS FOR A PARTICULAR PURPOSE. use anyhow::Result; -use e3_ciphernode_builder::{get_enclave_bus_handle, get_enclave_event_bus, CiphernodeBuilder}; +use e3_ciphernode_builder::CiphernodeBuilder; use e3_config::AppConfig; use e3_crypto::Cipher; use e3_data::RepositoriesFactory; diff --git a/crates/entrypoint/src/start/start.rs b/crates/entrypoint/src/start/start.rs index 16a19870d9..ac2f7e2e03 100644 --- a/crates/entrypoint/src/start/start.rs +++ b/crates/entrypoint/src/start/start.rs @@ -6,7 +6,7 @@ use alloy::primitives::Address; use anyhow::Result; -use e3_ciphernode_builder::{get_enclave_bus_handle, get_enclave_event_bus, CiphernodeBuilder}; +use e3_ciphernode_builder::CiphernodeBuilder; use e3_config::AppConfig; use e3_crypto::Cipher; use e3_data::RepositoriesFactory; diff --git a/crates/events/src/enclave_event/compute_request/mod.rs b/crates/events/src/enclave_event/compute_request/mod.rs index 64ff467424..53d33bc3f7 100644 --- a/crates/events/src/enclave_event/compute_request/mod.rs +++ b/crates/events/src/enclave_event/compute_request/mod.rs @@ -17,28 +17,41 @@ use e3_trbfv::{ }; use serde::{Deserialize, Serialize}; +use crate::{CorrelationId, E3id}; + /// The compute instruction for a threadpool computation. /// This enum provides protocol disambiguation #[derive(Message, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] -#[rtype(result = "Result")] -pub enum ComputeRequest { - /// By Protocol - TrBFV(e3_trbfv::TrBFVRequest), - // Eg. TFHE(TFHERequest) +// #[rtype(result = "Result")] +#[rtype(result = "()")] +pub struct ComputeRequest { + // TODO: Disambiguate protocol later + pub request: e3_trbfv::TrBFVRequest, + pub correlation_id: CorrelationId, + pub e3_id: E3id, // It may come to pass this should be option + // but our initial need is only within the e3 flow +} +impl ComputeRequest { + pub fn new( + request: e3_trbfv::TrBFVRequest, + correlation_id: CorrelationId, + e3_id: E3id, + ) -> Self { + Self { + request, + correlation_id, + e3_id, + } + } } - impl ToString for ComputeRequest { fn to_string(&self) -> String { - match self { - Self::TrBFV(e3_trbfv::TrBFVRequest::GenEsiSss(_)) => "GenEsiSss", - Self::TrBFV(e3_trbfv::TrBFVRequest::GenPkShareAndSkSss(_)) => "GenPkShareAndSkSss", - Self::TrBFV(e3_trbfv::TrBFVRequest::CalculateDecryptionKey(_)) => { - "CalculateDecryptionKey" - } - Self::TrBFV(e3_trbfv::TrBFVRequest::CalculateDecryptionShare(_)) => { - "CalculateDecryptionShare" - } - Self::TrBFV(e3_trbfv::TrBFVRequest::CalculateThresholdDecryption(_)) => { + match self.request { + e3_trbfv::TrBFVRequest::GenEsiSss(_) => "GenEsiSss", + e3_trbfv::TrBFVRequest::GenPkShareAndSkSss(_) => "GenPkShareAndSkSss", + e3_trbfv::TrBFVRequest::CalculateDecryptionKey(_) => "CalculateDecryptionKey", + e3_trbfv::TrBFVRequest::CalculateDecryptionShare(_) => "CalculateDecryptionShare", + e3_trbfv::TrBFVRequest::CalculateThresholdDecryption(_) => { "CalculateThresholdDecryption" } } @@ -50,44 +63,66 @@ impl ToString for ComputeRequest { /// This enum provides protocol disambiguation #[derive(Message, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] #[rtype(result = "()")] -pub enum ComputeResponse { - /// By Protocol - TrBFV(e3_trbfv::TrBFVResponse), - // Eg. TFHE(TFHEResponse) +pub struct ComputeResponse { + pub response: e3_trbfv::TrBFVResponse, + pub correlation_id: CorrelationId, + pub e3_id: E3id, +} + +impl ComputeResponse { + pub fn new( + response: e3_trbfv::TrBFVResponse, + correlation_id: CorrelationId, + e3_id: E3id, + ) -> ComputeResponse { + ComputeResponse { + response, + correlation_id, + e3_id, + } + } } /// An error from a threadpool computation /// This enum provides protocol disambiguation +#[derive(Message, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[rtype(result = "()")] +pub struct ComputeRequestError { + kind: ComputeRequestErrorKind, + request: ComputeRequest, +} + +impl ComputeRequestError { + pub fn new(kind: ComputeRequestErrorKind, request: ComputeRequest) -> Self { + Self { kind, request } + } +} + #[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub enum ComputeRequestError { - /// By Protocol +pub enum ComputeRequestErrorKind { TrBFV(e3_trbfv::TrBFVError), - RecvError(String), - SemaphoreError(String), - // Eg. TFHE(TFHEError) +} + +impl ComputeRequestError { + pub fn get_err(&self) -> &ComputeRequestErrorKind { + &self.kind + } } impl std::error::Error for ComputeRequestError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - ComputeRequestError::TrBFV(err) => Some(err), - _ => None, + match self.get_err() { + ComputeRequestErrorKind::TrBFV(err) => Some(err), } } } impl fmt::Display for ComputeRequestError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ComputeRequestError::TrBFV(err) => { + match self.get_err() { + ComputeRequestErrorKind::TrBFV(err) => { write!(f, "We had an error number crunching: {:?}", err) } - ComputeRequestError::SemaphoreError(name) => { - write!(f, "Multithread SemaphoreError. This means there was a problem acquiring the semaphore lock for this ComputeRequest: '{name}'") - } - ComputeRequestError::RecvError(name) => { - write!(f, "Multithread RecvError. This means there was a problem receiving a response for this ComputeRequest: '{name}'") - } } } } @@ -95,8 +130,8 @@ impl fmt::Display for ComputeRequestError { impl TryFrom for CalculateDecryptionShareResponse { type Error = anyhow::Error; fn try_from(value: ComputeResponse) -> Result { - match value { - ComputeResponse::TrBFV(TrBFVResponse::CalculateDecryptionShare(data)) => Ok(data), + match value.response { + TrBFVResponse::CalculateDecryptionShare(data) => Ok(data), _ => { bail!("Expected CalculateDecryptionShareResponse in response but it was not found") } @@ -107,8 +142,8 @@ impl TryFrom for CalculateDecryptionShareResponse { impl TryFrom for CalculateDecryptionKeyResponse { type Error = anyhow::Error; fn try_from(value: ComputeResponse) -> Result { - match value { - ComputeResponse::TrBFV(TrBFVResponse::CalculateDecryptionKey(data)) => Ok(data), + match value.response { + TrBFVResponse::CalculateDecryptionKey(data) => Ok(data), _ => { bail!("Expected CalculateDecryptionKeyResponse in response but it was not found") } @@ -119,8 +154,8 @@ impl TryFrom for CalculateDecryptionKeyResponse { impl TryFrom for GenPkShareAndSkSssResponse { type Error = anyhow::Error; fn try_from(value: ComputeResponse) -> Result { - match value { - ComputeResponse::TrBFV(TrBFVResponse::GenPkShareAndSkSss(data)) => Ok(data), + match value.response { + TrBFVResponse::GenPkShareAndSkSss(data) => Ok(data), _ => { bail!("Expected GenPkShareAndSkSssResponse in response but it was not found") } @@ -131,8 +166,8 @@ impl TryFrom for GenPkShareAndSkSssResponse { impl TryFrom for GenEsiSssResponse { type Error = anyhow::Error; fn try_from(value: ComputeResponse) -> Result { - match value { - ComputeResponse::TrBFV(TrBFVResponse::GenEsiSss(data)) => Ok(data), + match value.response { + TrBFVResponse::GenEsiSss(data) => Ok(data), _ => { bail!("Expected GenEsiSssResponse in response but it was not found") } @@ -143,8 +178,8 @@ impl TryFrom for GenEsiSssResponse { impl TryFrom for CalculateThresholdDecryptionResponse { type Error = anyhow::Error; fn try_from(value: ComputeResponse) -> Result { - match value { - ComputeResponse::TrBFV(TrBFVResponse::CalculateThresholdDecryption(data)) => Ok(data), + match value.response { + TrBFVResponse::CalculateThresholdDecryption(data) => Ok(data), _ => { bail!("Expected CalculateThresholdDecryptionResponse in response but it was not found") } diff --git a/crates/events/src/enclave_event/enclave_error.rs b/crates/events/src/enclave_event/enclave_error.rs index e748298cd5..241ce8f563 100644 --- a/crates/events/src/enclave_event/enclave_error.rs +++ b/crates/events/src/enclave_event/enclave_error.rs @@ -39,6 +39,7 @@ pub enum EType { Sortition, Data, Event, + Computation, } impl EnclaveError { diff --git a/crates/events/src/enclave_event/mod.rs b/crates/events/src/enclave_event/mod.rs index 0b1ac02a11..7536e30c13 100644 --- a/crates/events/src/enclave_event/mod.rs +++ b/crates/events/src/enclave_event/mod.rs @@ -122,6 +122,9 @@ pub enum EnclaveEventData { EncryptionKeyCreated(EncryptionKeyCreated), EncryptionKeyCollectionFailed(EncryptionKeyCollectionFailed), ThresholdShareCollectionFailed(ThresholdShareCollectionFailed), + ComputeRequest(ComputeRequest), + ComputeResponse(ComputeResponse), + ComputeRequestError(ComputeRequestError), /// This is a test event to use in testing TestEvent(TestEvent), } @@ -302,6 +305,7 @@ impl EnclaveEvent { EnclaveEventData::TicketGenerated(ref data) => Some(data.e3_id.clone()), EnclaveEventData::TicketSubmitted(ref data) => Some(data.e3_id.clone()), EnclaveEventData::EncryptionKeyCreated(ref data) => Some(data.e3_id.clone()), + EnclaveEventData::ComputeResponse(ref data) => Some(data.e3_id.clone()), _ => None, } } @@ -336,7 +340,10 @@ impl_into_event_data!( ThresholdShareCreated, EncryptionKeyCreated, EncryptionKeyCollectionFailed, - ThresholdShareCollectionFailed + ThresholdShareCollectionFailed, + ComputeRequest, + ComputeResponse, + ComputeRequestError ); impl TryFrom<&EnclaveEvent> for EnclaveError { diff --git a/crates/keyshare/src/ext.rs b/crates/keyshare/src/ext.rs index 2c0750224e..ab2a47d908 100644 --- a/crates/keyshare/src/ext.rs +++ b/crates/keyshare/src/ext.rs @@ -8,14 +8,13 @@ use crate::{ Keyshare, KeyshareParams, KeyshareRepositoryFactory, KeyshareState, ThresholdKeyshare, ThresholdKeyshareParams, ThresholdKeyshareRepositoryFactory, ThresholdKeyshareState, }; -use actix::{Actor, Addr}; +use actix::Actor; use anyhow::{anyhow, Result}; use async_trait::async_trait; use e3_crypto::Cipher; use e3_data::{AutoPersist, RepositoriesFactory}; use e3_events::{prelude::*, BusHandle, EType, EnclaveEvent, EnclaveEventData}; use e3_fhe::ext::FHE_KEY; -use e3_multithread::Multithread; use e3_request::{E3Context, E3ContextSnapshot, E3Extension, META_KEY}; use std::sync::Arc; @@ -116,7 +115,6 @@ pub struct ThresholdKeyshareExtension { bus: BusHandle, cipher: Arc, address: String, - multithread: Addr, share_encryption_params: Arc, } @@ -124,14 +122,12 @@ impl ThresholdKeyshareExtension { pub fn create( bus: &BusHandle, cipher: &Arc, - multithread: &Addr, address: &str, share_encryption_params: Arc, ) -> Box { Box::new(Self { bus: bus.clone(), cipher: cipher.to_owned(), - multithread: multithread.clone(), address: address.to_owned(), share_encryption_params, }) @@ -174,7 +170,6 @@ impl E3Extension for ThresholdKeyshareExtension { ThresholdKeyshare::new(ThresholdKeyshareParams { bus: self.bus.clone(), cipher: self.cipher.clone(), - multithread: self.multithread.clone(), state: container, share_encryption_params: self.share_encryption_params.clone(), }) @@ -205,7 +200,6 @@ impl E3Extension for ThresholdKeyshareExtension { let value = ThresholdKeyshare::new(ThresholdKeyshareParams { bus: self.bus.clone(), cipher: self.cipher.clone(), - multithread: self.multithread.clone(), state, share_encryption_params: self.share_encryption_params.clone(), }) diff --git a/crates/keyshare/src/threshold_keyshare.rs b/crates/keyshare/src/threshold_keyshare.rs index fc49d89a94..f437a11dc8 100644 --- a/crates/keyshare/src/threshold_keyshare.rs +++ b/crates/keyshare/src/threshold_keyshare.rs @@ -9,14 +9,13 @@ use anyhow::{anyhow, bail, Result}; use e3_crypto::{Cipher, SensitiveBytes}; use e3_data::Persistable; use e3_events::{ - prelude::*, BusHandle, CiphernodeSelected, CiphertextOutputPublished, ComputeRequest, - ComputeResponse, DecryptionshareCreated, Die, E3RequestComplete, E3id, EnclaveEvent, - EnclaveEventData, EncryptionKey, EncryptionKeyCollectionFailed, EncryptionKeyCreated, - KeyshareCreated, PartyId, ThresholdShare, ThresholdShareCollectionFailed, + prelude::*, trap, BusHandle, CiphernodeSelected, CiphertextOutputPublished, ComputeRequest, + ComputeResponse, CorrelationId, DecryptionshareCreated, Die, E3RequestComplete, E3id, EType, + EnclaveEvent, EnclaveEventData, EncryptionKey, EncryptionKeyCollectionFailed, + EncryptionKeyCreated, KeyshareCreated, PartyId, ThresholdShare, ThresholdShareCollectionFailed, ThresholdShareCreated, }; use e3_fhe::create_crp; -use e3_multithread::Multithread; use e3_trbfv::{ calculate_decryption_key::CalculateDecryptionKeyRequest, calculate_decryption_share::{ @@ -28,7 +27,7 @@ use e3_trbfv::{ shares::{BfvEncryptedShares, EncryptableVec, Encrypted, ShamirShare, SharedSecret}, TrBFVConfig, TrBFVRequest, TrBFVResponse, }; -use e3_utils::{bail, to_ordered_vec, utility_types::ArcBytes}; +use e3_utils::{to_ordered_vec, utility_types::ArcBytes}; use fhe::bfv::BfvParameters; use fhe::bfv::{PublicKey, SecretKey}; use fhe_traits::{DeserializeParametrized, Serialize}; @@ -302,7 +301,6 @@ impl TryInto for ThresholdKeyshareState { pub struct ThresholdKeyshareParams { pub bus: BusHandle, pub cipher: Arc, - pub multithread: Addr, pub state: Persistable, pub share_encryption_params: Arc, } @@ -312,7 +310,6 @@ pub struct ThresholdKeyshare { cipher: Arc, decryption_key_collector: Option>, encryption_key_collector: Option>, - multithread: Addr, state: Persistable, share_encryption_params: Arc, } @@ -324,7 +321,6 @@ impl ThresholdKeyshare { cipher: params.cipher, decryption_key_collector: None, encryption_key_collector: None, - multithread: params.multithread, state: params.state, share_encryption_params: params.share_encryption_params, } @@ -410,6 +406,22 @@ impl ThresholdKeyshare { Ok(()) } + pub fn handle_compute_response(&mut self, msg: ComputeResponse) -> Result<()> { + match &msg.response { + TrBFVResponse::GenEsiSss(_) => self.handle_gen_esi_sss_response(msg), + TrBFVResponse::GenPkShareAndSkSss(_) => { + self.handle_gen_pk_share_and_sk_sss_response(msg) + } + TrBFVResponse::CalculateDecryptionKey(_) => { + self.handle_calculate_decryption_key_response(msg) + } + TrBFVResponse::CalculateDecryptionShare(_) => { + self.handle_calculate_decryption_share_response(msg) + } + _ => Ok(()), + } + } + /// 1. CiphernodeSelected - Generate BFV keys and start collecting pub fn handle_ciphernode_selected( &mut self, @@ -457,7 +469,6 @@ impl ThresholdKeyshare { pub fn handle_all_encryption_keys_collected( &mut self, msg: AllEncryptionKeysCollected, - address: Addr, ) -> Result<()> { info!( "AllEncryptionKeysCollected - {} keys received", @@ -479,14 +490,16 @@ impl ThresholdKeyshare { )) })?; - address.do_send(GenEsiSss(current.ciphernode_selected.clone())); - address.do_send(GenPkShareAndSkSss(current.ciphernode_selected)); + self.handle_gen_esi_sss_requested(GenEsiSss(current.ciphernode_selected.clone()))?; + self.handle_gen_pk_share_and_sk_sss_requested(GenPkShareAndSkSss( + current.ciphernode_selected, + ))?; Ok(()) } /// 2. GenEsiSss - pub fn handle_gen_esi_sss_requested(&self, msg: GenEsiSss) -> Result { + pub fn handle_gen_esi_sss_requested(&self, msg: GenEsiSss) -> Result<()> { info!("GenEsiSss on ThresholdKeyshare"); let evt = msg.0; @@ -495,6 +508,7 @@ impl ThresholdKeyshare { // bundle them in with the params error_size, esi_per_ct, + e3_id, .. } = evt.clone(); @@ -505,16 +519,21 @@ impl ThresholdKeyshare { let trbfv_config = state.get_trbfv_config(); - let event = ComputeRequest::TrBFV(TrBFVRequest::GenEsiSss( - GenEsiSssRequest { - trbfv_config, - error_size, - esi_per_ct: esi_per_ct as u64, - } - .into(), - )); + let event = ComputeRequest::new( + TrBFVRequest::GenEsiSss( + GenEsiSssRequest { + trbfv_config, + error_size, + esi_per_ct: esi_per_ct as u64, + } + .into(), + ), + CorrelationId::new(), + e3_id, + ); - Ok(event) + self.bus.publish(event)?; + Ok(()) } /// 2a. GenEsiSss result @@ -567,12 +586,9 @@ impl ThresholdKeyshare { } /// 3. GenPkShareAndSkSss - pub fn handle_gen_pk_share_and_sk_sss_requested( - &self, - msg: GenPkShareAndSkSss, - ) -> Result { + pub fn handle_gen_pk_share_and_sk_sss_requested(&self, msg: GenPkShareAndSkSss) -> Result<()> { info!("GenPkShareAndSkSss on ThresholdKeyshare"); - let CiphernodeSelected { seed, .. } = msg.0; + let CiphernodeSelected { seed, e3_id, .. } = msg.0; let state = self .state .get() @@ -587,16 +603,21 @@ impl ThresholdKeyshare { ) .to_bytes(), ); - let event = ComputeRequest::TrBFV(TrBFVRequest::GenPkShareAndSkSss( - GenPkShareAndSkSssRequest { trbfv_config, crp }.into(), - )); + let event = ComputeRequest::new( + TrBFVRequest::GenPkShareAndSkSss( + GenPkShareAndSkSssRequest { trbfv_config, crp }.into(), + ), + CorrelationId::new(), + e3_id, + ); - Ok(event) + self.bus.publish(event)?; + Ok(()) } /// 3a. GenPkShareAndSkSss result pub fn handle_gen_pk_share_and_sk_sss_response(&mut self, res: ComputeResponse) -> Result<()> { - let ComputeResponse::TrBFV(TrBFVResponse::GenPkShareAndSkSss(output)) = res else { + let TrBFVResponse::GenPkShareAndSkSss(output) = res.response else { bail!("Error extracting data from compute process") }; @@ -728,10 +749,11 @@ impl ThresholdKeyshare { pub fn handle_all_threshold_shares_collected( &self, msg: AllThresholdSharesCollected, - ) -> Result { + ) -> Result<()> { info!("AllThresholdSharesCollected"); let cipher = self.cipher.clone(); let state = self.state.try_get()?; + let e3_id = state.get_e3_id(); let party_id = state.party_id as usize; let trbfv_config = state.get_trbfv_config(); @@ -784,14 +806,19 @@ impl ThresholdKeyshare { sk_sss_collected: sk_sss_collected.encrypt(&cipher)?, }; - let event = ComputeRequest::TrBFV(TrBFVRequest::CalculateDecryptionKey(request)); + let event = ComputeRequest::new( + TrBFVRequest::CalculateDecryptionKey(request), + CorrelationId::new(), + e3_id.clone(), + ); - Ok(event) + self.bus.publish(event)?; + Ok(()) } /// 5a. CalculateDecryptionKeyResponse -> KeyshareCreated pub fn handle_calculate_decryption_key_response(&mut self, res: ComputeResponse) -> Result<()> { - let ComputeResponse::TrBFV(TrBFVResponse::CalculateDecryptionKey(output)) = res else { + let TrBFVResponse::CalculateDecryptionKey(output) = res.response else { bail!("Error extracting data from compute process") }; @@ -815,13 +842,13 @@ impl ThresholdKeyshare { })?; let state = self.state.try_get()?; - let e3_id = state.get_e3_id().clone(); + let e3_id = state.get_e3_id(); let address = state.get_address().to_owned(); let current: ReadyForDecryption = state.clone().try_into()?; self.bus.publish(KeyshareCreated { pubkey: current.pk_share, - e3_id, + e3_id: e3_id.clone(), node: address, })?; @@ -832,7 +859,7 @@ impl ThresholdKeyshare { pub fn handle_ciphertext_output_published( &mut self, msg: CiphertextOutputPublished, - ) -> Result { + ) -> Result<()> { // Set state to decrypting self.state.try_mutate(|s| { use KeyshareState as K; @@ -850,20 +877,25 @@ impl ThresholdKeyshare { let ciphertext_output = msg.ciphertext_output; let state = self.state.try_get()?; + let e3_id = state.get_e3_id(); let decrypting: Decrypting = state.clone().try_into()?; let trbfv_config = state.get_trbfv_config(); - let event = ComputeRequest::TrBFV(TrBFVRequest::CalculateDecryptionShare( - CalculateDecryptionShareRequest { - name: format!("party_id({})", state.party_id), - ciphertexts: ciphertext_output, - sk_poly_sum: decrypting.sk_poly_sum, - es_poly_sum: decrypting.es_poly_sum, - trbfv_config, - } - .into(), - )); - - Ok(event) // CalculateDecryptionShareRequest + let event = ComputeRequest::new( + TrBFVRequest::CalculateDecryptionShare( + CalculateDecryptionShareRequest { + name: format!("party_id({})", state.party_id), + ciphertexts: ciphertext_output, + sk_poly_sum: decrypting.sk_poly_sum, + es_poly_sum: decrypting.es_poly_sum, + trbfv_config, + } + .into(), + ), + CorrelationId::new(), + e3_id.clone(), + ); + self.bus.publish(event)?; // CalculateDecryptionShareRequest + Ok(()) } /// CalculateDecryptionShareResponse @@ -898,45 +930,6 @@ impl ThresholdKeyshare { Ok(()) } - - /// This is handling some of the dark arts of actix - /// This effectively calls the request function which - /// generates a ComputeRequest message and then runs the request - /// on the multithread actor and trigggers the response - /// handler with the results. Errors at this stage are simply - /// logged. Eventually we will need to configure a policy here - /// For example retry with exponential backoff - fn multithread_request( - &mut self, - request_fn: F, - response_fn: R, - ) -> ResponseActFuture - where - F: FnOnce(&mut Self) -> Result, - R: FnOnce(&mut Self, ComputeResponse, &mut ::Context) -> Result<()> - + 'static, - { - // When handling futures in actix you need a pinned box - // This is so that the future stays in the same spot in memory - Box::pin( - // Run the request function and print if there is an error - match request_fn(self) { - Ok(evt) => self.multithread.send(evt), - Err(e) => { - error!("{e}"); - return bail(self); - } - } - .into_actor(self) - .map(move |res, act, ctx| { - // Run the response function and print if there is an error - match (|| -> Result<()> { response_fn(act, res??, ctx) })() { - Ok(_) => (), - Err(e) => error!("{e}"), - } - }), - ) - } } // Will only receive events that are for this specific e3_id @@ -953,68 +946,54 @@ impl Handler for ThresholdKeyshare { let _ = self.handle_encryption_key_created(data, ctx.address()); } EnclaveEventData::E3RequestComplete(data) => ctx.notify(data), + EnclaveEventData::ComputeResponse(data) => ctx.notify(data), _ => (), } } } -impl Handler for ThresholdKeyshare { +impl Handler for ThresholdKeyshare { type Result = (); - fn handle(&mut self, msg: CiphernodeSelected, ctx: &mut Self::Context) -> Self::Result { - match self.handle_ciphernode_selected(msg, ctx.address()) { - Err(e) => error!("{e}"), - Ok(_) => (), - } - } -} - -impl Handler for ThresholdKeyshare { - type Result = ResponseActFuture; - fn handle(&mut self, msg: GenEsiSss, _: &mut Self::Context) -> Self::Result { - self.multithread_request( - |act| act.handle_gen_esi_sss_requested(msg), - |act, res, _| act.handle_gen_esi_sss_response(res), - ) + fn handle(&mut self, msg: ComputeResponse, _: &mut Self::Context) -> Self::Result { + trap(EType::KeyGeneration, &self.bus.clone(), || { + self.handle_compute_response(msg) + }) } } -impl Handler for ThresholdKeyshare { - type Result = ResponseActFuture; - fn handle(&mut self, msg: GenPkShareAndSkSss, _: &mut Self::Context) -> Self::Result { - self.multithread_request( - |act| act.handle_gen_pk_share_and_sk_sss_requested(msg), - |act, res, _| act.handle_gen_pk_share_and_sk_sss_response(res), - ) +impl Handler for ThresholdKeyshare { + type Result = (); + fn handle(&mut self, msg: CiphernodeSelected, ctx: &mut Self::Context) -> Self::Result { + trap(EType::KeyGeneration, &self.bus.clone(), || { + self.handle_ciphernode_selected(msg, ctx.address()) + }) } } impl Handler for ThresholdKeyshare { type Result = (); - fn handle(&mut self, msg: AllEncryptionKeysCollected, ctx: &mut Self::Context) -> Self::Result { - match self.handle_all_encryption_keys_collected(msg, ctx.address()) { - Err(e) => error!("{e}"), - Ok(_) => (), - } + fn handle(&mut self, msg: AllEncryptionKeysCollected, _: &mut Self::Context) -> Self::Result { + trap(EType::KeyGeneration, &self.bus.clone(), || { + self.handle_all_encryption_keys_collected(msg) + }) } } impl Handler for ThresholdKeyshare { - type Result = ResponseActFuture; + type Result = (); fn handle(&mut self, msg: AllThresholdSharesCollected, _: &mut Self::Context) -> Self::Result { - self.multithread_request( - |act| act.handle_all_threshold_shares_collected(msg), - |act, res, _| act.handle_calculate_decryption_key_response(res), - ) + trap(EType::KeyGeneration, &self.bus.clone(), || { + self.handle_all_threshold_shares_collected(msg) + }) } } impl Handler for ThresholdKeyshare { - type Result = ResponseActFuture; + type Result = (); fn handle(&mut self, msg: CiphertextOutputPublished, _: &mut Self::Context) -> Self::Result { - self.multithread_request( - |act| act.handle_ciphertext_output_published(msg), - |act, res, _| act.handle_calculate_decryption_share_response(res), - ) + trap(EType::KeyGeneration, &self.bus.clone(), || { + self.handle_ciphertext_output_published(msg) + }) } } diff --git a/crates/multithread/Cargo.toml b/crates/multithread/Cargo.toml index fa2101d0fe..ddc497ad7e 100644 --- a/crates/multithread/Cargo.toml +++ b/crates/multithread/Cargo.toml @@ -18,4 +18,5 @@ rand = { workspace = true } rayon = { workspace = true } tokio = { workspace = true } tracing = { workspace = true } +thiserror = { workspace = true } zeroize = { workspace = true } diff --git a/crates/multithread/src/lib.rs b/crates/multithread/src/lib.rs index 3cb3e8a78a..c98c55dc29 100644 --- a/crates/multithread/src/lib.rs +++ b/crates/multithread/src/lib.rs @@ -4,300 +4,10 @@ // without even the implied warranty of MERCHANTABILITY // or FITNESS FOR A PARTICULAR PURPOSE. +mod multithread; +mod pool; mod report; -use std::sync::Arc; -use std::thread; -use std::time::Duration; -use std::time::Instant; - -use actix::prelude::*; -use actix::{Actor, Handler}; -use anyhow::Result; -use e3_crypto::Cipher; -use e3_events::{ComputeRequest, ComputeRequestError, ComputeResponse}; -use e3_trbfv::calculate_decryption_key::calculate_decryption_key; -use e3_trbfv::calculate_decryption_share::calculate_decryption_share; -use e3_trbfv::calculate_threshold_decryption::calculate_threshold_decryption; -use e3_trbfv::gen_esi_sss::gen_esi_sss; -use e3_trbfv::gen_pk_share_and_sk_sss::gen_pk_share_and_sk_sss; -use e3_trbfv::{TrBFVError, TrBFVRequest, TrBFVResponse}; -use e3_utils::SharedRng; -use rand::Rng; -use rayon::{self, ThreadPool}; -use report::MultithreadReport; -use tokio::sync::Semaphore; -use tracing::error; -use tracing::info; -use tracing::warn; - -/// Multithread actor -pub struct Multithread { - rng: SharedRng, - cipher: Arc, - rayon_limit: Arc, - thread_pool: Arc, - report: Option, -} - -impl Multithread { - pub fn new( - rng: SharedRng, - cipher: Arc, - rayon_threads: usize, - max_simultaneous_rayon_tasks: usize, - capture_events: bool, - ) -> Self { - let thread_pool = Arc::new( - rayon::ThreadPoolBuilder::new() - .num_threads(rayon_threads) - .build() - .expect("Failed to create Rayon thread pool"), - ); - info!( - "Created threadpool with {} threads.", - thread_pool.current_num_threads() - ); - let rayon_limit = Arc::new(Semaphore::new(max_simultaneous_rayon_tasks)); - - Self { - rng, - cipher, - thread_pool, - rayon_limit, - report: if capture_events { - Some(MultithreadReport::new( - rayon_threads, - max_simultaneous_rayon_tasks, - )) - } else { - None - }, - } - } - - /// Subtract the given amount from the total number of available threads and return the result - pub fn get_max_threads_minus(amount: usize) -> usize { - let total_threads = thread::available_parallelism() - .map(|n| n.get()) - .unwrap_or(1); - let threads_to_use = std::cmp::max(1, total_threads.saturating_sub(amount)); - threads_to_use - } - - pub fn attach( - rng: SharedRng, - cipher: Arc, - rayon_threads: usize, - max_simultaneous_rayon_tasks: usize, - capture_events: bool, - ) -> Addr { - Self::new( - rng.clone(), - cipher.clone(), - rayon_threads, - max_simultaneous_rayon_tasks, - capture_events, - ) - .start() - } -} - -impl Actor for Multithread { - type Context = actix::Context; -} - -impl Handler for Multithread { - type Result = ResponseFuture>; - fn handle(&mut self, msg: ComputeRequest, ctx: &mut Self::Context) -> Self::Result { - let cipher = self.cipher.clone(); - let rng = self.rng.clone(); - let thread_pool = self.thread_pool.clone(); - let semaphore = self.rayon_limit.clone(); - let msg_string = msg.to_string(); - let self_addr = ctx.address(); - let capture_events = self.report.is_some(); - let job_name = msg_string.clone(); - Box::pin(async move { - // Block until we have enough task slots available we have to do this this way as - // because we use do_send() everywhere there is no backpressure on the actors - let _permit = semaphore - .acquire() - .await - .map_err(|_| ComputeRequestError::SemaphoreError(msg_string.to_string()))?; - - // Warn of long running jobs - let warning_handle = tokio::spawn(async move { - tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; - warn!( - "Job '{}' has been running for more than 10 seconds", - job_name - ); - tokio::time::sleep(tokio::time::Duration::from_secs(30)).await; - error!( - "Job '{}' has been running for more than 30 seconds", - job_name - ); - }); - - // This uses channels to track pending and complete tasks when - // using the thread pool - let (tx, rx) = tokio::sync::oneshot::channel(); - - // We spawn a thread on rayon moving to "sync"-land - thread_pool.spawn(move || { - // Do the actual work this is gonna take a while... - let (result, duration) = handle_compute_request(rng, cipher, msg); - - // try to return the result and it's duration note this is sync as it is a oneshot sender. - if let Err(res) = tx.send((result, Some(duration))) { - error!( - "There was an error sending the result from the multithread actor: result = {:?}", - res - ); - } - }); - // we are back in async io land... - - // await the oneshot - let (result, duration) = rx.await.unwrap_or_else(|_| { - ( - Err(ComputeRequestError::RecvError(msg_string.to_string())), - None, - ) - }); - - warning_handle.abort(); - - // incase we are collecting events for a report - if capture_events { - if let Some(dur) = duration { - self_addr.do_send(TrackDuration::new(msg_string, dur)) - } - }; - - result - }) - } -} - -impl Handler for Multithread { - type Result = (); - fn handle(&mut self, msg: TrackDuration, _: &mut Self::Context) -> Self::Result { - // If the report is there we are tracking durations - if let Some(report) = &mut self.report { - report.track(msg); - }; - } -} - -impl Handler for Multithread { - type Result = Option; - fn handle(&mut self, _: GetReport, _: &mut Self::Context) -> Self::Result { - if let Some(ref report) = self.report { - return Some(report.to_report().to_string()); - } - None - } -} - -#[derive(Message, Debug)] -#[rtype("()")] -pub struct TrackDuration { - name: String, - duration: Duration, -} - -impl TrackDuration { - pub fn new(name: String, duration: Duration) -> Self { - Self { name, duration } - } -} - -#[derive(Message, Debug)] -#[rtype("Option")] -pub struct GetReport; - -fn timefunc( - name: &str, - id: u8, - func: F, -) -> (Result, Duration) -where - F: FnOnce() -> Result, -{ - info!("\nSTARTING MULTITHREAD `{}({})`\n", name, id); - let start = Instant::now(); - let out = func(); - let dur = start.elapsed(); - info!("\nFINISHED MULTITHREAD `{}`({}) in {:?}\n", name, id, dur); - (out, dur) // return output as well as timing info -} - -/// Handle our compute request. This function is run on a rayon threadpool. -fn handle_compute_request( - rng: SharedRng, - cipher: Arc, - request: ComputeRequest, -) -> (Result, Duration) { - let id: u8 = rand::thread_rng().gen(); - match request { - ComputeRequest::TrBFV(TrBFVRequest::GenPkShareAndSkSss(req)) => timefunc( - "gen_pk_share_and_sk_sss", - id, - || match gen_pk_share_and_sk_sss(&rng, &cipher, req) { - Ok(o) => Ok(ComputeResponse::TrBFV(TrBFVResponse::GenPkShareAndSkSss(o))), - Err(e) => Err(ComputeRequestError::TrBFV(TrBFVError::GenPkShareAndSkSss( - e.to_string(), - ))), - }, - ), - ComputeRequest::TrBFV(TrBFVRequest::GenEsiSss(req)) => timefunc("gen_esi_sss", id, || { - match gen_esi_sss(&rng, &cipher, req) { - Ok(o) => Ok(ComputeResponse::TrBFV(TrBFVResponse::GenEsiSss(o))), - Err(e) => Err(ComputeRequestError::TrBFV(TrBFVError::GenEsiSss( - e.to_string(), - ))), - } - }), - ComputeRequest::TrBFV(TrBFVRequest::CalculateDecryptionKey(req)) => timefunc( - "calculate_decryption_key", - id, - || match calculate_decryption_key(&cipher, req) { - Ok(o) => Ok(ComputeResponse::TrBFV( - TrBFVResponse::CalculateDecryptionKey(o), - )), - Err(e) => { - error!("Error calculating decryption key: {}", e); - Err(ComputeRequestError::TrBFV( - TrBFVError::CalculateDecryptionKey(e.to_string()), - )) - } - }, - ), - ComputeRequest::TrBFV(TrBFVRequest::CalculateDecryptionShare(req)) => timefunc( - "calculate_decryption_share", - id, - || match calculate_decryption_share(&cipher, req) { - Ok(o) => Ok(ComputeResponse::TrBFV( - TrBFVResponse::CalculateDecryptionShare(o), - )), - Err(e) => Err(ComputeRequestError::TrBFV( - TrBFVError::CalculateDecryptionShare(e.to_string()), - )), - }, - ), - ComputeRequest::TrBFV(TrBFVRequest::CalculateThresholdDecryption(req)) => timefunc( - "calculate_threshold_decryption", - id, - || match calculate_threshold_decryption(req) { - Ok(o) => Ok(ComputeResponse::TrBFV( - TrBFVResponse::CalculateThresholdDecryption(o), - )), - Err(e) => Err(ComputeRequestError::TrBFV( - TrBFVError::CalculateThresholdDecryption(e.to_string()), - )), - }, - ), - } -} +pub use multithread::*; +pub use pool::*; +pub use report::*; diff --git a/crates/multithread/src/multithread.rs b/crates/multithread/src/multithread.rs new file mode 100644 index 0000000000..552bb73855 --- /dev/null +++ b/crates/multithread/src/multithread.rs @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: LGPL-3.0-only +// +// This file is provided WITHOUT ANY WARRANTY; +// without even the implied warranty of MERCHANTABILITY +// or FITNESS FOR A PARTICULAR PURPOSE. + +use std::sync::Arc; +use std::thread; +use std::time::Duration; +use std::time::Instant; + +use crate::report::MultithreadReport; +use crate::report::TrackDuration; +use crate::TaskPool; +use crate::TaskTimeouts; +use actix::prelude::*; +use actix::{Actor, Handler}; +use anyhow::Result; +use e3_crypto::Cipher; +use e3_events::BusHandle; +use e3_events::ComputeRequestErrorKind; +use e3_events::EType; +use e3_events::EnclaveEvent; +use e3_events::EnclaveEventData; +use e3_events::ErrorDispatcher; +use e3_events::Event; +use e3_events::EventPublisher; +use e3_events::EventSubscriber; +use e3_events::{ComputeRequest, ComputeRequestError, ComputeResponse}; +use e3_trbfv::calculate_decryption_key::calculate_decryption_key; +use e3_trbfv::calculate_decryption_share::calculate_decryption_share; +use e3_trbfv::calculate_threshold_decryption::calculate_threshold_decryption; +use e3_trbfv::gen_esi_sss::gen_esi_sss; +use e3_trbfv::gen_pk_share_and_sk_sss::gen_pk_share_and_sk_sss; +use e3_trbfv::{TrBFVError, TrBFVRequest, TrBFVResponse}; +use e3_utils::SharedRng; +use rand::Rng; +use tracing::error; +use tracing::info; + +/// Multithread actor +pub struct Multithread { + bus: BusHandle, + rng: SharedRng, + cipher: Arc, + task_pool: TaskPool, + report: Option>, +} + +impl Multithread { + pub fn new( + bus: BusHandle, + rng: SharedRng, + cipher: Arc, + task_pool: TaskPool, + report: Option>, + ) -> Self { + Self { + bus, + rng, + cipher, + task_pool, + report, + } + } + + /// Subtract the given amount from the total number of available threads and return the result + pub fn get_max_threads_minus(amount: usize) -> usize { + let total_threads = thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(1); + let threads_to_use = std::cmp::max(1, total_threads.saturating_sub(amount)); + threads_to_use + } + + pub fn attach( + bus: &BusHandle, + rng: SharedRng, + cipher: Arc, + task_pool: TaskPool, + report: Option>, + ) -> Addr { + let addr = Self::new(bus.clone(), rng.clone(), cipher.clone(), task_pool, report).start(); + bus.subscribe("ComputeRequest", addr.clone().recipient()); + addr + } + + pub fn create_taskpool(threads: usize, max_tasks: usize) -> TaskPool { + TaskPool::new(threads, max_tasks) + } +} + +impl Actor for Multithread { + type Context = actix::Context; +} + +impl Handler for Multithread { + type Result = (); + fn handle(&mut self, msg: EnclaveEvent, ctx: &mut Self::Context) -> Self::Result { + info!("Multithread received EnclaveEvent!"); + match msg.get_data() { + EnclaveEventData::ComputeRequest(data) => ctx.notify(data.clone()), + _ => (), + } + } +} + +impl Handler for Multithread { + type Result = ResponseFuture<()>; + fn handle(&mut self, msg: ComputeRequest, _: &mut Self::Context) -> Self::Result { + let cipher = self.cipher.clone(); + let rng = self.rng.clone(); + let bus = self.bus.clone(); + let pool = self.task_pool.clone(); + let report = self.report.clone(); + // TODO: replace with trap_fut + Box::pin(async move { + match handle_compute_request_event(msg, bus, cipher, rng, pool, report).await { + Ok(_) => (), + Err(e) => error!("{e}"), + } + }) + } +} + +async fn handle_compute_request_event( + msg: ComputeRequest, + bus: BusHandle, + cipher: Arc, + rng: SharedRng, + pool: TaskPool, + report: Option>, +) -> anyhow::Result<()> { + let msg_string = msg.to_string(); + let job_name = msg_string.clone(); + + // We spawn a thread on rayon moving to "sync"-land + let (result, duration) = pool + .spawn(job_name, TaskTimeouts::default(), move || { + // Do the actual work this is gonna take a while... + handle_compute_request(rng, cipher, msg) + }) + .await?; + // we are back in async io land... + + // incase we are collecting events for a report + if let Some(report) = report { + report.do_send(TrackDuration::new(msg_string, duration)) + }; + + match result { + Ok(val) => bus.publish(val)?, + Err(e) => bus.err(EType::Computation, e), + }; + Ok(()) +} + +fn timefunc( + name: &str, + id: u8, + func: F, +) -> (Result, Duration) +where + F: FnOnce() -> Result, +{ + info!("\nSTARTING MULTITHREAD `{}({})`\n", name, id); + let start = Instant::now(); + let out = func(); + let dur = start.elapsed(); + info!("\nFINISHED MULTITHREAD `{}`({}) in {:?}\n", name, id, dur); + (out, dur) // return output as well as timing info +} + +/// Handle our compute request. This function is run on a rayon threadpool. +fn handle_compute_request( + rng: SharedRng, + cipher: Arc, + request: ComputeRequest, +) -> (Result, Duration) { + let id: u8 = rand::thread_rng().gen(); + + match request.request.clone() { + TrBFVRequest::GenPkShareAndSkSss(req) => { + timefunc( + "gen_pk_share_and_sk_sss", + id, + || match gen_pk_share_and_sk_sss(&rng, &cipher, req) { + Ok(o) => Ok(ComputeResponse::new( + TrBFVResponse::GenPkShareAndSkSss(o), + request.correlation_id, + request.e3_id, + )), + Err(e) => Err(ComputeRequestError::new( + ComputeRequestErrorKind::TrBFV(TrBFVError::GenPkShareAndSkSss( + e.to_string(), + )), + request, + )), + }, + ) + } + TrBFVRequest::GenEsiSss(req) => timefunc("gen_esi_sss", id, || { + match gen_esi_sss(&rng, &cipher, req) { + Ok(o) => Ok(ComputeResponse::new( + TrBFVResponse::GenEsiSss(o), + request.correlation_id, + request.e3_id, + )), + Err(e) => Err(ComputeRequestError::new( + ComputeRequestErrorKind::TrBFV(TrBFVError::GenEsiSss(e.to_string())), + request, + )), + } + }), + TrBFVRequest::CalculateDecryptionKey(req) => timefunc( + "calculate_decryption_key", + id, + || match calculate_decryption_key(&cipher, req) { + Ok(o) => Ok(ComputeResponse::new( + TrBFVResponse::CalculateDecryptionKey(o), + request.correlation_id, + request.e3_id, + )), + Err(e) => { + error!("Error calculating decryption key: {}", e); + Err(ComputeRequestError::new( + ComputeRequestErrorKind::TrBFV(TrBFVError::CalculateDecryptionKey( + e.to_string(), + )), + request, + )) + } + }, + ), + TrBFVRequest::CalculateDecryptionShare(req) => timefunc( + "calculate_decryption_share", + id, + || match calculate_decryption_share(&cipher, req) { + Ok(o) => Ok(ComputeResponse::new( + TrBFVResponse::CalculateDecryptionShare(o), + request.correlation_id, + request.e3_id, + )), + Err(e) => Err(ComputeRequestError::new( + ComputeRequestErrorKind::TrBFV(TrBFVError::CalculateDecryptionShare( + e.to_string(), + )), + request, + )), + }, + ), + TrBFVRequest::CalculateThresholdDecryption(req) => timefunc( + "calculate_threshold_decryption", + id, + || match calculate_threshold_decryption(req) { + Ok(o) => Ok(ComputeResponse::new( + TrBFVResponse::CalculateThresholdDecryption(o), + request.correlation_id, + request.e3_id, + )), + Err(e) => Err(ComputeRequestError::new( + ComputeRequestErrorKind::TrBFV(TrBFVError::CalculateThresholdDecryption( + e.to_string(), + )), + request, + )), + }, + ), + } +} diff --git a/crates/multithread/src/pool.rs b/crates/multithread/src/pool.rs new file mode 100644 index 0000000000..152b3a25b5 --- /dev/null +++ b/crates/multithread/src/pool.rs @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: LGPL-3.0-only +// +// This file is provided WITHOUT ANY WARRANTY; +// without even the implied warranty of MERCHANTABILITY +// or FITNESS FOR A PARTICULAR PURPOSE. + +use rayon::ThreadPool; +use std::fmt::Debug; +use std::ops::Deref; +use std::{sync::Arc, time::Duration}; +use thiserror::Error; +use tokio::sync::oneshot::error::RecvError; +use tokio::{sync::Semaphore, time::sleep}; +use tracing::{debug, error, info, warn, Level}; + +/// A bounded executor for CPU-bound tasks backed by a Rayon thread pool. +#[derive(Debug, Clone)] +pub struct TaskPool { + semaphore: Arc, + thread_pool: Arc, +} + +#[derive(Debug, Error)] +pub enum TaskPoolError { + #[error("{0}")] + SemaphoreError(String), + + #[error("{0}")] + RecvError(RecvError), +} + +impl TaskPool { + /// Creates a new pool with `threads` worker threads and at most `max_tasks` concurrent tasks. + pub fn new(threads: usize, max_tasks: usize) -> TaskPool { + let thread_pool = rayon::ThreadPoolBuilder::new() + .num_threads(threads) + .build() + .expect("Failed to build thread pool"); + + Self { + thread_pool: Arc::new(thread_pool), + semaphore: Arc::new(Semaphore::new(max_tasks)), + } + } + + pub async fn spawn( + &self, + task_name: String, + timed_logs: impl Into, // [(10, Level::WARN), (30, Level::ERROR)] + op: OP, + ) -> Result + where + OP: FnOnce() -> T + Send + 'static, + { + let timeouts = timed_logs.into(); + // Limit the requests and get them to block + let _permit = self + .semaphore + .acquire() + .await + .map_err(|_| TaskPoolError::SemaphoreError(task_name.to_owned()))?; + + // Warn of long running jobs + let warning_handle = tokio::spawn(async move { + let mut elapsed = Duration::ZERO; + + for log in timeouts.iter() { + let target = Duration::from_secs(log.0); + + // Sleep only for the remaining time to reach target + if target > elapsed { + sleep(target - elapsed).await; + elapsed = target; + } + let msg = format!("Job '{}' has been running for {:?}", task_name, target); + match log.1 { + Level::WARN => warn!(msg), + Level::ERROR => error!(msg), + Level::INFO => info!(msg), + Level::DEBUG => debug!(msg), + _ => (), + } + } + }); + + // This uses channels to track pending and complete tasks when + // using the thread pool + let (tx, rx) = tokio::sync::oneshot::channel(); + self.thread_pool.spawn(|| { + let t = op(); + // try to return the result and it's duration note this is sync as it is a oneshot sender. + if let Err(res) = tx.send(t) { + error!( + "There was an error sending the result from the multithread actor: result = {:?}", + res + ); + } + }); + + let output = rx.await.map_err(|r| TaskPoolError::RecvError(r))?; + + warning_handle.abort(); + + Ok(output) + } +} + +#[derive(Debug, Clone)] +pub struct TaskTimeouts(pub Vec); + +impl From<[(u64, Level); N]> for TaskTimeouts { + fn from(arr: [(u64, Level); N]) -> Self { + Self(arr.into_iter().map(|(s, l)| TimedLog(s, l)).collect()) + } +} + +impl Deref for TaskTimeouts { + type Target = Vec; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl TaskTimeouts { + pub fn new(logs: Vec) -> Self { + Self(logs) + } +} + +impl Default for TaskTimeouts { + fn default() -> Self { + [(10, Level::WARN), (30, Level::ERROR)].into() + } +} + +impl From<(u64, Level)> for TimedLog { + fn from((s, level): (u64, Level)) -> Self { + Self(s, level) + } +} + +#[derive(Debug, Clone)] +pub struct TimedLog(pub u64, pub tracing::Level); diff --git a/crates/multithread/src/report.rs b/crates/multithread/src/report.rs index 13c72f209d..9eedc23dab 100644 --- a/crates/multithread/src/report.rs +++ b/crates/multithread/src/report.rs @@ -6,7 +6,24 @@ use std::{collections::HashMap, thread, time::Duration}; -use crate::TrackDuration; +use actix::{Actor, Handler, Message, MessageResponse}; + +#[derive(Message)] +#[rtype(result = "FlattenedReport")] +pub struct ToReport; + +#[derive(Message, Debug)] +#[rtype("()")] +pub struct TrackDuration { + name: String, + duration: Duration, +} + +impl TrackDuration { + pub fn new(name: String, duration: Duration) -> Self { + Self { name, duration } + } +} #[derive(Default)] pub struct MultithreadReport { @@ -15,6 +32,10 @@ pub struct MultithreadReport { events: Vec, } +impl Actor for MultithreadReport { + type Context = actix::Context; +} + impl MultithreadReport { pub fn new(rayon_threads: usize, max_simultaneous_rayon_tasks: usize) -> Self { Self { @@ -24,11 +45,11 @@ impl MultithreadReport { } } - pub fn track(&mut self, msg: TrackDuration) { + fn track(&mut self, msg: TrackDuration) { self.events.push(msg); } - pub fn to_report(&self) -> FlattenedReport { + fn to_report(&self) -> FlattenedReport { let mut total_dur: HashMap = HashMap::new(); let mut runs: HashMap = HashMap::new(); let cores_available: usize = match thread::available_parallelism() { @@ -74,6 +95,21 @@ impl MultithreadReport { } } +impl Handler for MultithreadReport { + type Result = (); + fn handle(&mut self, msg: TrackDuration, _: &mut Self::Context) -> Self::Result { + self.track(msg) + } +} + +impl Handler for MultithreadReport { + type Result = FlattenedReport; + fn handle(&mut self, _: ToReport, _: &mut Self::Context) -> Self::Result { + self.to_report() + } +} + +#[derive(MessageResponse, Debug, Clone, Eq, PartialEq)] pub struct FlattenedReport { cores_available: usize, rayon_threads: usize, diff --git a/crates/tests/tests/integration.rs b/crates/tests/tests/integration.rs index b53f3ca924..69a556b9f1 100644 --- a/crates/tests/tests/integration.rs +++ b/crates/tests/tests/integration.rs @@ -4,6 +4,7 @@ // without even the implied warranty of MERCHANTABILITY // or FITNESS FOR A PARTICULAR PURPOSE. +use actix::Actor; use alloy::primitives::{FixedBytes, I256, U256}; use anyhow::{bail, Result}; use e3_ciphernode_builder::{CiphernodeBuilder, EventSystem}; @@ -13,7 +14,7 @@ use e3_events::{ E3Requested, E3id, EnclaveEventData, OperatorActivationChanged, PlaintextAggregated, TicketBalanceUpdated, }; -use e3_multithread::{GetReport, Multithread}; +use e3_multithread::{Multithread, MultithreadReport, ToReport}; use e3_sdk::bfv_helpers::{build_bfv_params_arc, decode_bytes_to_vec_u64, encode_bfv_params}; use e3_test_helpers::ciphernode_system::CiphernodeSystemBuilder; use e3_test_helpers::{create_seed_from_u64, create_shared_rng_from_u64, AddToCommittee}; @@ -166,13 +167,8 @@ async fn test_trbfv_actor() -> Result<()> { // Seems like you cannot send more than one job at a time to rayon let concurrent_jobs = 1; // leaving at 1 let max_threadroom = Multithread::get_max_threads_minus(1); - let multithread = Multithread::attach( - rng.clone(), - cipher.clone(), - max_threadroom, - concurrent_jobs, - true, - ); + let task_pool = Multithread::create_taskpool(max_threadroom, concurrent_jobs); + let multithread_report = MultithreadReport::new(max_threadroom, concurrent_jobs).start(); let nodes = CiphernodeSystemBuilder::new() // Adding 7 total nodes of which we are only choosing 5 for the committee @@ -181,8 +177,10 @@ async fn test_trbfv_actor() -> Result<()> { println!("Building collector {}!", addr); CiphernodeBuilder::new(&addr, rng.clone(), cipher.clone()) .with_address(&addr) - .with_injected_multithread(multithread.clone()) .testmode_with_history() + .with_shared_taskpool(&task_pool) + .with_multithread_concurrent_jobs(concurrent_jobs) + .with_shared_multithread_report(&multithread_report) .with_trbfv() .with_pubkey_aggregation() .with_sortition_score() @@ -197,7 +195,9 @@ async fn test_trbfv_actor() -> Result<()> { println!("Building normal {}", &addr); CiphernodeBuilder::new(&addr, rng.clone(), cipher.clone()) .with_address(&addr) - .with_injected_multithread(multithread.clone()) + .with_shared_taskpool(&task_pool) + .with_multithread_concurrent_jobs(concurrent_jobs) + .with_shared_multithread_report(&multithread_report) .with_trbfv() .with_sortition_score() .testmode_with_forked_bus(bus.consumer()) @@ -392,8 +392,10 @@ async fn test_trbfv_actor() -> Result<()> { "DecryptionshareCreated", "DecryptionshareCreated", "DecryptionshareCreated", + "ComputeRequest", "DecryptionshareCreated", "DecryptionshareCreated", + "ComputeResponse", "PlaintextAggregated", ]; @@ -438,7 +440,7 @@ async fn test_trbfv_actor() -> Result<()> { assert_eq!(res, exp); } - let mt_report = multithread.send(GetReport).await.unwrap().unwrap(); + let mt_report = multithread_report.send(ToReport).await.unwrap(); println!("{}", mt_report); report.push(("Entire Test", whole_test.elapsed()));