diff --git a/Cargo.toml b/Cargo.toml index 7e53b66..26ff04e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -72,6 +72,7 @@ default = ["alloy", "rustls"] alloy = ["dep:alloy"] aws = ["alloy", "alloy?/signer-aws", "dep:async-trait", "dep:aws-config", "dep:aws-sdk-kms"] perms = ["dep:oauth2", "dep:tokio", "dep:reqwest", "dep:signet-tx-cache"] +block_watcher = ["dep:tokio"] rustls = ["dep:rustls", "rustls/aws-lc-rs"] [[example]] diff --git a/src/lib.rs b/src/lib.rs index a114ff5..cf1efc0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,6 +47,10 @@ pub mod utils { /// Tracing utilities. pub mod tracing; + + /// Block watcher utilities. + #[cfg(feature = "block_watcher")] + pub mod block_watcher; } /// Re-exports of common dependencies. diff --git a/src/utils/block_watcher.rs b/src/utils/block_watcher.rs new file mode 100644 index 0000000..9eb6727 --- /dev/null +++ b/src/utils/block_watcher.rs @@ -0,0 +1,120 @@ +//! Host chain block watcher that subscribes to new blocks and tracks the +//! current host block number. + +use alloy::{ + network::Ethereum, + providers::{Provider, RootProvider}, + transports::TransportError, +}; +use tokio::{ + sync::{broadcast::error::RecvError, watch}, + task::JoinHandle, +}; +use tracing::{debug, error, trace}; + +/// Host chain block watcher that subscribes to new blocks and broadcasts +/// updates via a watch channel. +#[derive(Debug)] +pub struct BlockWatcher { + /// Watch channel responsible for broadcasting block number updates. + block_number: watch::Sender, + + /// Host chain provider. + host_provider: RootProvider, +} + +impl BlockWatcher { + /// Creates a new [`BlockWatcher`] with the given provider and initial + /// block number. + pub fn new(host_provider: RootProvider, initial: u64) -> Self { + Self { + block_number: watch::channel(initial).0, + host_provider, + } + } + + /// Creates a new [`BlockWatcher`], fetching the current block number first. + pub async fn with_current_block( + host_provider: RootProvider, + ) -> Result { + let block_number = host_provider.get_block_number().await?; + Ok(Self::new(host_provider, block_number)) + } + + /// Subscribe to block number updates. + pub fn subscribe(&self) -> SharedBlockNumber { + self.block_number.subscribe().into() + } + + /// Spawns the block watcher task. + pub fn spawn(self) -> (SharedBlockNumber, JoinHandle<()>) { + (self.subscribe(), tokio::spawn(self.task_future())) + } + + async fn task_future(self) { + let mut sub = match self.host_provider.subscribe_blocks().await { + Ok(sub) => sub, + Err(error) => { + error!(%error); + return; + } + }; + + debug!("subscribed to host chain blocks"); + + loop { + match sub.recv().await { + Ok(header) => { + let block_number = header.number; + self.block_number.send_replace(block_number); + trace!(block_number, "updated host block number"); + } + Err(RecvError::Lagged(missed)) => { + debug!(%missed, "block subscription lagged"); + } + Err(RecvError::Closed) => { + debug!("block subscription closed"); + break; + } + } + } + } +} + +/// A shared block number, wrapped in a [`tokio::sync::watch`] Receiver. +/// +/// The block number is periodically updated by a [`BlockWatcher`] task, and +/// can be read or awaited for changes. This allows multiple tasks to observe +/// block number updates. +#[derive(Debug, Clone)] +pub struct SharedBlockNumber(watch::Receiver); + +impl From> for SharedBlockNumber { + fn from(inner: watch::Receiver) -> Self { + Self(inner) + } +} + +impl SharedBlockNumber { + /// Get the current block number. + pub fn get(&self) -> u64 { + *self.0.borrow() + } + + /// Wait for the block number to change, then return the new value. + /// + /// This is implemented using [`Receiver::changed`]. + /// + /// [`Receiver::changed`]: tokio::sync::watch::Receiver::changed + pub async fn changed(&mut self) -> Result { + self.0.changed().await?; + Ok(*self.0.borrow_and_update()) + } + + /// Wait for the block number to reach at least `target`. + /// + /// Returns the block number once it is >= `target`. + pub async fn wait_until(&mut self, target: u64) -> Result { + self.0.wait_for(|&n| n >= target).await.map(|r| *r) + } +}