Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions crates/flux-network/src/tcp/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ struct ConnectionManager {
to_be_reconnected: Vec<(Token, ConnectionVariant)>,
// Outbound connections that completed during maybe_reconnect, drained in poll_with.
reconnected_to: Vec<Token>,
// Connections dropped outside event handling, drained in poll_with before reconnects.
pending_disconnects: Vec<Token>,
next_token: usize,

/// Scratch buffers for [`SendBehavior::Broadcast`]: the frame is serialised
Expand All @@ -98,6 +100,7 @@ impl Default for ConnectionManager {
nodelay: true,
to_be_reconnected: Vec::with_capacity(10),
reconnected_to: Vec::with_capacity(10),
pending_disconnects: Vec::with_capacity(10),
poll: Poll::new().expect("couldn't set up a poll for tcp connector"),
next_token: 0,
bcast_header: [0; FRAME_HEADER_SIZE],
Expand Down Expand Up @@ -133,6 +136,12 @@ impl ConnectionManager {
}
}

fn disconnect_at_index_pending(&mut self, index: usize) {
let token = self.conns[index].0;
self.disconnect_at_index(index);
self.pending_disconnects.push(token);
}

fn disconnect_token(&mut self, token: Token) {
if let Some(i) = self.conns.iter().position(|(t, _)| *t == token) {
self.disconnect_at_index(i);
Expand Down Expand Up @@ -166,7 +175,7 @@ impl ConnectionManager {
&self.bcast_payload,
) == ConnState::Disconnected
{
self.disconnect_at_index(i);
self.disconnect_at_index_pending(i);
}
}
ConnectionVariant::Listener(_tcp_listener) => {}
Expand All @@ -190,7 +199,7 @@ impl ConnectionManager {
ConnState::Disconnected
{
tracing::warn!("issue when writing to {token:?} disconnecting");
self.disconnect_at_index(i);
self.disconnect_at_index_pending(i);
}
}
ConnectionVariant::Listener(_tcp_listener) => error!(
Expand All @@ -216,7 +225,7 @@ impl ConnectionManager {
};
if stream.has_backlog() {
if stream.drain_backlog(self.poll.registry()) == ConnState::Disconnected {
self.disconnect_at_index(i);
self.disconnect_at_index_pending(i);
continue;
}
if let Some((max, timeout)) = self.max_backlog {
Expand All @@ -233,7 +242,7 @@ impl ConnectionManager {
?elapsed,
"backlog exceeded limit for too long, disconnecting"
);
self.disconnect_at_index(i);
self.disconnect_at_index_pending(i);
}
} else {
// Back below threshold — reset the timer.
Expand Down Expand Up @@ -379,6 +388,18 @@ impl ConnectionManager {
self.maybe_reconnect();
}

#[inline]
fn drain_pending_disconnects<F>(&mut self, handler: &mut F) -> bool
where
F: for<'a> FnMut(PollEvent<&'a [u8]>),
{
let had_pending = !self.pending_disconnects.is_empty();
for token in self.pending_disconnects.drain(..) {
handler(PollEvent::Disconnect { token });
}
had_pending
}

#[inline]
fn handle_event<F>(&mut self, e: &Event, handler: &mut F)
where
Expand Down Expand Up @@ -428,6 +449,7 @@ impl ConnectionManager {
continue;
}
}
set_user_timeout(&stream, self.user_timeout_ms);
let mut conn = TcpStream::from_stream_with_telemetry(
stream,
token,
Expand Down Expand Up @@ -515,6 +537,7 @@ impl ConnectionManager {
continue;
}
}
set_user_timeout(&stream, self.user_timeout_ms);
let mut conn = TcpStream::from_stream_with_telemetry(
stream,
token,
Expand Down Expand Up @@ -675,21 +698,23 @@ impl TcpConnector {
where
F: for<'a> FnMut(PollEvent<&'a [u8]>),
{
let mut o = self.conn_mgr.drain_pending_disconnects(&mut handler);
self.conn_mgr.maybe_reconnect();
for token in self.conn_mgr.reconnected_to.drain(..) {
handler(PollEvent::Reconnect { token });
o = true;
}
if let Err(e) = self.conn_mgr.poll.poll(&mut self.events, Some(std::time::Duration::ZERO)) {
safe_panic!("got error polling {e}");
return false;
}
let mut o = false;

for e in &self.events {
o = true;
self.conn_mgr.handle_event(e, &mut handler);
}
self.conn_mgr.flush_backlogs();
o |= self.conn_mgr.drain_pending_disconnects(&mut handler);
o
}

Expand All @@ -706,20 +731,26 @@ impl TcpConnector {
P: SpineProducers + AsRef<SpineProducerWithDCache<T>>,
F: for<'a> FnMut(PollEvent<&'a [u8]>) -> Option<T>,
{
let mut o = self.conn_mgr.drain_pending_disconnects(&mut |event| {
let _ = on_msg(event);
});
self.conn_mgr.maybe_reconnect();
for token in self.conn_mgr.reconnected_to.drain(..) {
let _ = on_msg(PollEvent::Reconnect { token });
o = true;
}
if let Err(e) = self.conn_mgr.poll.poll(&mut self.events, Some(std::time::Duration::ZERO)) {
safe_panic!("got error polling {e}");
return false;
}
let mut o = false;
for e in &self.events {
o = true;
self.conn_mgr.handle_event_produce(e, produce, &mut on_msg);
}
self.conn_mgr.flush_backlogs();
o |= self.conn_mgr.drain_pending_disconnects(&mut |event| {
let _ = on_msg(event);
});
o
}

Expand Down
50 changes: 49 additions & 1 deletion crates/flux-network/tests/tcp_roundtrip.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream},
thread,
time::Duration,
};
Expand Down Expand Up @@ -91,3 +91,51 @@ fn tcp_roundtrip() {
server.join().unwrap();
client.join().unwrap();
}

#[test]
fn backlog_disconnect_is_reported() {
let probe =
std::net::TcpListener::bind(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))).expect("probe");
let bind_addr = probe.local_addr().unwrap();
drop(probe);

let mut listener = TcpConnector::default()
.with_socket_buf_size(1024)
.with_max_backlog(0, flux_timing::Duration::ZERO);
listener.listen_at(bind_addr).unwrap();

let client = TcpStream::connect(bind_addr).expect("failed to connect client");

let mut stream_token = None;
let deadline = std::time::Instant::now() + Duration::from_secs(5);
while stream_token.is_none() && std::time::Instant::now() < deadline {
listener.poll_with(|event| {
if let PollEvent::Accept { stream, .. } = event {
stream_token = Some(stream);
}
});
thread::sleep(Duration::from_millis(1));
}

let stream_token = stream_token.expect("listener did not accept client");

let payload = vec![7_u8; 16 * 1024 * 1024];
listener.write_or_enqueue_with(SendBehavior::Single(stream_token), |buf| {
buf.extend_from_slice(&payload);
});

let mut disconnected = false;
let deadline = std::time::Instant::now() + Duration::from_secs(5);
while !disconnected && std::time::Instant::now() < deadline {
listener.poll_with(|event| {
if let PollEvent::Disconnect { token } = event {
assert_eq!(token, stream_token);
disconnected = true;
}
});
thread::sleep(Duration::from_millis(1));
}

drop(client);
assert!(disconnected, "backlog disconnect was not reported");
}
Loading