Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 39 additions & 128 deletions dash-spv/src/network/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ pub struct PeerNetworkManager {
tasks: Arc<Mutex<JoinSet<()>>>,
/// Initial peer addresses
initial_peers: Vec<SocketAddr>,
/// Current sync peer (sticky during sync operations)
current_sync_peer: Arc<Mutex<Option<SocketAddr>>>,
/// Data directory for storage
data_dir: PathBuf,
/// Mempool strategy from config
Expand Down Expand Up @@ -113,7 +111,6 @@ impl PeerNetworkManager {
shutdown_token: CancellationToken::new(),
tasks: Arc::new(Mutex::new(JoinSet::new())),
initial_peers: config.peers.clone(),
current_sync_peer: Arc::new(Mutex::new(None)),
data_dir,
mempool_strategy: config.mempool_strategy,
user_agent: config.user_agent.clone(),
Expand Down Expand Up @@ -962,112 +959,41 @@ impl PeerNetworkManager {
});
}

/// Send a message to a single peer (using sticky peer selection for sync consistency)
/// Send a message to a single peer selected by message type requirements.
async fn send_to_single_peer(&self, message: NetworkMessage) -> NetworkResult<()> {
let peers = self.pool.get_all_peers().await;

if peers.is_empty() {
return Err(NetworkError::ConnectionFailed("No connected peers".to_string()));
}

// For filter-related messages, we need a peer that supports compact filters
let requires_compact_filters =
matches!(&message, NetworkMessage::GetCFHeaders(_) | NetworkMessage::GetCFilters(_));
let check_headers2 =
matches!(&message, NetworkMessage::GetHeaders(_) | NetworkMessage::GetHeaders2(_));

let selected_peer = if requires_compact_filters {
// Find a peer that supports compact filters
let mut filter_peer = None;
for (addr, peer) in &peers {
let peer_guard = peer.read().await;

if peer_guard.supports_compact_filters() {
filter_peer = Some(*addr);
break;
}
let preferred_service = match &message {
NetworkMessage::GetCFHeaders(_) | NetworkMessage::GetCFilters(_) => {
Some((ServiceFlags::COMPACT_FILTERS, true))
}

match filter_peer {
Some(addr) => {
log::debug!("Selected peer {} for compact filter request", addr);
addr
}
None => {
log::warn!("No peers support compact filters, cannot send {}", message.cmd());
return Err(NetworkError::ProtocolError(
"No peers support compact filters".to_string(),
));
}
NetworkMessage::GetHeaders(_) | NetworkMessage::GetHeaders2(_) => {
Some((ServiceFlags::NODE_HEADERS_COMPRESSED, false))
}
} else if check_headers2 {
// Prefer a peer that advertises headers2 support
let mut current_sync_peer = self.current_sync_peer.lock().await;
let mut selected: Option<SocketAddr> = None;
_ => None,
};

if let Some(current_addr) = *current_sync_peer {
if let Some((_, peer)) = peers.iter().find(|(addr, _)| *addr == current_addr) {
let peer_guard = peer.read().await;
if peer_guard.supports_headers2() {
selected = Some(current_addr);
}
let (addr, peer) = if let Some((flags, required)) = preferred_service {
match self.pool.peer_with_service(flags).await {
Some((address, peer)) => {
log::debug!("Selected peer {} with {} for {}", address, flags, message.cmd());
(address, peer)
}
}

if selected.is_none() {
for (addr, peer) in &peers {
let peer_guard = peer.read().await;
if peer_guard.supports_headers2() {
selected = Some(*addr);
break;
}
None if required => {
log::warn!("No peers support {}, cannot send {}", flags, message.cmd());
return Err(NetworkError::ProtocolError(format!("No peers support {}", flags)));
}
None => self.next_peer(&peers),
}

let chosen = selected.unwrap_or(peers[0].0);
if Some(chosen) != *current_sync_peer {
log::info!("Sync peer selected for Headers2: {}", chosen);
*current_sync_peer = Some(chosen);
}
drop(current_sync_peer);
chosen
} else {
// For non-filter messages, use the sticky sync peer
let mut current_sync_peer = self.current_sync_peer.lock().await;
let selected = if let Some(current_addr) = *current_sync_peer {
// Check if current sync peer is still connected
if peers.iter().any(|(addr, _)| *addr == current_addr) {
// Keep using the same peer for sync consistency
current_addr
} else {
// Current sync peer disconnected, pick a new one
let new_addr = peers[0].0;
log::info!(
"Sync peer switched from {} to {} (previous peer disconnected)",
current_addr,
new_addr
);
*current_sync_peer = Some(new_addr);
new_addr
}
} else {
// No current sync peer, pick the first available
let new_addr = peers[0].0;
log::info!("Sync peer selected: {}", new_addr);
*current_sync_peer = Some(new_addr);
new_addr
};
drop(current_sync_peer);
selected
self.next_peer(&peers)
};

// Find the peer for the selected address
let (addr, peer) = peers
.iter()
.find(|(a, _)| *a == selected_peer)
.ok_or_else(|| NetworkError::ConnectionFailed("Selected peer not found".to_string()))?;

self.send_message_to_peer(addr, peer, message).await
self.send_message_to_peer(&addr, &peer, message).await
}

/// Send a message distributed across connected peers using round-robin selection.
Expand All @@ -1086,33 +1012,17 @@ impl PeerNetworkManager {
// Select eligible peers based on message type
let (selected_peers, require_capability) = match &message {
NetworkMessage::GetCFHeaders(_) | NetworkMessage::GetCFilters(_) => {
// Filter requests require compact filter support
let filter_peers: Vec<_> = {
let mut result = Vec::new();
for (addr, peer) in &peers {
let peer_guard = peer.read().await;
if peer_guard.supports_compact_filters() {
result.push((*addr, peer.clone()));
}
}
result
};
let filter_peers =
self.pool.peers_with_service(ServiceFlags::COMPACT_FILTERS).await;
(filter_peers, true)
}
NetworkMessage::GetHeaders(_) | NetworkMessage::GetHeaders2(_) => {
// Prefer headers2 peers, fall back to all
let headers2_peers: Vec<_> = {
let mut result = Vec::new();
for (addr, peer) in &peers {
let peer_guard = peer.read().await;
if peer_guard.supports_headers2()
&& !self.headers2_disabled.lock().await.contains(addr)
{
result.push((*addr, peer.clone()));
}
}
result
};
// Prefer headers2 peers (excluding disabled), fall back to all
let disabled = self.headers2_disabled.lock().await;
let mut headers2_peers =
self.pool.peers_with_service(ServiceFlags::NODE_HEADERS_COMPRESSED).await;
headers2_peers.retain(|(addr, _)| !disabled.contains(addr));
drop(disabled);
if headers2_peers.is_empty() {
(peers.clone(), false)
} else {
Expand All @@ -1133,18 +1043,20 @@ impl PeerNetworkManager {
};
}

// Round-robin selection
let idx = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % selected_peers.len();
let (addr, peer) = &selected_peers[idx];
let (addr, peer) = self.next_peer(&selected_peers);

log::debug!(
"Distributing {} request to peer {} (round-robin idx {})",
message.cmd(),
addr,
idx
);
log::debug!("Distributing {} request to peer {}", message.cmd(), addr);

self.send_message_to_peer(&addr, &peer, message).await
}

self.send_message_to_peer(addr, peer, message).await
/// Pick the next peer from `peers` using round-robin rotation.
fn next_peer(
&self,
peers: &[(SocketAddr, Arc<RwLock<Peer>>)],
) -> (SocketAddr, Arc<RwLock<Peer>>) {
let idx = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % peers.len();
(peers[idx].0, peers[idx].1.clone())
}

/// Send a message to the given peer.
Expand Down Expand Up @@ -1307,7 +1219,6 @@ impl Clone for PeerNetworkManager {
shutdown_token: self.shutdown_token.clone(),
tasks: self.tasks.clone(),
initial_peers: self.initial_peers.clone(),
current_sync_peer: self.current_sync_peer.clone(),
data_dir: self.data_dir.clone(),
mempool_strategy: self.mempool_strategy,
user_agent: self.user_agent.clone(),
Expand Down
11 changes: 10 additions & 1 deletion dash-spv/src/network/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -833,15 +833,24 @@ impl Peer {
}
}

#[cfg(test)]
impl Peer {
pub(crate) fn set_services(&mut self, flags: ServiceFlags) {
self.services = Some(flags.as_u64());
}
}

#[cfg(test)]
mod tests {
use std::net::SocketAddr;
use std::time::{Duration, SystemTime};

use super::Peer;

#[test]
fn remove_expired_pings() {
let mut peer = Peer::dummy();
let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let mut peer = Peer::dummy(addr);
let now = SystemTime::now();
let expired = now - Duration::from_secs(61);

Expand Down
113 changes: 113 additions & 0 deletions dash-spv/src/network/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use crate::error::{NetworkError, SpvError as Error};
use crate::network::constants::TARGET_PEERS;
use crate::network::peer::Peer;
use dashcore::network::constants::ServiceFlags;
use dashcore::prelude::CoreBlockHeight;
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
Expand Down Expand Up @@ -145,6 +146,35 @@ impl PeerPool {
}
}

/// Find the first connected peer that advertises the given service flags.
pub(crate) async fn peer_with_service(
&self,
flags: ServiceFlags,
) -> Option<(SocketAddr, Arc<RwLock<Peer>>)> {
let peers = self.peers.read().await;
for (addr, peer) in peers.iter() {
if peer.read().await.has_service(flags) {
return Some((*addr, Arc::clone(peer)));
}
}
None
}

/// Collect all connected peers that advertise the given service flags.
pub(crate) async fn peers_with_service(
&self,
flags: ServiceFlags,
) -> Vec<(SocketAddr, Arc<RwLock<Peer>>)> {
let peers = self.peers.read().await;
let mut result = Vec::new();
for (addr, peer) in peers.iter() {
if peer.read().await.has_service(flags) {
result.push((*addr, peer.clone()));
}
}
result
}

/// Check if we need more peers
pub async fn needs_more_peers(&self) -> bool {
self.peer_count().await < TARGET_PEERS
Expand Down Expand Up @@ -189,6 +219,15 @@ impl Default for PeerPool {
}
}

#[cfg(test)]
impl PeerPool {
async fn insert_peer_with_services(&self, addr: SocketAddr, flags: ServiceFlags) {
let mut peer = Peer::dummy(addr);
peer.set_services(flags);
self.peers.write().await.insert(addr, Arc::new(RwLock::new(peer)));
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -208,4 +247,78 @@ mod tests {
assert!(!pool.mark_connecting(addr).await); // Already marked
assert!(pool.is_connecting(&addr).await);
}

#[tokio::test]
async fn test_peer_with_service() {
let pool = PeerPool::new();
let flags = ServiceFlags::COMPACT_FILTERS;

// No match on empty pool
assert!(pool.peer_with_service(flags).await.is_none());

// No match when peers lack the requested flag
let addr1: SocketAddr = "127.0.0.1:1001".parse().unwrap();
pool.insert_peer_with_services(addr1, ServiceFlags::NETWORK).await;
assert!(pool.peer_with_service(flags).await.is_none());

// Returns the matching peer with correct address and services
let addr2: SocketAddr = "127.0.0.1:1002".parse().unwrap();
pool.insert_peer_with_services(addr2, ServiceFlags::NETWORK | flags).await;
let (found_addr, found_peer) = pool.peer_with_service(flags).await.unwrap();
assert_eq!(found_addr, addr2);
assert!(found_peer.read().await.has_service(flags));
}

#[tokio::test]
async fn test_peers_with_service() {
let pool = PeerPool::new();
let flags = ServiceFlags::COMPACT_FILTERS;

// Empty on empty pool
assert!(pool.peers_with_service(flags).await.is_empty());

// Empty when no peer has the flag
let addr1: SocketAddr = "127.0.0.1:1001".parse().unwrap();
pool.insert_peer_with_services(addr1, ServiceFlags::NETWORK).await;
assert!(pool.peers_with_service(flags).await.is_empty());

// Returns all matching peers, skips non-matching
let addr2: SocketAddr = "127.0.0.1:1002".parse().unwrap();
let addr3: SocketAddr = "127.0.0.1:1003".parse().unwrap();
pool.insert_peer_with_services(addr2, flags).await;
pool.insert_peer_with_services(addr3, ServiceFlags::NETWORK | flags).await;

let results: HashMap<SocketAddr, _> =
pool.peers_with_service(flags).await.into_iter().collect();
assert_eq!(results.len(), 2);
assert!(results[&addr2].read().await.has_service(flags));
assert!(results[&addr3].read().await.has_service(flags));
}

#[tokio::test]
async fn test_service_lookup_with_combined_flags() {
let pool = PeerPool::new();
let combined = ServiceFlags::COMPACT_FILTERS | ServiceFlags::NODE_HEADERS_COMPRESSED;

// Peer with only one of the two flags does not match combined query
let addr1: SocketAddr = "127.0.0.1:1001".parse().unwrap();
pool.insert_peer_with_services(addr1, ServiceFlags::COMPACT_FILTERS).await;
assert!(pool.peer_with_service(combined).await.is_none());
assert!(pool.peers_with_service(combined).await.is_empty());

// Peer with both flags matches
let addr2: SocketAddr = "127.0.0.1:1002".parse().unwrap();
pool.insert_peer_with_services(addr2, combined | ServiceFlags::NETWORK).await;
let (found, _) = pool.peer_with_service(combined).await.unwrap();
assert_eq!(found, addr2);

let all = pool.peers_with_service(combined).await;
assert_eq!(all.len(), 1);
assert_eq!(all[0].0, addr2);

// Single-flag queries still match both peers appropriately
assert!(pool.peer_with_service(ServiceFlags::COMPACT_FILTERS).await.is_some());
let cf_peers = pool.peers_with_service(ServiceFlags::COMPACT_FILTERS).await;
assert_eq!(cf_peers.len(), 2);
}
}
Loading
Loading