diff --git a/modules/examples/price-alert/Cargo.toml b/modules/examples/price-alert/Cargo.toml index ef16cb4..a4173d7 100644 --- a/modules/examples/price-alert/Cargo.toml +++ b/modules/examples/price-alert/Cargo.toml @@ -14,3 +14,6 @@ shepherd-sdk = { path = "../../../crates/shepherd-sdk" } alloy-primitives = { version = "1.5", default-features = false, features = ["std"] } alloy-sol-types = { version = "1.5", default-features = false, features = ["std"] } wit-bindgen = { version = "0.57", default-features = false, features = ["macros", "realloc"] } + +[dev-dependencies] +shepherd-sdk-test = { path = "../../../crates/shepherd-sdk-test" } diff --git a/modules/examples/price-alert/src/lib.rs b/modules/examples/price-alert/src/lib.rs index 7f1b5b2..0130bcb 100644 --- a/modules/examples/price-alert/src/lib.rs +++ b/modules/examples/price-alert/src/lib.rs @@ -1,36 +1,21 @@ //! # price-alert (example Shepherd module) //! -//! Polls a Chainlink price oracle on every new block and emits a -//! Warn-level log when the price crosses a config-supplied -//! threshold. Demonstrates the three load-bearing patterns of a -//! Shepherd module: +//! Polls a Chainlink price oracle on every new block (throttled by +//! `every_n_blocks`) and emits a Warn-level log when the price +//! crosses a config-supplied threshold. //! -//! - `chain::request` + ABI decode via `alloy_sol_types` -//! - `shepherd_sdk` helpers (`prelude`, `chain::eth_call_params`, -//! `chain::parse_eth_call_result`) -//! - `[config]` driven behaviour parsed once in `init` and read on -//! every subsequent event +//! ## Module layout //! -//! ## Settings +//! - `strategy.rs` holds the pure logic and tests against +//! `shepherd_sdk::host::Host`. It does not know `wit-bindgen` +//! exists. +//! - `lib.rs` (this file) bridges the per-cdylib wit-bindgen imports +//! into the trait surface and delegates `init` / `on_event` to +//! `strategy`. //! -//! ```toml -//! [config] -//! # Chainlink AggregatorV3Interface address. -//! oracle_address = "0x694AA1769357215DE4FAC081bf1f309aDC325306" # ETH/USD on Sepolia -//! # Oracle's decimals (Chainlink USD pairs are 8; ETH pairs 18). -//! decimals = "8" -//! # Threshold in the oracle's native units (decimal string). The -//! # module multiplies by 10**decimals at init. -//! threshold = "2500.00" -//! # Either "above" or "below". Fires when the answer crosses on -//! # the configured side. -//! direction = "below" -//! # Optional throttle: poll every N blocks. Default 1. -//! every_n_blocks = "1" -//! ``` +//! This split is the M3 "host trait + adapter" recipe documented in +//! `docs/tutorial-first-module.md`. -// wit_bindgen::generate! expands to host-import shims whose arity matches -// the WIT signatures, which can exceed clippy's too-many-arguments threshold. #![cfg_attr(not(test), warn(unused_crate_dependencies))] #![allow(clippy::too_many_arguments)] @@ -40,373 +25,137 @@ wit_bindgen::generate!({ generate_all, }); +mod strategy; + use std::sync::OnceLock; -use alloy_primitives::{Address, I256, U256}; -use alloy_sol_types::{SolCall, sol}; -use shepherd_sdk::chain::{eth_call_params, parse_eth_call_result}; +use shepherd_sdk::host::{ + ChainHost, CowApiHost, HostError as SdkHostError, HostErrorKind as SdkHostErrorKind, + LocalStoreHost, LogLevel as SdkLogLevel, LoggingHost, +}; use nexum::host::types::HostErrorKind; -use nexum::host::{chain, logging, types}; +use nexum::host::{chain, local_store, logging, types}; +use shepherd::cow::cow_api; -sol! { - /// Chainlink AggregatorV3Interface - only the function this - /// module needs. - interface AggregatorV3 { - function latestRoundData() external view returns ( - uint80 roundId, - int256 answer, - uint256 startedAt, - uint256 updatedAt, - uint80 answeredInRound - ); - } -} +static SETTINGS: OnceLock = OnceLock::new(); -/// Resolved configuration, parsed from `module.toml::[config]` at -/// `init` and read on every `on_event`. Stored in a `OnceLock` so -/// the module is single-init by construction. -#[derive(Debug)] -struct Settings { - oracle_address: Address, - /// Threshold scaled to the oracle's native units - /// (`threshold_decimal * 10**decimals`). - threshold_scaled: I256, - direction: Direction, - every_n_blocks: u64, -} +/// Wraps the module's per-cdylib wit-bindgen imports so the strategy +/// can hold a `&impl Host` instead of dispatching on the free +/// functions directly. The implementation is mechanical and identical +/// across modules; a future declarative macro in `shepherd-sdk` will +/// elide the boilerplate. +struct WitBindgenHost; -/// Which side of the threshold the alert fires on. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -enum Direction { - /// Fire when `answer >= threshold`. - Above, - /// Fire when `answer <= threshold`. - Below, +impl ChainHost for WitBindgenHost { + fn request(&self, chain_id: u64, method: &str, params: &str) -> Result { + chain::request(chain_id, method, params).map_err(convert_err) + } } -static CONFIG: OnceLock = OnceLock::new(); - -struct PriceAlert; - -impl Guest for PriceAlert { - fn init(config: Vec<(String, String)>) -> Result<(), HostError> { - match parse_config(&config) { - Ok(cfg) => { - logging::log( - logging::Level::Info, - &format!( - "price-alert init: oracle={:#x} threshold={} direction={:?} every_n_blocks={}", - cfg.oracle_address, - cfg.threshold_scaled, - cfg.direction, - cfg.every_n_blocks, - ), - ); - // OnceLock::set fails only if already set - in a - // single-init module that means a re-entry from the - // supervisor, which is not a hard error; we keep the - // first parse. - let _ = CONFIG.set(cfg); - Ok(()) - } - Err(e) => Err(HostError { - domain: "price-alert".into(), - kind: HostErrorKind::InvalidInput, - code: 0, - message: format!("price-alert: invalid [config]: {e}"), - data: None, - }), - } +impl LocalStoreHost for WitBindgenHost { + fn get(&self, key: &str) -> Result>, SdkHostError> { + local_store::get(key).map_err(convert_err) } - - fn on_event(event: types::Event) -> Result<(), HostError> { - let Some(cfg) = CONFIG.get() else { - return Ok(()); // init failed; no-op until a fresh load. - }; - if let types::Event::Block(block) = event { - if block.number % cfg.every_n_blocks != 0 { - return Ok(()); - } - poll_oracle(block.chain_id, cfg); - } - // Logs / Tick / Message are not used by this example. - Ok(()) + fn set(&self, key: &str, value: &[u8]) -> Result<(), SdkHostError> { + local_store::set(key, value).map_err(convert_err) } -} - -/// Build + dispatch the `latestRoundData` eth_call. Result is -/// logged: Info if the threshold is not crossed, Warn if it is. -/// Returns nothing so a single bad RPC reply does not propagate -/// into the supervisor - the next block re-polls. -fn poll_oracle(chain_id: u64, cfg: &Settings) { - let call_data = AggregatorV3::latestRoundDataCall {}.abi_encode(); - let params = eth_call_params(&cfg.oracle_address, &call_data); - let result_json = match chain::request(chain_id, "eth_call", ¶ms) { - Ok(s) => s, - Err(err) => { - logging::log( - logging::Level::Warn, - &format!("price-alert eth_call failed ({}): {}", err.code, err.message), - ); - return; - } - }; - let Some(bytes) = parse_eth_call_result(&result_json) else { - logging::log( - logging::Level::Warn, - &format!("price-alert: cannot decode result hex {result_json}"), - ); - return; - }; - let decoded = match AggregatorV3::latestRoundDataCall::abi_decode_returns(&bytes) { - Ok(d) => d, - Err(e) => { - logging::log( - logging::Level::Warn, - &format!("price-alert: latestRoundData decode failed: {e}"), - ); - return; - } - }; - let answer = decoded.answer; - if classify(answer, cfg.threshold_scaled, cfg.direction) { - logging::log( - logging::Level::Warn, - &format!( - "price-alert: TRIGGERED answer={answer} threshold={} ({:?})", - cfg.threshold_scaled, cfg.direction, - ), - ); - } else { - logging::log( - logging::Level::Info, - &format!( - "price-alert: ok answer={answer} threshold={} ({:?})", - cfg.threshold_scaled, cfg.direction, - ), - ); + fn delete(&self, key: &str) -> Result<(), SdkHostError> { + local_store::delete(key).map_err(convert_err) } -} - -/// `true` when `answer` is on the firing side of `threshold` per -/// `direction`. Pure - exercised by the unit tests. -fn classify(answer: I256, threshold: I256, direction: Direction) -> bool { - match direction { - Direction::Above => answer >= threshold, - Direction::Below => answer <= threshold, + fn list_keys(&self, prefix: &str) -> Result, SdkHostError> { + local_store::list_keys(prefix).map_err(convert_err) } } -/// Parse `module.toml::[config]` into a typed [`Settings`]. Returns a -/// human-readable error string the engine surfaces under -/// `host_error.message`. -fn parse_config(entries: &[(String, String)]) -> Result { - let oracle_address = config_get(entries, "oracle_address")? - .parse::
() - .map_err(|e| format!("oracle_address: {e}"))?; - let decimals = config_get(entries, "decimals")? - .parse::() - .map_err(|e| format!("decimals: {e}"))?; - if decimals > 38 { - return Err(format!( - "decimals={decimals} exceeds the I256 power-of-ten budget" - )); +impl CowApiHost for WitBindgenHost { + fn submit_order(&self, chain_id: u64, body: &[u8]) -> Result { + cow_api::submit_order(chain_id, body).map_err(convert_err) } - let threshold_decimal = config_get(entries, "threshold")?; - let threshold_scaled = scale_threshold(threshold_decimal, decimals)?; - let direction = match config_get(entries, "direction")?.to_ascii_lowercase().as_str() { - "above" => Direction::Above, - "below" => Direction::Below, - other => return Err(format!("direction: expected 'above'|'below', got {other:?}")), - }; - let every_n_blocks = config_get_optional(entries, "every_n_blocks") - .map(|s| s.parse::().map_err(|e| format!("every_n_blocks: {e}"))) - .transpose()? - .unwrap_or(1) - .max(1); - Ok(Settings { - oracle_address, - threshold_scaled, - direction, - every_n_blocks, - }) } -fn config_get<'a>(entries: &'a [(String, String)], key: &str) -> Result<&'a str, String> { - entries - .iter() - .find(|(k, _)| k == key) - .map(|(_, v)| v.as_str()) - .ok_or_else(|| format!("missing key {key:?}")) +impl LoggingHost for WitBindgenHost { + fn log(&self, level: SdkLogLevel, message: &str) { + logging::log(convert_level(level), message); + } } -fn config_get_optional<'a>(entries: &'a [(String, String)], key: &str) -> Option<&'a str> { - entries.iter().find(|(k, _)| k == key).map(|(_, v)| v.as_str()) +fn convert_err(e: HostError) -> SdkHostError { + SdkHostError { + domain: e.domain, + kind: match e.kind { + HostErrorKind::Unsupported => SdkHostErrorKind::Unsupported, + HostErrorKind::Unavailable => SdkHostErrorKind::Unavailable, + HostErrorKind::Denied => SdkHostErrorKind::Denied, + HostErrorKind::RateLimited => SdkHostErrorKind::RateLimited, + HostErrorKind::Timeout => SdkHostErrorKind::Timeout, + HostErrorKind::InvalidInput => SdkHostErrorKind::InvalidInput, + HostErrorKind::Internal => SdkHostErrorKind::Internal, + }, + code: e.code, + message: e.message, + data: e.data, + } } -/// Multiply `threshold_decimal` (e.g. `"2500.00"`) by `10**decimals` -/// into an `I256` for direct comparison with the oracle's answer. -/// Hand-rolled because alloy does not ship a `Decimal::parse_units`- -/// style helper and the module needs to stay no-std-ish. -fn scale_threshold(threshold_decimal: &str, decimals: u32) -> Result { - let (sign, body) = if let Some(rest) = threshold_decimal.strip_prefix('-') { - (-1i32, rest) - } else { - (1, threshold_decimal) - }; - let (whole, frac) = match body.split_once('.') { - Some((w, f)) => (w, f), - None => (body, ""), - }; - if whole.is_empty() && frac.is_empty() { - return Err("threshold: empty".into()); - } - if !whole.chars().all(|c| c.is_ascii_digit()) || !frac.chars().all(|c| c.is_ascii_digit()) { - return Err(format!( - "threshold: non-digit character in {threshold_decimal:?}" - )); +fn sdk_err_into_wit(e: SdkHostError) -> HostError { + HostError { + domain: e.domain, + kind: match e.kind { + SdkHostErrorKind::Unsupported => HostErrorKind::Unsupported, + SdkHostErrorKind::Unavailable => HostErrorKind::Unavailable, + SdkHostErrorKind::Denied => HostErrorKind::Denied, + SdkHostErrorKind::RateLimited => HostErrorKind::RateLimited, + SdkHostErrorKind::Timeout => HostErrorKind::Timeout, + SdkHostErrorKind::InvalidInput => HostErrorKind::InvalidInput, + SdkHostErrorKind::Internal => HostErrorKind::Internal, + }, + code: e.code, + message: e.message, + data: e.data, } - // Compose the un-scaled integer string, padding / truncating the - // fractional part against `decimals`. - let frac_len = frac.len() as u32; - let composed: String = if frac_len <= decimals { - let mut s = String::with_capacity(whole.len() + decimals as usize); - s.push_str(whole); - s.push_str(frac); - // Pad with zeros for the missing fractional digits. - for _ in 0..(decimals - frac_len) { - s.push('0'); - } - s - } else { - // Fractional part is longer than `decimals` - truncate - // (chops trailing digits; deliberately not rounding to keep - // behaviour predictable). - let mut s = String::with_capacity(whole.len() + decimals as usize); - s.push_str(whole); - s.push_str(&frac[..decimals as usize]); - s - }; - let raw = if composed.is_empty() { "0" } else { &composed }; - let unsigned: U256 = raw.parse().map_err(|e| format!("threshold parse: {e}"))?; - let signed = I256::try_from(unsigned).map_err(|e| format!("threshold range: {e}"))?; - Ok(if sign < 0 { -signed } else { signed }) } -export!(PriceAlert); - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn parse_config_happy_path() { - let entries = vec![ - ( - "oracle_address".into(), - "0x694AA1769357215DE4FAC081bf1f309aDC325306".into(), - ), - ("decimals".into(), "8".into()), - ("threshold".into(), "2500.50".into()), - ("direction".into(), "below".into()), - ("every_n_blocks".into(), "5".into()), - ]; - let cfg = parse_config(&entries).unwrap(); - assert_eq!(cfg.direction, Direction::Below); - assert_eq!(cfg.every_n_blocks, 5); - // 2500.50 with 8 decimals = 2500_50000000 = 250_050_000_000 - assert_eq!(cfg.threshold_scaled, I256::try_from(250_050_000_000_i64).unwrap()); +fn convert_level(l: SdkLogLevel) -> logging::Level { + match l { + SdkLogLevel::Trace => logging::Level::Trace, + SdkLogLevel::Debug => logging::Level::Debug, + SdkLogLevel::Info => logging::Level::Info, + SdkLogLevel::Warn => logging::Level::Warn, + SdkLogLevel::Error => logging::Level::Error, } +} - #[test] - fn parse_config_defaults_every_n_blocks_to_one() { - let entries = vec![ - ( - "oracle_address".into(), - "0x694AA1769357215DE4FAC081bf1f309aDC325306".into(), - ), - ("decimals".into(), "8".into()), - ("threshold".into(), "1".into()), - ("direction".into(), "above".into()), - ]; - let cfg = parse_config(&entries).unwrap(); - assert_eq!(cfg.every_n_blocks, 1); - assert_eq!(cfg.direction, Direction::Above); - } +struct PriceAlert; - #[test] - fn parse_config_rejects_unknown_direction() { - let entries = vec![ - ( - "oracle_address".into(), - "0x694AA1769357215DE4FAC081bf1f309aDC325306".into(), +impl Guest for PriceAlert { + fn init(config: Vec<(String, String)>) -> Result<(), HostError> { + let cfg = strategy::parse_config(&config).map_err(sdk_err_into_wit)?; + logging::log( + logging::Level::Info, + &format!( + "price-alert init: oracle={:#x} threshold={} direction={:?} every_n_blocks={}", + cfg.oracle_address, cfg.threshold_scaled, cfg.direction, cfg.every_n_blocks, ), - ("decimals".into(), "8".into()), - ("threshold".into(), "1".into()), - ("direction".into(), "sideways".into()), - ]; - assert!(parse_config(&entries).is_err()); - } - - #[test] - fn parse_config_rejects_missing_key() { - let entries = vec![ - ("decimals".into(), "8".into()), - ("threshold".into(), "1".into()), - ("direction".into(), "above".into()), - ]; - let err = parse_config(&entries).unwrap_err(); - assert!(err.contains("oracle_address")); - } - - #[test] - fn scale_threshold_pads_short_fractional() { - assert_eq!(scale_threshold("1.5", 8).unwrap(), I256::try_from(150_000_000_i64).unwrap()); - } - - #[test] - fn scale_threshold_truncates_long_fractional() { - // "1.123456789" with 8 decimals truncates to "1.12345678". - assert_eq!( - scale_threshold("1.123456789", 8).unwrap(), - I256::try_from(112_345_678_i64).unwrap(), - ); - } - - #[test] - fn scale_threshold_handles_no_decimal_point() { - assert_eq!(scale_threshold("42", 8).unwrap(), I256::try_from(4_200_000_000_i64).unwrap()); - } - - #[test] - fn scale_threshold_handles_negative_values() { - // Useful for non-USD pairs (yield curves, basis spreads, etc.). - assert_eq!( - scale_threshold("-1.5", 8).unwrap(), - -I256::try_from(150_000_000_i64).unwrap(), ); + // OnceLock::set fails only if already set - in a single-init + // module that means a re-entry from the supervisor, which is + // not a hard error; we keep the first parse. + let _ = SETTINGS.set(cfg); + Ok(()) } - #[test] - fn scale_threshold_rejects_garbage() { - assert!(scale_threshold("abc", 8).is_err()); - assert!(scale_threshold("1.2.3", 8).is_err()); - } - - #[test] - fn classify_below_fires_at_or_under_threshold() { - let t = I256::try_from(100_i32).unwrap(); - assert!(classify(I256::try_from(99_i32).unwrap(), t, Direction::Below)); - assert!(classify(I256::try_from(100_i32).unwrap(), t, Direction::Below)); - assert!(!classify(I256::try_from(101_i32).unwrap(), t, Direction::Below)); - } - - #[test] - fn classify_above_fires_at_or_over_threshold() { - let t = I256::try_from(100_i32).unwrap(); - assert!(classify(I256::try_from(101_i32).unwrap(), t, Direction::Above)); - assert!(classify(I256::try_from(100_i32).unwrap(), t, Direction::Above)); - assert!(!classify(I256::try_from(99_i32).unwrap(), t, Direction::Above)); + fn on_event(event: types::Event) -> Result<(), HostError> { + let Some(cfg) = SETTINGS.get() else { + return Ok(()); // init failed; no-op. + }; + if let types::Event::Block(block) = event { + strategy::on_block(&WitBindgenHost, block.chain_id, cfg, block.number) + .map_err(sdk_err_into_wit)?; + } + // Logs / Tick / Message are not used by this example. + Ok(()) } } + +export!(PriceAlert); diff --git a/modules/examples/price-alert/src/strategy.rs b/modules/examples/price-alert/src/strategy.rs new file mode 100644 index 0000000..3b7b0ec --- /dev/null +++ b/modules/examples/price-alert/src/strategy.rs @@ -0,0 +1,495 @@ +//! Pure strategy logic for the price-alert module. +//! +//! Every interaction with the world flows through the [`Host`] trait +//! seam exposed by `shepherd-sdk` — no direct calls to wit-bindgen- +//! generated free functions live here. The `lib.rs` glue wraps a +//! `WitBindgenHost` adapter around the module's per-cdylib wit-bindgen +//! imports and hands it to [`on_block`]; tests under `#[cfg(test)]` +//! hand the same function a `shepherd_sdk_test::MockHost`. + +use alloy_primitives::I256; +use alloy_sol_types::{SolCall, sol}; +use shepherd_sdk::chain::{eth_call_params, parse_eth_call_result}; +use shepherd_sdk::host::{Host, HostError, HostErrorKind, LogLevel}; +use shepherd_sdk::prelude::{Address, U256}; + +sol! { + /// Chainlink AggregatorV3Interface - only the function this module + /// needs. + interface AggregatorV3 { + function latestRoundData() external view returns ( + uint80 roundId, + int256 answer, + uint256 startedAt, + uint256 updatedAt, + uint80 answeredInRound + ); + } +} + +/// Resolved configuration, parsed from `module.toml::[config]` at +/// `init` and read on every `on_event`. +#[derive(Debug)] +pub struct Settings { + /// Chainlink AggregatorV3Interface address. + pub oracle_address: Address, + /// Threshold scaled to the oracle's native units + /// (`threshold_decimal * 10**decimals`). + pub threshold_scaled: I256, + /// Which side of the threshold fires. + pub direction: Direction, + /// Throttle: only poll every Nth block. + pub every_n_blocks: u64, +} + +/// Which side of the threshold the alert fires on. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum Direction { + /// Fire when `answer >= threshold`. + Above, + /// Fire when `answer <= threshold`. + Below, +} + +/// React to a new block. +/// +/// Returns `Ok(())` on success and on recoverable upstream failures +/// (oracle RPC error, decode failure) - the strategy logs a Warn and +/// lets the next block re-poll rather than propagating into the +/// supervisor. Only host-level I/O on the persistence side would +/// bubble up via `?`, and this module does not touch the store. +pub fn on_block( + host: &H, + chain_id: u64, + settings: &Settings, + block_number: u64, +) -> Result<(), HostError> { + if !block_number.is_multiple_of(settings.every_n_blocks) { + return Ok(()); + } + let call_data = AggregatorV3::latestRoundDataCall {}.abi_encode(); + let params = eth_call_params(&settings.oracle_address, &call_data); + let result_json = match host.request(chain_id, "eth_call", ¶ms) { + Ok(s) => s, + Err(err) => { + host.log( + LogLevel::Warn, + &format!( + "price-alert eth_call failed ({}): {}", + err.code, err.message + ), + ); + return Ok(()); + } + }; + let Some(bytes) = parse_eth_call_result(&result_json) else { + host.log( + LogLevel::Warn, + &format!("price-alert: cannot decode result hex {result_json}"), + ); + return Ok(()); + }; + let decoded = match AggregatorV3::latestRoundDataCall::abi_decode_returns(&bytes) { + Ok(d) => d, + Err(e) => { + host.log( + LogLevel::Warn, + &format!("price-alert: latestRoundData decode failed: {e}"), + ); + return Ok(()); + } + }; + let answer = decoded.answer; + if classify(answer, settings.threshold_scaled, settings.direction) { + host.log( + LogLevel::Warn, + &format!( + "price-alert: TRIGGERED answer={answer} threshold={} ({:?})", + settings.threshold_scaled, settings.direction, + ), + ); + } else { + host.log( + LogLevel::Info, + &format!( + "price-alert: ok answer={answer} threshold={} ({:?})", + settings.threshold_scaled, settings.direction, + ), + ); + } + Ok(()) +} + +/// `true` when `answer` is on the firing side of `threshold` per +/// `direction`. Pure - exercised by the unit tests. +pub fn classify(answer: I256, threshold: I256, direction: Direction) -> bool { + match direction { + Direction::Above => answer >= threshold, + Direction::Below => answer <= threshold, + } +} + +/// Parse `module.toml::[config]` into a typed [`Settings`]. +/// +/// One-shot config-parser style: returns `Result` so the +/// `Guest::init` adapter can lift the failure into the wit-bindgen +/// `HostError` with no extra plumbing. +pub fn parse_config(entries: &[(String, String)]) -> Result { + let oracle_address = config_get(entries, "oracle_address")? + .parse::
() + .map_err(|e| config_err(format!("oracle_address: {e}")))?; + let decimals = config_get(entries, "decimals")? + .parse::() + .map_err(|e| config_err(format!("decimals: {e}")))?; + if decimals > 38 { + return Err(config_err(format!( + "decimals={decimals} exceeds the I256 power-of-ten budget" + ))); + } + let threshold_decimal = config_get(entries, "threshold")?; + let threshold_scaled = scale_threshold(threshold_decimal, decimals)?; + let direction = match config_get(entries, "direction")?.to_ascii_lowercase().as_str() { + "above" => Direction::Above, + "below" => Direction::Below, + other => { + return Err(config_err(format!( + "direction: expected 'above'|'below', got {other:?}" + ))); + } + }; + let every_n_blocks = config_get_optional(entries, "every_n_blocks") + .map(|s| { + s.parse::() + .map_err(|e| config_err(format!("every_n_blocks: {e}"))) + }) + .transpose()? + .unwrap_or(1) + .max(1); + Ok(Settings { + oracle_address, + threshold_scaled, + direction, + every_n_blocks, + }) +} + +fn config_get<'a>(entries: &'a [(String, String)], key: &str) -> Result<&'a str, HostError> { + entries + .iter() + .find(|(k, _)| k == key) + .map(|(_, v)| v.as_str()) + .ok_or_else(|| config_err(format!("missing key {key:?}"))) +} + +fn config_get_optional<'a>(entries: &'a [(String, String)], key: &str) -> Option<&'a str> { + entries.iter().find(|(k, _)| k == key).map(|(_, v)| v.as_str()) +} + +fn config_err(message: impl Into) -> HostError { + HostError { + domain: "price-alert".into(), + kind: HostErrorKind::InvalidInput, + code: 0, + message: format!("price-alert: invalid [config]: {}", message.into()), + data: None, + } +} + +/// Multiply `threshold_decimal` (e.g. `"2500.00"`) by `10**decimals` +/// into an `I256` for direct comparison with the oracle's answer. +fn scale_threshold(threshold_decimal: &str, decimals: u32) -> Result { + let (sign, body) = if let Some(rest) = threshold_decimal.strip_prefix('-') { + (-1i32, rest) + } else { + (1, threshold_decimal) + }; + let (whole, frac) = match body.split_once('.') { + Some((w, f)) => (w, f), + None => (body, ""), + }; + if whole.is_empty() && frac.is_empty() { + return Err(config_err("threshold: empty")); + } + if !whole.chars().all(|c| c.is_ascii_digit()) || !frac.chars().all(|c| c.is_ascii_digit()) { + return Err(config_err(format!( + "threshold: non-digit character in {threshold_decimal:?}" + ))); + } + let frac_len = frac.len() as u32; + let composed: String = if frac_len <= decimals { + let mut s = String::with_capacity(whole.len() + decimals as usize); + s.push_str(whole); + s.push_str(frac); + for _ in 0..(decimals - frac_len) { + s.push('0'); + } + s + } else { + let mut s = String::with_capacity(whole.len() + decimals as usize); + s.push_str(whole); + s.push_str(&frac[..decimals as usize]); + s + }; + let raw = if composed.is_empty() { "0" } else { &composed }; + let unsigned: U256 = raw + .parse() + .map_err(|e| config_err(format!("threshold parse: {e}")))?; + let signed = I256::try_from(unsigned) + .map_err(|e| config_err(format!("threshold range: {e}")))?; + Ok(if sign < 0 { -signed } else { signed }) +} + +#[cfg(test)] +mod tests { + use super::*; + use alloy_primitives::hex; + use shepherd_sdk::host::HostErrorKind as Kind; + use shepherd_sdk_test::MockHost; + + fn sample_settings(trigger_scaled_dec: i128, direction: Direction) -> Settings { + Settings { + oracle_address: "0x694AA1769357215DE4FAC081bf1f309aDC325306".parse().unwrap(), + threshold_scaled: I256::try_from(trigger_scaled_dec).unwrap(), + direction, + every_n_blocks: 1, + } + } + + /// Encode a `latestRoundData` return into the `"0x..."` JSON string + /// the host's `chain::request` would yield. + fn oracle_response_json(answer_scaled: i128) -> String { + use alloy_primitives::aliases::U80; + let returns = AggregatorV3::latestRoundDataReturn { + roundId: U80::ZERO, + answer: I256::try_from(answer_scaled).unwrap(), + startedAt: U256::ZERO, + updatedAt: U256::ZERO, + answeredInRound: U80::ZERO, + }; + let encoded = AggregatorV3::latestRoundDataCall::abi_encode_returns(&returns); + let hex = hex::encode_prefixed(encoded); + format!("\"{hex}\"") + } + + fn programmed_eth_call(host: &MockHost, oracle: Address, response: Result) { + let call_data = AggregatorV3::latestRoundDataCall {}.abi_encode(); + let params = eth_call_params(&oracle, &call_data); + host.chain.respond_to("eth_call", ¶ms, response); + } + + // ---- pure helpers ---- + + #[test] + fn classify_below_fires_at_or_under_threshold() { + let t = I256::try_from(100_i32).unwrap(); + assert!(classify(I256::try_from(99_i32).unwrap(), t, Direction::Below)); + assert!(classify(I256::try_from(100_i32).unwrap(), t, Direction::Below)); + assert!(!classify(I256::try_from(101_i32).unwrap(), t, Direction::Below)); + } + + #[test] + fn classify_above_fires_at_or_over_threshold() { + let t = I256::try_from(100_i32).unwrap(); + assert!(classify(I256::try_from(101_i32).unwrap(), t, Direction::Above)); + assert!(classify(I256::try_from(100_i32).unwrap(), t, Direction::Above)); + assert!(!classify(I256::try_from(99_i32).unwrap(), t, Direction::Above)); + } + + #[test] + fn scale_threshold_pads_short_fractional() { + assert_eq!( + scale_threshold("1.5", 8).unwrap(), + I256::try_from(150_000_000_i64).unwrap(), + ); + } + + #[test] + fn scale_threshold_truncates_long_fractional() { + assert_eq!( + scale_threshold("1.123456789", 8).unwrap(), + I256::try_from(112_345_678_i64).unwrap(), + ); + } + + #[test] + fn scale_threshold_handles_no_decimal_point() { + assert_eq!( + scale_threshold("42", 8).unwrap(), + I256::try_from(4_200_000_000_i64).unwrap(), + ); + } + + #[test] + fn scale_threshold_handles_negative_values() { + assert_eq!( + scale_threshold("-1.5", 8).unwrap(), + -I256::try_from(150_000_000_i64).unwrap(), + ); + } + + #[test] + fn scale_threshold_rejects_garbage() { + assert!(matches!( + scale_threshold("abc", 8).unwrap_err().kind, + Kind::InvalidInput + )); + assert!(matches!( + scale_threshold("1.2.3", 8).unwrap_err().kind, + Kind::InvalidInput + )); + } + + #[test] + fn parse_config_happy_path() { + let entries = vec![ + ( + "oracle_address".into(), + "0x694AA1769357215DE4FAC081bf1f309aDC325306".into(), + ), + ("decimals".into(), "8".into()), + ("threshold".into(), "2500.50".into()), + ("direction".into(), "below".into()), + ("every_n_blocks".into(), "5".into()), + ]; + let cfg = parse_config(&entries).unwrap(); + assert_eq!(cfg.direction, Direction::Below); + assert_eq!(cfg.every_n_blocks, 5); + assert_eq!( + cfg.threshold_scaled, + I256::try_from(250_050_000_000_i64).unwrap() + ); + } + + #[test] + fn parse_config_defaults_every_n_blocks_to_one() { + let entries = vec![ + ( + "oracle_address".into(), + "0x694AA1769357215DE4FAC081bf1f309aDC325306".into(), + ), + ("decimals".into(), "8".into()), + ("threshold".into(), "1".into()), + ("direction".into(), "above".into()), + ]; + let cfg = parse_config(&entries).unwrap(); + assert_eq!(cfg.every_n_blocks, 1); + assert_eq!(cfg.direction, Direction::Above); + } + + #[test] + fn parse_config_rejects_missing_key() { + let entries = vec![ + ("decimals".into(), "8".into()), + ("threshold".into(), "1".into()), + ("direction".into(), "above".into()), + ]; + let err = parse_config(&entries).unwrap_err(); + assert!(matches!(err.kind, Kind::InvalidInput)); + assert!(err.message.contains("oracle_address")); + } + + // ---- strategy behaviour against MockHost ---- + + #[test] + fn on_block_idle_when_price_above_below_trigger() { + let host = MockHost::new(); + let settings = sample_settings(/*trigger*/ 250_050_000_000, Direction::Below); + programmed_eth_call( + &host, + settings.oracle_address, + Ok(oracle_response_json(300_000_000_000)), + ); + + on_block(&host, 11_155_111, &settings, 100).unwrap(); + + assert_eq!(host.chain.call_count(), 1); + assert!(host.logging.contains("ok answer=")); + assert_eq!(host.logging.count_at(LogLevel::Warn), 0); + } + + #[test] + fn on_block_triggers_below_threshold() { + let host = MockHost::new(); + let settings = sample_settings(250_050_000_000, Direction::Below); + programmed_eth_call( + &host, + settings.oracle_address, + Ok(oracle_response_json(200_000_000_000)), + ); + + on_block(&host, 11_155_111, &settings, 100).unwrap(); + + assert!(host.logging.contains("TRIGGERED")); + assert_eq!(host.logging.count_at(LogLevel::Warn), 1); + } + + #[test] + fn on_block_triggers_above_threshold() { + let host = MockHost::new(); + let settings = sample_settings(100, Direction::Above); + programmed_eth_call( + &host, + settings.oracle_address, + Ok(oracle_response_json(200)), + ); + + on_block(&host, 11_155_111, &settings, 100).unwrap(); + + assert!(host.logging.contains("TRIGGERED")); + } + + #[test] + fn on_block_warns_and_continues_on_rpc_error() { + let host = MockHost::new(); + let settings = sample_settings(100, Direction::Below); + programmed_eth_call( + &host, + settings.oracle_address, + Err(HostError { + domain: "chain".into(), + kind: Kind::Timeout, + code: 504, + message: "upstream timed out".into(), + data: None, + }), + ); + + // Strategy returns Ok so the supervisor moves on. + on_block(&host, 11_155_111, &settings, 100).unwrap(); + assert!(host.logging.contains("eth_call failed")); + // No "TRIGGERED" / "ok answer=" log because we never got an + // oracle response. + assert!(!host.logging.contains("TRIGGERED")); + } + + #[test] + fn on_block_warns_on_undecodable_result() { + let host = MockHost::new(); + let settings = sample_settings(100, Direction::Below); + programmed_eth_call(&host, settings.oracle_address, Ok("not-json".into())); + + on_block(&host, 11_155_111, &settings, 100).unwrap(); + assert!(host.logging.contains("cannot decode result hex")); + } + + #[test] + fn on_block_respects_every_n_blocks_throttle() { + let host = MockHost::new(); + let mut settings = sample_settings(100, Direction::Below); + settings.every_n_blocks = 5; + programmed_eth_call( + &host, + settings.oracle_address, + Ok(oracle_response_json(50)), + ); + + // Blocks 1..5 do not poll; only block 5 (which divides evenly). + for n in 1..5 { + on_block(&host, 11_155_111, &settings, n).unwrap(); + } + assert_eq!(host.chain.call_count(), 0); + + on_block(&host, 11_155_111, &settings, 5).unwrap(); + assert_eq!(host.chain.call_count(), 1); + } +}