diff --git a/dash-spv/src/network/manager.rs b/dash-spv/src/network/manager.rs index 95337fd5e..914aee39f 100644 --- a/dash-spv/src/network/manager.rs +++ b/dash-spv/src/network/manager.rs @@ -57,8 +57,6 @@ pub struct PeerNetworkManager { tasks: Arc>>, /// Initial peer addresses initial_peers: Vec, - /// Current sync peer (sticky during sync operations) - current_sync_peer: Arc>>, /// Data directory for storage data_dir: PathBuf, /// Mempool strategy from config @@ -113,7 +111,6 @@ impl PeerNetworkManager { shutdown_token: CancellationToken::new(), tasks: Arc::new(Mutex::new(JoinSet::new())), initial_peers: config.peers.clone(), - current_sync_peer: Arc::new(Mutex::new(None)), data_dir, mempool_strategy: config.mempool_strategy, user_agent: config.user_agent.clone(), @@ -962,7 +959,7 @@ impl PeerNetworkManager { }); } - /// Send a message to a single peer (using sticky peer selection for sync consistency) + /// Send a message to a single peer selected by message type requirements. async fn send_to_single_peer(&self, message: NetworkMessage) -> NetworkResult<()> { let peers = self.pool.get_all_peers().await; @@ -970,104 +967,33 @@ impl PeerNetworkManager { return Err(NetworkError::ConnectionFailed("No connected peers".to_string())); } - // For filter-related messages, we need a peer that supports compact filters - let requires_compact_filters = - matches!(&message, NetworkMessage::GetCFHeaders(_) | NetworkMessage::GetCFilters(_)); - let check_headers2 = - matches!(&message, NetworkMessage::GetHeaders(_) | NetworkMessage::GetHeaders2(_)); - - let selected_peer = if requires_compact_filters { - // Find a peer that supports compact filters - let mut filter_peer = None; - for (addr, peer) in &peers { - let peer_guard = peer.read().await; - - if peer_guard.supports_compact_filters() { - filter_peer = Some(*addr); - break; - } + let preferred_service = match &message { + NetworkMessage::GetCFHeaders(_) | NetworkMessage::GetCFilters(_) => { + Some((ServiceFlags::COMPACT_FILTERS, true)) } - - match filter_peer { - Some(addr) => { - log::debug!("Selected peer {} for compact filter request", addr); - addr - } - None => { - log::warn!("No peers support compact filters, cannot send {}", message.cmd()); - return Err(NetworkError::ProtocolError( - "No peers support compact filters".to_string(), - )); - } + NetworkMessage::GetHeaders(_) | NetworkMessage::GetHeaders2(_) => { + Some((ServiceFlags::NODE_HEADERS_COMPRESSED, false)) } - } else if check_headers2 { - // Prefer a peer that advertises headers2 support - let mut current_sync_peer = self.current_sync_peer.lock().await; - let mut selected: Option = None; + _ => None, + }; - if let Some(current_addr) = *current_sync_peer { - if let Some((_, peer)) = peers.iter().find(|(addr, _)| *addr == current_addr) { - let peer_guard = peer.read().await; - if peer_guard.supports_headers2() { - selected = Some(current_addr); - } + let (addr, peer) = if let Some((flags, required)) = preferred_service { + match self.pool.peer_with_service(flags).await { + Some((address, peer)) => { + log::debug!("Selected peer {} with {} for {}", address, flags, message.cmd()); + (address, peer) } - } - - if selected.is_none() { - for (addr, peer) in &peers { - let peer_guard = peer.read().await; - if peer_guard.supports_headers2() { - selected = Some(*addr); - break; - } + None if required => { + log::warn!("No peers support {}, cannot send {}", flags, message.cmd()); + return Err(NetworkError::ProtocolError(format!("No peers support {}", flags))); } + None => self.next_peer(&peers), } - - let chosen = selected.unwrap_or(peers[0].0); - if Some(chosen) != *current_sync_peer { - log::info!("Sync peer selected for Headers2: {}", chosen); - *current_sync_peer = Some(chosen); - } - drop(current_sync_peer); - chosen } else { - // For non-filter messages, use the sticky sync peer - let mut current_sync_peer = self.current_sync_peer.lock().await; - let selected = if let Some(current_addr) = *current_sync_peer { - // Check if current sync peer is still connected - if peers.iter().any(|(addr, _)| *addr == current_addr) { - // Keep using the same peer for sync consistency - current_addr - } else { - // Current sync peer disconnected, pick a new one - let new_addr = peers[0].0; - log::info!( - "Sync peer switched from {} to {} (previous peer disconnected)", - current_addr, - new_addr - ); - *current_sync_peer = Some(new_addr); - new_addr - } - } else { - // No current sync peer, pick the first available - let new_addr = peers[0].0; - log::info!("Sync peer selected: {}", new_addr); - *current_sync_peer = Some(new_addr); - new_addr - }; - drop(current_sync_peer); - selected + self.next_peer(&peers) }; - // Find the peer for the selected address - let (addr, peer) = peers - .iter() - .find(|(a, _)| *a == selected_peer) - .ok_or_else(|| NetworkError::ConnectionFailed("Selected peer not found".to_string()))?; - - self.send_message_to_peer(addr, peer, message).await + self.send_message_to_peer(&addr, &peer, message).await } /// Send a message distributed across connected peers using round-robin selection. @@ -1086,33 +1012,17 @@ impl PeerNetworkManager { // Select eligible peers based on message type let (selected_peers, require_capability) = match &message { NetworkMessage::GetCFHeaders(_) | NetworkMessage::GetCFilters(_) => { - // Filter requests require compact filter support - let filter_peers: Vec<_> = { - let mut result = Vec::new(); - for (addr, peer) in &peers { - let peer_guard = peer.read().await; - if peer_guard.supports_compact_filters() { - result.push((*addr, peer.clone())); - } - } - result - }; + let filter_peers = + self.pool.peers_with_service(ServiceFlags::COMPACT_FILTERS).await; (filter_peers, true) } NetworkMessage::GetHeaders(_) | NetworkMessage::GetHeaders2(_) => { - // Prefer headers2 peers, fall back to all - let headers2_peers: Vec<_> = { - let mut result = Vec::new(); - for (addr, peer) in &peers { - let peer_guard = peer.read().await; - if peer_guard.supports_headers2() - && !self.headers2_disabled.lock().await.contains(addr) - { - result.push((*addr, peer.clone())); - } - } - result - }; + // Prefer headers2 peers (excluding disabled), fall back to all + let disabled = self.headers2_disabled.lock().await; + let mut headers2_peers = + self.pool.peers_with_service(ServiceFlags::NODE_HEADERS_COMPRESSED).await; + headers2_peers.retain(|(addr, _)| !disabled.contains(addr)); + drop(disabled); if headers2_peers.is_empty() { (peers.clone(), false) } else { @@ -1133,18 +1043,20 @@ impl PeerNetworkManager { }; } - // Round-robin selection - let idx = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % selected_peers.len(); - let (addr, peer) = &selected_peers[idx]; + let (addr, peer) = self.next_peer(&selected_peers); - log::debug!( - "Distributing {} request to peer {} (round-robin idx {})", - message.cmd(), - addr, - idx - ); + log::debug!("Distributing {} request to peer {}", message.cmd(), addr); + + self.send_message_to_peer(&addr, &peer, message).await + } - self.send_message_to_peer(addr, peer, message).await + /// Pick the next peer from `peers` using round-robin rotation. + fn next_peer( + &self, + peers: &[(SocketAddr, Arc>)], + ) -> (SocketAddr, Arc>) { + let idx = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % peers.len(); + (peers[idx].0, peers[idx].1.clone()) } /// Send a message to the given peer. @@ -1307,7 +1219,6 @@ impl Clone for PeerNetworkManager { shutdown_token: self.shutdown_token.clone(), tasks: self.tasks.clone(), initial_peers: self.initial_peers.clone(), - current_sync_peer: self.current_sync_peer.clone(), data_dir: self.data_dir.clone(), mempool_strategy: self.mempool_strategy, user_agent: self.user_agent.clone(), diff --git a/dash-spv/src/network/peer.rs b/dash-spv/src/network/peer.rs index dc3e484ed..5e409da80 100644 --- a/dash-spv/src/network/peer.rs +++ b/dash-spv/src/network/peer.rs @@ -833,15 +833,24 @@ impl Peer { } } +#[cfg(test)] +impl Peer { + pub(crate) fn set_services(&mut self, flags: ServiceFlags) { + self.services = Some(flags.as_u64()); + } +} + #[cfg(test)] mod tests { + use std::net::SocketAddr; use std::time::{Duration, SystemTime}; use super::Peer; #[test] fn remove_expired_pings() { - let mut peer = Peer::dummy(); + let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap(); + let mut peer = Peer::dummy(addr); let now = SystemTime::now(); let expired = now - Duration::from_secs(61); diff --git a/dash-spv/src/network/pool.rs b/dash-spv/src/network/pool.rs index 7550b1db9..64bfb42d2 100644 --- a/dash-spv/src/network/pool.rs +++ b/dash-spv/src/network/pool.rs @@ -3,6 +3,7 @@ use crate::error::{NetworkError, SpvError as Error}; use crate::network::constants::TARGET_PEERS; use crate::network::peer::Peer; +use dashcore::network::constants::ServiceFlags; use dashcore::prelude::CoreBlockHeight; use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; @@ -145,6 +146,35 @@ impl PeerPool { } } + /// Find the first connected peer that advertises the given service flags. + pub(crate) async fn peer_with_service( + &self, + flags: ServiceFlags, + ) -> Option<(SocketAddr, Arc>)> { + let peers = self.peers.read().await; + for (addr, peer) in peers.iter() { + if peer.read().await.has_service(flags) { + return Some((*addr, Arc::clone(peer))); + } + } + None + } + + /// Collect all connected peers that advertise the given service flags. + pub(crate) async fn peers_with_service( + &self, + flags: ServiceFlags, + ) -> Vec<(SocketAddr, Arc>)> { + let peers = self.peers.read().await; + let mut result = Vec::new(); + for (addr, peer) in peers.iter() { + if peer.read().await.has_service(flags) { + result.push((*addr, peer.clone())); + } + } + result + } + /// Check if we need more peers pub async fn needs_more_peers(&self) -> bool { self.peer_count().await < TARGET_PEERS @@ -189,6 +219,15 @@ impl Default for PeerPool { } } +#[cfg(test)] +impl PeerPool { + async fn insert_peer_with_services(&self, addr: SocketAddr, flags: ServiceFlags) { + let mut peer = Peer::dummy(addr); + peer.set_services(flags); + self.peers.write().await.insert(addr, Arc::new(RwLock::new(peer))); + } +} + #[cfg(test)] mod tests { use super::*; @@ -208,4 +247,78 @@ mod tests { assert!(!pool.mark_connecting(addr).await); // Already marked assert!(pool.is_connecting(&addr).await); } + + #[tokio::test] + async fn test_peer_with_service() { + let pool = PeerPool::new(); + let flags = ServiceFlags::COMPACT_FILTERS; + + // No match on empty pool + assert!(pool.peer_with_service(flags).await.is_none()); + + // No match when peers lack the requested flag + let addr1: SocketAddr = "127.0.0.1:1001".parse().unwrap(); + pool.insert_peer_with_services(addr1, ServiceFlags::NETWORK).await; + assert!(pool.peer_with_service(flags).await.is_none()); + + // Returns the matching peer with correct address and services + let addr2: SocketAddr = "127.0.0.1:1002".parse().unwrap(); + pool.insert_peer_with_services(addr2, ServiceFlags::NETWORK | flags).await; + let (found_addr, found_peer) = pool.peer_with_service(flags).await.unwrap(); + assert_eq!(found_addr, addr2); + assert!(found_peer.read().await.has_service(flags)); + } + + #[tokio::test] + async fn test_peers_with_service() { + let pool = PeerPool::new(); + let flags = ServiceFlags::COMPACT_FILTERS; + + // Empty on empty pool + assert!(pool.peers_with_service(flags).await.is_empty()); + + // Empty when no peer has the flag + let addr1: SocketAddr = "127.0.0.1:1001".parse().unwrap(); + pool.insert_peer_with_services(addr1, ServiceFlags::NETWORK).await; + assert!(pool.peers_with_service(flags).await.is_empty()); + + // Returns all matching peers, skips non-matching + let addr2: SocketAddr = "127.0.0.1:1002".parse().unwrap(); + let addr3: SocketAddr = "127.0.0.1:1003".parse().unwrap(); + pool.insert_peer_with_services(addr2, flags).await; + pool.insert_peer_with_services(addr3, ServiceFlags::NETWORK | flags).await; + + let results: HashMap = + pool.peers_with_service(flags).await.into_iter().collect(); + assert_eq!(results.len(), 2); + assert!(results[&addr2].read().await.has_service(flags)); + assert!(results[&addr3].read().await.has_service(flags)); + } + + #[tokio::test] + async fn test_service_lookup_with_combined_flags() { + let pool = PeerPool::new(); + let combined = ServiceFlags::COMPACT_FILTERS | ServiceFlags::NODE_HEADERS_COMPRESSED; + + // Peer with only one of the two flags does not match combined query + let addr1: SocketAddr = "127.0.0.1:1001".parse().unwrap(); + pool.insert_peer_with_services(addr1, ServiceFlags::COMPACT_FILTERS).await; + assert!(pool.peer_with_service(combined).await.is_none()); + assert!(pool.peers_with_service(combined).await.is_empty()); + + // Peer with both flags matches + let addr2: SocketAddr = "127.0.0.1:1002".parse().unwrap(); + pool.insert_peer_with_services(addr2, combined | ServiceFlags::NETWORK).await; + let (found, _) = pool.peer_with_service(combined).await.unwrap(); + assert_eq!(found, addr2); + + let all = pool.peers_with_service(combined).await; + assert_eq!(all.len(), 1); + assert_eq!(all[0].0, addr2); + + // Single-flag queries still match both peers appropriately + assert!(pool.peer_with_service(ServiceFlags::COMPACT_FILTERS).await.is_some()); + let cf_peers = pool.peers_with_service(ServiceFlags::COMPACT_FILTERS).await; + assert_eq!(cf_peers.len(), 2); + } } diff --git a/dash-spv/src/test_utils/network.rs b/dash-spv/src/test_utils/network.rs index 7e21c55f9..5d0f6231f 100644 --- a/dash-spv/src/test_utils/network.rs +++ b/dash-spv/src/test_utils/network.rs @@ -177,8 +177,7 @@ impl NetworkManager for MockNetworkManager { } impl Peer { - pub fn dummy() -> Self { - let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap(); + pub fn dummy(addr: SocketAddr) -> Self { Peer::new(addr, Duration::from_secs(10), Network::Mainnet) } }