diff --git a/dash-spv/src/network/manager.rs b/dash-spv/src/network/manager.rs index 95337fd5e..efc439b53 100644 --- a/dash-spv/src/network/manager.rs +++ b/dash-spv/src/network/manager.rs @@ -970,34 +970,24 @@ impl PeerNetworkManager { return Err(NetworkError::ConnectionFailed("No connected peers".to_string())); } - // For filter-related messages, we need a peer that supports compact filters - let requires_compact_filters = - matches!(&message, NetworkMessage::GetCFHeaders(_) | NetworkMessage::GetCFilters(_)); + let required_service = match &message { + NetworkMessage::GetCFHeaders(_) | NetworkMessage::GetCFilters(_) => { + Some(ServiceFlags::COMPACT_FILTERS) + } + _ => None, + }; let check_headers2 = matches!(&message, NetworkMessage::GetHeaders(_) | NetworkMessage::GetHeaders2(_)); - let selected_peer = if requires_compact_filters { - // Find a peer that supports compact filters - let mut filter_peer = None; - for (addr, peer) in &peers { - let peer_guard = peer.read().await; - - if peer_guard.supports_compact_filters() { - filter_peer = Some(*addr); - break; - } - } - - match filter_peer { - Some(addr) => { - log::debug!("Selected peer {} for compact filter request", addr); - addr + let selected_peer = if let Some(flags) = required_service { + match self.pool.peer_with_service(flags).await { + Some((address, _)) => { + log::debug!("Selected peer {} with {} for {}", address, flags, message.cmd()); + address } None => { - log::warn!("No peers support compact filters, cannot send {}", message.cmd()); - return Err(NetworkError::ProtocolError( - "No peers support compact filters".to_string(), - )); + log::warn!("No peers support {}, cannot send {}", flags, message.cmd()); + return Err(NetworkError::ProtocolError(format!("No peers support {}", flags))); } } } else if check_headers2 { @@ -1015,13 +1005,11 @@ impl PeerNetworkManager { } if selected.is_none() { - for (addr, peer) in &peers { - let peer_guard = peer.read().await; - if peer_guard.supports_headers2() { - selected = Some(*addr); - break; - } - } + selected = self + .pool + .peer_with_service(ServiceFlags::NODE_HEADERS_COMPRESSED) + .await + .map(|(addr, _)| addr); } let chosen = selected.unwrap_or(peers[0].0); @@ -1086,33 +1074,17 @@ impl PeerNetworkManager { // Select eligible peers based on message type let (selected_peers, require_capability) = match &message { NetworkMessage::GetCFHeaders(_) | NetworkMessage::GetCFilters(_) => { - // Filter requests require compact filter support - let filter_peers: Vec<_> = { - let mut result = Vec::new(); - for (addr, peer) in &peers { - let peer_guard = peer.read().await; - if peer_guard.supports_compact_filters() { - result.push((*addr, peer.clone())); - } - } - result - }; + let filter_peers = + self.pool.peers_with_service(ServiceFlags::COMPACT_FILTERS).await; (filter_peers, true) } NetworkMessage::GetHeaders(_) | NetworkMessage::GetHeaders2(_) => { - // Prefer headers2 peers, fall back to all - let headers2_peers: Vec<_> = { - let mut result = Vec::new(); - for (addr, peer) in &peers { - let peer_guard = peer.read().await; - if peer_guard.supports_headers2() - && !self.headers2_disabled.lock().await.contains(addr) - { - result.push((*addr, peer.clone())); - } - } - result - }; + // Prefer headers2 peers (excluding disabled), fall back to all + let disabled = self.headers2_disabled.lock().await; + let mut headers2_peers = + self.pool.peers_with_service(ServiceFlags::NODE_HEADERS_COMPRESSED).await; + headers2_peers.retain(|(addr, _)| !disabled.contains(addr)); + drop(disabled); if headers2_peers.is_empty() { (peers.clone(), false) } else { diff --git a/dash-spv/src/network/peer.rs b/dash-spv/src/network/peer.rs index dc3e484ed..5e409da80 100644 --- a/dash-spv/src/network/peer.rs +++ b/dash-spv/src/network/peer.rs @@ -833,15 +833,24 @@ impl Peer { } } +#[cfg(test)] +impl Peer { + pub(crate) fn set_services(&mut self, flags: ServiceFlags) { + self.services = Some(flags.as_u64()); + } +} + #[cfg(test)] mod tests { + use std::net::SocketAddr; use std::time::{Duration, SystemTime}; use super::Peer; #[test] fn remove_expired_pings() { - let mut peer = Peer::dummy(); + let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap(); + let mut peer = Peer::dummy(addr); let now = SystemTime::now(); let expired = now - Duration::from_secs(61); diff --git a/dash-spv/src/network/pool.rs b/dash-spv/src/network/pool.rs index 7550b1db9..d33f02259 100644 --- a/dash-spv/src/network/pool.rs +++ b/dash-spv/src/network/pool.rs @@ -3,6 +3,7 @@ use crate::error::{NetworkError, SpvError as Error}; use crate::network::constants::TARGET_PEERS; use crate::network::peer::Peer; +use dashcore::network::constants::ServiceFlags; use dashcore::prelude::CoreBlockHeight; use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; @@ -145,6 +146,35 @@ impl PeerPool { } } + /// Find the first connected peer that advertises the given service flags. + pub(crate) async fn peer_with_service( + &self, + flags: ServiceFlags, + ) -> Option<(SocketAddr, Arc>)> { + let peers = self.peers.read().await; + for (addr, peer) in peers.iter() { + if peer.read().await.has_service(flags) { + return Some((*addr, Arc::clone(peer))); + } + } + None + } + + /// Collect all connected peers that advertise the given service flags. + pub(crate) async fn peers_with_service( + &self, + flags: ServiceFlags, + ) -> Vec<(SocketAddr, Arc>)> { + let peers = self.peers.read().await; + let mut result = Vec::new(); + for (addr, peer) in peers.iter() { + if peer.read().await.has_service(flags) { + result.push((*addr, peer.clone())); + } + } + result + } + /// Check if we need more peers pub async fn needs_more_peers(&self) -> bool { self.peer_count().await < TARGET_PEERS @@ -189,6 +219,15 @@ impl Default for PeerPool { } } +#[cfg(test)] +impl PeerPool { + async fn insert_peer_with_services(&self, addr: SocketAddr, flags: ServiceFlags) { + let mut peer = Peer::dummy(addr); + peer.set_services(flags); + self.peers.write().await.insert(addr, Arc::new(RwLock::new(peer))); + } +} + #[cfg(test)] mod tests { use super::*; @@ -208,4 +247,49 @@ mod tests { assert!(!pool.mark_connecting(addr).await); // Already marked assert!(pool.is_connecting(&addr).await); } + + #[tokio::test] + async fn test_service_lookup() { + let pool = PeerPool::new(); + let compact_filters = ServiceFlags::COMPACT_FILTERS; + let combined = compact_filters | ServiceFlags::NODE_HEADERS_COMPRESSED; + + // No matches on empty pool + assert!(pool.peer_with_service(compact_filters).await.is_none()); + assert!(pool.peers_with_service(compact_filters).await.is_empty()); + + // No matches when peers lack the requested flag + let addr1: SocketAddr = "127.0.0.1:1001".parse().unwrap(); + pool.insert_peer_with_services(addr1, ServiceFlags::NETWORK).await; + assert!(pool.peer_with_service(compact_filters).await.is_none()); + assert!(pool.peers_with_service(compact_filters).await.is_empty()); + + // Single-flag lookup returns matching peers + let addr2: SocketAddr = "127.0.0.1:1002".parse().unwrap(); + let addr3: SocketAddr = "127.0.0.1:1003".parse().unwrap(); + pool.insert_peer_with_services(addr2, ServiceFlags::NETWORK | compact_filters).await; + pool.insert_peer_with_services(addr3, ServiceFlags::NETWORK | combined).await; + + let (found_addr, found_peer) = pool.peer_with_service(compact_filters).await.unwrap(); + assert!(found_addr == addr2 || found_addr == addr3); + assert!(found_peer.read().await.has_service(compact_filters)); + + let filter_peers: HashMap = + pool.peers_with_service(compact_filters).await.into_iter().collect(); + assert_eq!(filter_peers.len(), 2); + assert!(filter_peers.contains_key(&addr2)); + assert!(filter_peers.contains_key(&addr3)); + + // Combined flags require all bits present + let (found_addr, _) = pool.peer_with_service(combined).await.unwrap(); + assert_eq!(found_addr, addr3); + let combined_peers = pool.peers_with_service(combined).await; + assert_eq!(combined_peers.len(), 1); + assert_eq!(combined_peers[0].0, addr3); + + // NONE matches every peer in the pool + assert!(pool.peer_with_service(ServiceFlags::NONE).await.is_some()); + let all = pool.peers_with_service(ServiceFlags::NONE).await; + assert_eq!(all.len(), 3); + } } diff --git a/dash-spv/src/test_utils/network.rs b/dash-spv/src/test_utils/network.rs index 7e21c55f9..5d0f6231f 100644 --- a/dash-spv/src/test_utils/network.rs +++ b/dash-spv/src/test_utils/network.rs @@ -177,8 +177,7 @@ impl NetworkManager for MockNetworkManager { } impl Peer { - pub fn dummy() -> Self { - let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap(); + pub fn dummy(addr: SocketAddr) -> Self { Peer::new(addr, Duration::from_secs(10), Network::Mainnet) } }