Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 26 additions & 54 deletions dash-spv/src/network/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -970,34 +970,24 @@ impl PeerNetworkManager {
return Err(NetworkError::ConnectionFailed("No connected peers".to_string()));
}

// For filter-related messages, we need a peer that supports compact filters
let requires_compact_filters =
matches!(&message, NetworkMessage::GetCFHeaders(_) | NetworkMessage::GetCFilters(_));
let required_service = match &message {
NetworkMessage::GetCFHeaders(_) | NetworkMessage::GetCFilters(_) => {
Some(ServiceFlags::COMPACT_FILTERS)
}
_ => None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are touching this logic, couldn't we improve it by returning ServiceFlags::None, that way we remove the unnecessary Optional and if let branch. Feel free to do it if you want, if not I will take a look into ti

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some other PRs coming in this area.

};
let check_headers2 =
matches!(&message, NetworkMessage::GetHeaders(_) | NetworkMessage::GetHeaders2(_));

let selected_peer = if requires_compact_filters {
// Find a peer that supports compact filters
let mut filter_peer = None;
for (addr, peer) in &peers {
let peer_guard = peer.read().await;

if peer_guard.supports_compact_filters() {
filter_peer = Some(*addr);
break;
}
}

match filter_peer {
Some(addr) => {
log::debug!("Selected peer {} for compact filter request", addr);
addr
let selected_peer = if let Some(flags) = required_service {
match self.pool.peer_with_service(flags).await {
Some((address, _)) => {
log::debug!("Selected peer {} with {} for {}", address, flags, message.cmd());
address
}
None => {
log::warn!("No peers support compact filters, cannot send {}", message.cmd());
return Err(NetworkError::ProtocolError(
"No peers support compact filters".to_string(),
));
log::warn!("No peers support {}, cannot send {}", flags, message.cmd());
return Err(NetworkError::ProtocolError(format!("No peers support {}", flags)));
}
}
} else if check_headers2 {
Expand All @@ -1015,13 +1005,11 @@ impl PeerNetworkManager {
}

if selected.is_none() {
for (addr, peer) in &peers {
let peer_guard = peer.read().await;
if peer_guard.supports_headers2() {
selected = Some(*addr);
break;
}
}
selected = self
.pool
.peer_with_service(ServiceFlags::NODE_HEADERS_COMPRESSED)
.await
.map(|(addr, _)| addr);
}

let chosen = selected.unwrap_or(peers[0].0);
Expand Down Expand Up @@ -1086,33 +1074,17 @@ impl PeerNetworkManager {
// Select eligible peers based on message type
let (selected_peers, require_capability) = match &message {
NetworkMessage::GetCFHeaders(_) | NetworkMessage::GetCFilters(_) => {
// Filter requests require compact filter support
let filter_peers: Vec<_> = {
let mut result = Vec::new();
for (addr, peer) in &peers {
let peer_guard = peer.read().await;
if peer_guard.supports_compact_filters() {
result.push((*addr, peer.clone()));
}
}
result
};
let filter_peers =
self.pool.peers_with_service(ServiceFlags::COMPACT_FILTERS).await;
(filter_peers, true)
}
NetworkMessage::GetHeaders(_) | NetworkMessage::GetHeaders2(_) => {
// Prefer headers2 peers, fall back to all
let headers2_peers: Vec<_> = {
let mut result = Vec::new();
for (addr, peer) in &peers {
let peer_guard = peer.read().await;
if peer_guard.supports_headers2()
&& !self.headers2_disabled.lock().await.contains(addr)
{
result.push((*addr, peer.clone()));
}
}
result
};
// Prefer headers2 peers (excluding disabled), fall back to all
let disabled = self.headers2_disabled.lock().await;
let mut headers2_peers =
self.pool.peers_with_service(ServiceFlags::NODE_HEADERS_COMPRESSED).await;
headers2_peers.retain(|(addr, _)| !disabled.contains(addr));
drop(disabled);
if headers2_peers.is_empty() {
(peers.clone(), false)
} else {
Expand Down
11 changes: 10 additions & 1 deletion dash-spv/src/network/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -833,15 +833,24 @@ impl Peer {
}
}

#[cfg(test)]
impl Peer {
pub(crate) fn set_services(&mut self, flags: ServiceFlags) {
self.services = Some(flags.as_u64());
}
}

#[cfg(test)]
mod tests {
use std::net::SocketAddr;
use std::time::{Duration, SystemTime};

use super::Peer;

#[test]
fn remove_expired_pings() {
let mut peer = Peer::dummy();
let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let mut peer = Peer::dummy(addr);
let now = SystemTime::now();
let expired = now - Duration::from_secs(61);

Expand Down
84 changes: 84 additions & 0 deletions dash-spv/src/network/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use crate::error::{NetworkError, SpvError as Error};
use crate::network::constants::TARGET_PEERS;
use crate::network::peer::Peer;
use dashcore::network::constants::ServiceFlags;
use dashcore::prelude::CoreBlockHeight;
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
Expand Down Expand Up @@ -145,6 +146,35 @@ impl PeerPool {
}
}

/// Find the first connected peer that advertises the given service flags.
pub(crate) async fn peer_with_service(
&self,
flags: ServiceFlags,
) -> Option<(SocketAddr, Arc<RwLock<Peer>>)> {
let peers = self.peers.read().await;
for (addr, peer) in peers.iter() {
if peer.read().await.has_service(flags) {
return Some((*addr, Arc::clone(peer)));
}
}
None
}

/// Collect all connected peers that advertise the given service flags.
pub(crate) async fn peers_with_service(
&self,
flags: ServiceFlags,
) -> Vec<(SocketAddr, Arc<RwLock<Peer>>)> {
let peers = self.peers.read().await;
let mut result = Vec::new();
for (addr, peer) in peers.iter() {
if peer.read().await.has_service(flags) {
result.push((*addr, peer.clone()));
}
}
result
}

/// Check if we need more peers
pub async fn needs_more_peers(&self) -> bool {
self.peer_count().await < TARGET_PEERS
Expand Down Expand Up @@ -189,6 +219,15 @@ impl Default for PeerPool {
}
}

#[cfg(test)]
impl PeerPool {
async fn insert_peer_with_services(&self, addr: SocketAddr, flags: ServiceFlags) {
let mut peer = Peer::dummy(addr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this into the tests module

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? Its test only already and moving it into the actual tests module just adds another layer of indentation. I think it's cleaner to have a separation here with having a test impl block.

peer.set_services(flags);
self.peers.write().await.insert(addr, Arc::new(RwLock::new(peer)));
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -208,4 +247,49 @@ mod tests {
assert!(!pool.mark_connecting(addr).await); // Already marked
assert!(pool.is_connecting(&addr).await);
}

#[tokio::test]
async fn test_service_lookup() {
let pool = PeerPool::new();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove one test_peer_with_service function and merge missing cases if they exist into one

let compact_filters = ServiceFlags::COMPACT_FILTERS;
let combined = compact_filters | ServiceFlags::NODE_HEADERS_COMPRESSED;

// No matches on empty pool
assert!(pool.peer_with_service(compact_filters).await.is_none());
assert!(pool.peers_with_service(compact_filters).await.is_empty());

// No matches when peers lack the requested flag
let addr1: SocketAddr = "127.0.0.1:1001".parse().unwrap();
pool.insert_peer_with_services(addr1, ServiceFlags::NETWORK).await;
assert!(pool.peer_with_service(compact_filters).await.is_none());
assert!(pool.peers_with_service(compact_filters).await.is_empty());

// Single-flag lookup returns matching peers
let addr2: SocketAddr = "127.0.0.1:1002".parse().unwrap();
let addr3: SocketAddr = "127.0.0.1:1003".parse().unwrap();
pool.insert_peer_with_services(addr2, ServiceFlags::NETWORK | compact_filters).await;
pool.insert_peer_with_services(addr3, ServiceFlags::NETWORK | combined).await;

let (found_addr, found_peer) = pool.peer_with_service(compact_filters).await.unwrap();
assert!(found_addr == addr2 || found_addr == addr3);
assert!(found_peer.read().await.has_service(compact_filters));

let filter_peers: HashMap<SocketAddr, _> =
pool.peers_with_service(compact_filters).await.into_iter().collect();
assert_eq!(filter_peers.len(), 2);
assert!(filter_peers.contains_key(&addr2));
assert!(filter_peers.contains_key(&addr3));

// Combined flags require all bits present
let (found_addr, _) = pool.peer_with_service(combined).await.unwrap();
assert_eq!(found_addr, addr3);
let combined_peers = pool.peers_with_service(combined).await;
assert_eq!(combined_peers.len(), 1);
assert_eq!(combined_peers[0].0, addr3);

// NONE matches every peer in the pool
assert!(pool.peer_with_service(ServiceFlags::NONE).await.is_some());
let all = pool.peers_with_service(ServiceFlags::NONE).await;
assert_eq!(all.len(), 3);
}
}
3 changes: 1 addition & 2 deletions dash-spv/src/test_utils/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ impl NetworkManager for MockNetworkManager {
}

impl Peer {
pub fn dummy() -> Self {
let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap();
pub fn dummy(addr: SocketAddr) -> Self {
Peer::new(addr, Duration::from_secs(10), Network::Mainnet)
}
}
Loading