From 36be3dde62898fec91390d337d90441f75c0f57c Mon Sep 17 00:00:00 2001 From: xdustinface Date: Mon, 9 Mar 2026 22:33:08 +0700 Subject: [PATCH 1/2] refactor: extract capability lookup into `PeerPool` helpers - Add `peer_with_service()` and `peers_with_service()` on `PeerPool` to replace repeated "iterate peers, check service flag" loops in message routing. - Add unit tests - Generalize the match/log/error pattern for required-service peer selection in `send_to_single_peer` so new service requirements only need a single match arm. --- dash-spv/src/network/manager.rs | 80 +++++++------------- dash-spv/src/network/peer.rs | 11 ++- dash-spv/src/network/pool.rs | 113 +++++++++++++++++++++++++++++ dash-spv/src/test_utils/network.rs | 3 +- 4 files changed, 150 insertions(+), 57 deletions(-) diff --git a/dash-spv/src/network/manager.rs b/dash-spv/src/network/manager.rs index 95337fd5e..efc439b53 100644 --- a/dash-spv/src/network/manager.rs +++ b/dash-spv/src/network/manager.rs @@ -970,34 +970,24 @@ impl PeerNetworkManager { return Err(NetworkError::ConnectionFailed("No connected peers".to_string())); } - // For filter-related messages, we need a peer that supports compact filters - let requires_compact_filters = - matches!(&message, NetworkMessage::GetCFHeaders(_) | NetworkMessage::GetCFilters(_)); + let required_service = match &message { + NetworkMessage::GetCFHeaders(_) | NetworkMessage::GetCFilters(_) => { + Some(ServiceFlags::COMPACT_FILTERS) + } + _ => None, + }; let check_headers2 = matches!(&message, NetworkMessage::GetHeaders(_) | NetworkMessage::GetHeaders2(_)); - let selected_peer = if requires_compact_filters { - // Find a peer that supports compact filters - let mut filter_peer = None; - for (addr, peer) in &peers { - let peer_guard = peer.read().await; - - if peer_guard.supports_compact_filters() { - filter_peer = Some(*addr); - break; - } - } - - match filter_peer { - Some(addr) => { - log::debug!("Selected peer {} for compact filter request", addr); - addr + let selected_peer = if let Some(flags) = required_service { + match self.pool.peer_with_service(flags).await { + Some((address, _)) => { + log::debug!("Selected peer {} with {} for {}", address, flags, message.cmd()); + address } None => { - log::warn!("No peers support compact filters, cannot send {}", message.cmd()); - return Err(NetworkError::ProtocolError( - "No peers support compact filters".to_string(), - )); + log::warn!("No peers support {}, cannot send {}", flags, message.cmd()); + return Err(NetworkError::ProtocolError(format!("No peers support {}", flags))); } } } else if check_headers2 { @@ -1015,13 +1005,11 @@ impl PeerNetworkManager { } if selected.is_none() { - for (addr, peer) in &peers { - let peer_guard = peer.read().await; - if peer_guard.supports_headers2() { - selected = Some(*addr); - break; - } - } + selected = self + .pool + .peer_with_service(ServiceFlags::NODE_HEADERS_COMPRESSED) + .await + .map(|(addr, _)| addr); } let chosen = selected.unwrap_or(peers[0].0); @@ -1086,33 +1074,17 @@ impl PeerNetworkManager { // Select eligible peers based on message type let (selected_peers, require_capability) = match &message { NetworkMessage::GetCFHeaders(_) | NetworkMessage::GetCFilters(_) => { - // Filter requests require compact filter support - let filter_peers: Vec<_> = { - let mut result = Vec::new(); - for (addr, peer) in &peers { - let peer_guard = peer.read().await; - if peer_guard.supports_compact_filters() { - result.push((*addr, peer.clone())); - } - } - result - }; + let filter_peers = + self.pool.peers_with_service(ServiceFlags::COMPACT_FILTERS).await; (filter_peers, true) } NetworkMessage::GetHeaders(_) | NetworkMessage::GetHeaders2(_) => { - // Prefer headers2 peers, fall back to all - let headers2_peers: Vec<_> = { - let mut result = Vec::new(); - for (addr, peer) in &peers { - let peer_guard = peer.read().await; - if peer_guard.supports_headers2() - && !self.headers2_disabled.lock().await.contains(addr) - { - result.push((*addr, peer.clone())); - } - } - result - }; + // Prefer headers2 peers (excluding disabled), fall back to all + let disabled = self.headers2_disabled.lock().await; + let mut headers2_peers = + self.pool.peers_with_service(ServiceFlags::NODE_HEADERS_COMPRESSED).await; + headers2_peers.retain(|(addr, _)| !disabled.contains(addr)); + drop(disabled); if headers2_peers.is_empty() { (peers.clone(), false) } else { diff --git a/dash-spv/src/network/peer.rs b/dash-spv/src/network/peer.rs index dc3e484ed..5e409da80 100644 --- a/dash-spv/src/network/peer.rs +++ b/dash-spv/src/network/peer.rs @@ -833,15 +833,24 @@ impl Peer { } } +#[cfg(test)] +impl Peer { + pub(crate) fn set_services(&mut self, flags: ServiceFlags) { + self.services = Some(flags.as_u64()); + } +} + #[cfg(test)] mod tests { + use std::net::SocketAddr; use std::time::{Duration, SystemTime}; use super::Peer; #[test] fn remove_expired_pings() { - let mut peer = Peer::dummy(); + let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap(); + let mut peer = Peer::dummy(addr); let now = SystemTime::now(); let expired = now - Duration::from_secs(61); diff --git a/dash-spv/src/network/pool.rs b/dash-spv/src/network/pool.rs index 7550b1db9..64bfb42d2 100644 --- a/dash-spv/src/network/pool.rs +++ b/dash-spv/src/network/pool.rs @@ -3,6 +3,7 @@ use crate::error::{NetworkError, SpvError as Error}; use crate::network::constants::TARGET_PEERS; use crate::network::peer::Peer; +use dashcore::network::constants::ServiceFlags; use dashcore::prelude::CoreBlockHeight; use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; @@ -145,6 +146,35 @@ impl PeerPool { } } + /// Find the first connected peer that advertises the given service flags. + pub(crate) async fn peer_with_service( + &self, + flags: ServiceFlags, + ) -> Option<(SocketAddr, Arc>)> { + let peers = self.peers.read().await; + for (addr, peer) in peers.iter() { + if peer.read().await.has_service(flags) { + return Some((*addr, Arc::clone(peer))); + } + } + None + } + + /// Collect all connected peers that advertise the given service flags. + pub(crate) async fn peers_with_service( + &self, + flags: ServiceFlags, + ) -> Vec<(SocketAddr, Arc>)> { + let peers = self.peers.read().await; + let mut result = Vec::new(); + for (addr, peer) in peers.iter() { + if peer.read().await.has_service(flags) { + result.push((*addr, peer.clone())); + } + } + result + } + /// Check if we need more peers pub async fn needs_more_peers(&self) -> bool { self.peer_count().await < TARGET_PEERS @@ -189,6 +219,15 @@ impl Default for PeerPool { } } +#[cfg(test)] +impl PeerPool { + async fn insert_peer_with_services(&self, addr: SocketAddr, flags: ServiceFlags) { + let mut peer = Peer::dummy(addr); + peer.set_services(flags); + self.peers.write().await.insert(addr, Arc::new(RwLock::new(peer))); + } +} + #[cfg(test)] mod tests { use super::*; @@ -208,4 +247,78 @@ mod tests { assert!(!pool.mark_connecting(addr).await); // Already marked assert!(pool.is_connecting(&addr).await); } + + #[tokio::test] + async fn test_peer_with_service() { + let pool = PeerPool::new(); + let flags = ServiceFlags::COMPACT_FILTERS; + + // No match on empty pool + assert!(pool.peer_with_service(flags).await.is_none()); + + // No match when peers lack the requested flag + let addr1: SocketAddr = "127.0.0.1:1001".parse().unwrap(); + pool.insert_peer_with_services(addr1, ServiceFlags::NETWORK).await; + assert!(pool.peer_with_service(flags).await.is_none()); + + // Returns the matching peer with correct address and services + let addr2: SocketAddr = "127.0.0.1:1002".parse().unwrap(); + pool.insert_peer_with_services(addr2, ServiceFlags::NETWORK | flags).await; + let (found_addr, found_peer) = pool.peer_with_service(flags).await.unwrap(); + assert_eq!(found_addr, addr2); + assert!(found_peer.read().await.has_service(flags)); + } + + #[tokio::test] + async fn test_peers_with_service() { + let pool = PeerPool::new(); + let flags = ServiceFlags::COMPACT_FILTERS; + + // Empty on empty pool + assert!(pool.peers_with_service(flags).await.is_empty()); + + // Empty when no peer has the flag + let addr1: SocketAddr = "127.0.0.1:1001".parse().unwrap(); + pool.insert_peer_with_services(addr1, ServiceFlags::NETWORK).await; + assert!(pool.peers_with_service(flags).await.is_empty()); + + // Returns all matching peers, skips non-matching + let addr2: SocketAddr = "127.0.0.1:1002".parse().unwrap(); + let addr3: SocketAddr = "127.0.0.1:1003".parse().unwrap(); + pool.insert_peer_with_services(addr2, flags).await; + pool.insert_peer_with_services(addr3, ServiceFlags::NETWORK | flags).await; + + let results: HashMap = + pool.peers_with_service(flags).await.into_iter().collect(); + assert_eq!(results.len(), 2); + assert!(results[&addr2].read().await.has_service(flags)); + assert!(results[&addr3].read().await.has_service(flags)); + } + + #[tokio::test] + async fn test_service_lookup_with_combined_flags() { + let pool = PeerPool::new(); + let combined = ServiceFlags::COMPACT_FILTERS | ServiceFlags::NODE_HEADERS_COMPRESSED; + + // Peer with only one of the two flags does not match combined query + let addr1: SocketAddr = "127.0.0.1:1001".parse().unwrap(); + pool.insert_peer_with_services(addr1, ServiceFlags::COMPACT_FILTERS).await; + assert!(pool.peer_with_service(combined).await.is_none()); + assert!(pool.peers_with_service(combined).await.is_empty()); + + // Peer with both flags matches + let addr2: SocketAddr = "127.0.0.1:1002".parse().unwrap(); + pool.insert_peer_with_services(addr2, combined | ServiceFlags::NETWORK).await; + let (found, _) = pool.peer_with_service(combined).await.unwrap(); + assert_eq!(found, addr2); + + let all = pool.peers_with_service(combined).await; + assert_eq!(all.len(), 1); + assert_eq!(all[0].0, addr2); + + // Single-flag queries still match both peers appropriately + assert!(pool.peer_with_service(ServiceFlags::COMPACT_FILTERS).await.is_some()); + let cf_peers = pool.peers_with_service(ServiceFlags::COMPACT_FILTERS).await; + assert_eq!(cf_peers.len(), 2); + } } diff --git a/dash-spv/src/test_utils/network.rs b/dash-spv/src/test_utils/network.rs index 7e21c55f9..5d0f6231f 100644 --- a/dash-spv/src/test_utils/network.rs +++ b/dash-spv/src/test_utils/network.rs @@ -177,8 +177,7 @@ impl NetworkManager for MockNetworkManager { } impl Peer { - pub fn dummy() -> Self { - let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap(); + pub fn dummy(addr: SocketAddr) -> Self { Peer::new(addr, Duration::from_secs(10), Network::Mainnet) } } From d6f2108670d02be56434c80cfb1bdc65c8f4191e Mon Sep 17 00:00:00 2001 From: xdustinface Date: Tue, 10 Mar 2026 12:13:12 +0700 Subject: [PATCH 2/2] refactor: remove `current_sync_peer` from network manager Simplify peer selection in `send_to_single_peer` by removing the sticky sync peer tracking. Peers are now selected directly based on message type requirements and the round-robin counter. --- dash-spv/src/network/manager.rs | 111 +++++++------------------------- 1 file changed, 25 insertions(+), 86 deletions(-) diff --git a/dash-spv/src/network/manager.rs b/dash-spv/src/network/manager.rs index efc439b53..914aee39f 100644 --- a/dash-spv/src/network/manager.rs +++ b/dash-spv/src/network/manager.rs @@ -57,8 +57,6 @@ pub struct PeerNetworkManager { tasks: Arc>>, /// Initial peer addresses initial_peers: Vec, - /// Current sync peer (sticky during sync operations) - current_sync_peer: Arc>>, /// Data directory for storage data_dir: PathBuf, /// Mempool strategy from config @@ -113,7 +111,6 @@ impl PeerNetworkManager { shutdown_token: CancellationToken::new(), tasks: Arc::new(Mutex::new(JoinSet::new())), initial_peers: config.peers.clone(), - current_sync_peer: Arc::new(Mutex::new(None)), data_dir, mempool_strategy: config.mempool_strategy, user_agent: config.user_agent.clone(), @@ -962,7 +959,7 @@ impl PeerNetworkManager { }); } - /// Send a message to a single peer (using sticky peer selection for sync consistency) + /// Send a message to a single peer selected by message type requirements. async fn send_to_single_peer(&self, message: NetworkMessage) -> NetworkResult<()> { let peers = self.pool.get_all_peers().await; @@ -970,92 +967,33 @@ impl PeerNetworkManager { return Err(NetworkError::ConnectionFailed("No connected peers".to_string())); } - let required_service = match &message { + let preferred_service = match &message { NetworkMessage::GetCFHeaders(_) | NetworkMessage::GetCFilters(_) => { - Some(ServiceFlags::COMPACT_FILTERS) + Some((ServiceFlags::COMPACT_FILTERS, true)) + } + NetworkMessage::GetHeaders(_) | NetworkMessage::GetHeaders2(_) => { + Some((ServiceFlags::NODE_HEADERS_COMPRESSED, false)) } _ => None, }; - let check_headers2 = - matches!(&message, NetworkMessage::GetHeaders(_) | NetworkMessage::GetHeaders2(_)); - let selected_peer = if let Some(flags) = required_service { + let (addr, peer) = if let Some((flags, required)) = preferred_service { match self.pool.peer_with_service(flags).await { - Some((address, _)) => { + Some((address, peer)) => { log::debug!("Selected peer {} with {} for {}", address, flags, message.cmd()); - address + (address, peer) } - None => { + None if required => { log::warn!("No peers support {}, cannot send {}", flags, message.cmd()); return Err(NetworkError::ProtocolError(format!("No peers support {}", flags))); } + None => self.next_peer(&peers), } - } else if check_headers2 { - // Prefer a peer that advertises headers2 support - let mut current_sync_peer = self.current_sync_peer.lock().await; - let mut selected: Option = None; - - if let Some(current_addr) = *current_sync_peer { - if let Some((_, peer)) = peers.iter().find(|(addr, _)| *addr == current_addr) { - let peer_guard = peer.read().await; - if peer_guard.supports_headers2() { - selected = Some(current_addr); - } - } - } - - if selected.is_none() { - selected = self - .pool - .peer_with_service(ServiceFlags::NODE_HEADERS_COMPRESSED) - .await - .map(|(addr, _)| addr); - } - - let chosen = selected.unwrap_or(peers[0].0); - if Some(chosen) != *current_sync_peer { - log::info!("Sync peer selected for Headers2: {}", chosen); - *current_sync_peer = Some(chosen); - } - drop(current_sync_peer); - chosen } else { - // For non-filter messages, use the sticky sync peer - let mut current_sync_peer = self.current_sync_peer.lock().await; - let selected = if let Some(current_addr) = *current_sync_peer { - // Check if current sync peer is still connected - if peers.iter().any(|(addr, _)| *addr == current_addr) { - // Keep using the same peer for sync consistency - current_addr - } else { - // Current sync peer disconnected, pick a new one - let new_addr = peers[0].0; - log::info!( - "Sync peer switched from {} to {} (previous peer disconnected)", - current_addr, - new_addr - ); - *current_sync_peer = Some(new_addr); - new_addr - } - } else { - // No current sync peer, pick the first available - let new_addr = peers[0].0; - log::info!("Sync peer selected: {}", new_addr); - *current_sync_peer = Some(new_addr); - new_addr - }; - drop(current_sync_peer); - selected + self.next_peer(&peers) }; - // Find the peer for the selected address - let (addr, peer) = peers - .iter() - .find(|(a, _)| *a == selected_peer) - .ok_or_else(|| NetworkError::ConnectionFailed("Selected peer not found".to_string()))?; - - self.send_message_to_peer(addr, peer, message).await + self.send_message_to_peer(&addr, &peer, message).await } /// Send a message distributed across connected peers using round-robin selection. @@ -1105,18 +1043,20 @@ impl PeerNetworkManager { }; } - // Round-robin selection - let idx = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % selected_peers.len(); - let (addr, peer) = &selected_peers[idx]; + let (addr, peer) = self.next_peer(&selected_peers); - log::debug!( - "Distributing {} request to peer {} (round-robin idx {})", - message.cmd(), - addr, - idx - ); + log::debug!("Distributing {} request to peer {}", message.cmd(), addr); - self.send_message_to_peer(addr, peer, message).await + self.send_message_to_peer(&addr, &peer, message).await + } + + /// Pick the next peer from `peers` using round-robin rotation. + fn next_peer( + &self, + peers: &[(SocketAddr, Arc>)], + ) -> (SocketAddr, Arc>) { + let idx = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % peers.len(); + (peers[idx].0, peers[idx].1.clone()) } /// Send a message to the given peer. @@ -1279,7 +1219,6 @@ impl Clone for PeerNetworkManager { shutdown_token: self.shutdown_token.clone(), tasks: self.tasks.clone(), initial_peers: self.initial_peers.clone(), - current_sync_peer: self.current_sync_peer.clone(), data_dir: self.data_dir.clone(), mempool_strategy: self.mempool_strategy, user_agent: self.user_agent.clone(),