diff --git a/crates/flux-network/src/tcp/connector.rs b/crates/flux-network/src/tcp/connector.rs index a47f6a9..b2152e7 100644 --- a/crates/flux-network/src/tcp/connector.rs +++ b/crates/flux-network/src/tcp/connector.rs @@ -76,6 +76,8 @@ struct ConnectionManager { to_be_reconnected: Vec<(Token, ConnectionVariant)>, // Outbound connections that completed during maybe_reconnect, drained in poll_with. reconnected_to: Vec, + // Connections dropped outside event handling, drained in poll_with before reconnects. + pending_disconnects: Vec, next_token: usize, /// Scratch buffers for [`SendBehavior::Broadcast`]: the frame is serialised @@ -98,6 +100,7 @@ impl Default for ConnectionManager { nodelay: true, to_be_reconnected: Vec::with_capacity(10), reconnected_to: Vec::with_capacity(10), + pending_disconnects: Vec::with_capacity(10), poll: Poll::new().expect("couldn't set up a poll for tcp connector"), next_token: 0, bcast_header: [0; FRAME_HEADER_SIZE], @@ -133,6 +136,12 @@ impl ConnectionManager { } } + fn disconnect_at_index_pending(&mut self, index: usize) { + let token = self.conns[index].0; + self.disconnect_at_index(index); + self.pending_disconnects.push(token); + } + fn disconnect_token(&mut self, token: Token) { if let Some(i) = self.conns.iter().position(|(t, _)| *t == token) { self.disconnect_at_index(i); @@ -166,7 +175,7 @@ impl ConnectionManager { &self.bcast_payload, ) == ConnState::Disconnected { - self.disconnect_at_index(i); + self.disconnect_at_index_pending(i); } } ConnectionVariant::Listener(_tcp_listener) => {} @@ -190,7 +199,7 @@ impl ConnectionManager { ConnState::Disconnected { tracing::warn!("issue when writing to {token:?} disconnecting"); - self.disconnect_at_index(i); + self.disconnect_at_index_pending(i); } } ConnectionVariant::Listener(_tcp_listener) => error!( @@ -216,7 +225,7 @@ impl ConnectionManager { }; if stream.has_backlog() { if stream.drain_backlog(self.poll.registry()) == ConnState::Disconnected { - self.disconnect_at_index(i); + self.disconnect_at_index_pending(i); continue; } if let Some((max, timeout)) = self.max_backlog { @@ -233,7 +242,7 @@ impl ConnectionManager { ?elapsed, "backlog exceeded limit for too long, disconnecting" ); - self.disconnect_at_index(i); + self.disconnect_at_index_pending(i); } } else { // Back below threshold — reset the timer. @@ -379,6 +388,18 @@ impl ConnectionManager { self.maybe_reconnect(); } + #[inline] + fn drain_pending_disconnects(&mut self, handler: &mut F) -> bool + where + F: for<'a> FnMut(PollEvent<&'a [u8]>), + { + let had_pending = !self.pending_disconnects.is_empty(); + for token in self.pending_disconnects.drain(..) { + handler(PollEvent::Disconnect { token }); + } + had_pending + } + #[inline] fn handle_event(&mut self, e: &Event, handler: &mut F) where @@ -428,6 +449,7 @@ impl ConnectionManager { continue; } } + set_user_timeout(&stream, self.user_timeout_ms); let mut conn = TcpStream::from_stream_with_telemetry( stream, token, @@ -515,6 +537,7 @@ impl ConnectionManager { continue; } } + set_user_timeout(&stream, self.user_timeout_ms); let mut conn = TcpStream::from_stream_with_telemetry( stream, token, @@ -675,21 +698,23 @@ impl TcpConnector { where F: for<'a> FnMut(PollEvent<&'a [u8]>), { + let mut o = self.conn_mgr.drain_pending_disconnects(&mut handler); self.conn_mgr.maybe_reconnect(); for token in self.conn_mgr.reconnected_to.drain(..) { handler(PollEvent::Reconnect { token }); + o = true; } if let Err(e) = self.conn_mgr.poll.poll(&mut self.events, Some(std::time::Duration::ZERO)) { safe_panic!("got error polling {e}"); return false; } - let mut o = false; for e in &self.events { o = true; self.conn_mgr.handle_event(e, &mut handler); } self.conn_mgr.flush_backlogs(); + o |= self.conn_mgr.drain_pending_disconnects(&mut handler); o } @@ -706,20 +731,26 @@ impl TcpConnector { P: SpineProducers + AsRef>, F: for<'a> FnMut(PollEvent<&'a [u8]>) -> Option, { + let mut o = self.conn_mgr.drain_pending_disconnects(&mut |event| { + let _ = on_msg(event); + }); self.conn_mgr.maybe_reconnect(); for token in self.conn_mgr.reconnected_to.drain(..) { let _ = on_msg(PollEvent::Reconnect { token }); + o = true; } if let Err(e) = self.conn_mgr.poll.poll(&mut self.events, Some(std::time::Duration::ZERO)) { safe_panic!("got error polling {e}"); return false; } - let mut o = false; for e in &self.events { o = true; self.conn_mgr.handle_event_produce(e, produce, &mut on_msg); } self.conn_mgr.flush_backlogs(); + o |= self.conn_mgr.drain_pending_disconnects(&mut |event| { + let _ = on_msg(event); + }); o } diff --git a/crates/flux-network/tests/tcp_roundtrip.rs b/crates/flux-network/tests/tcp_roundtrip.rs index 29a52b3..6b0a4fc 100644 --- a/crates/flux-network/tests/tcp_roundtrip.rs +++ b/crates/flux-network/tests/tcp_roundtrip.rs @@ -1,5 +1,5 @@ use std::{ - net::{IpAddr, Ipv4Addr, SocketAddr}, + net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream}, thread, time::Duration, }; @@ -91,3 +91,51 @@ fn tcp_roundtrip() { server.join().unwrap(); client.join().unwrap(); } + +#[test] +fn backlog_disconnect_is_reported() { + let probe = + std::net::TcpListener::bind(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))).expect("probe"); + let bind_addr = probe.local_addr().unwrap(); + drop(probe); + + let mut listener = TcpConnector::default() + .with_socket_buf_size(1024) + .with_max_backlog(0, flux_timing::Duration::ZERO); + listener.listen_at(bind_addr).unwrap(); + + let client = TcpStream::connect(bind_addr).expect("failed to connect client"); + + let mut stream_token = None; + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while stream_token.is_none() && std::time::Instant::now() < deadline { + listener.poll_with(|event| { + if let PollEvent::Accept { stream, .. } = event { + stream_token = Some(stream); + } + }); + thread::sleep(Duration::from_millis(1)); + } + + let stream_token = stream_token.expect("listener did not accept client"); + + let payload = vec![7_u8; 16 * 1024 * 1024]; + listener.write_or_enqueue_with(SendBehavior::Single(stream_token), |buf| { + buf.extend_from_slice(&payload); + }); + + let mut disconnected = false; + let deadline = std::time::Instant::now() + Duration::from_secs(5); + while !disconnected && std::time::Instant::now() < deadline { + listener.poll_with(|event| { + if let PollEvent::Disconnect { token } = event { + assert_eq!(token, stream_token); + disconnected = true; + } + }); + thread::sleep(Duration::from_millis(1)); + } + + drop(client); + assert!(disconnected, "backlog disconnect was not reported"); +}