diff --git a/noq/src/connection.rs b/noq/src/connection.rs index 9466ccb0d..8262f37b9 100644 --- a/noq/src/connection.rs +++ b/noq/src/connection.rs @@ -39,9 +39,28 @@ use proto::{ /// In-progress connection attempt future #[derive(Debug)] pub struct Connecting { - conn: Option, - connected: oneshot::Receiver, - handshake_data_ready: Option>, + state: ConnectingState, +} + +#[derive(Debug)] +enum ConnectingState { + Active { + conn: ConnectionRef, + connected: oneshot::Receiver, + handshake_data_ready: Option>, + }, + Consumed, +} + +impl Drop for Connecting { + fn drop(&mut self) { + if let ConnectingState::Active { conn, .. } = &mut self.state { + let mut state = conn.lock_without_waking("connecting_drop"); + if !state.inner.is_closed() { + state.implicit_close(&conn.shared); + } + } + } } impl Connecting { @@ -81,9 +100,11 @@ impl Connecting { )); Self { - conn: Some(conn), - connected: on_connected_recv, - handshake_data_ready: Some(on_handshake_data_recv), + state: ConnectingState::Active { + conn, + connected: on_connected_recv, + handshake_data_ready: Some(on_handshake_data_recv), + }, } } @@ -132,16 +153,20 @@ impl Connecting { /// before TLS client authentication has occurred, and should therefore not be used to send /// data for which client authentication is being used. pub fn into_0rtt(mut self) -> Result<(Connection, ZeroRttAccepted), Self> { - // This lock borrows `self` and would normally be dropped at the end of this scope, so we'll - // have to release it explicitly before returning `self` by value. - let conn = (self.conn.as_mut().unwrap()).lock_without_waking("into_0rtt"); - - let is_ok = conn.inner.has_0rtt() || conn.inner.side().is_server(); - drop(conn); + let is_ok = match &self.state { + ConnectingState::Active { conn, .. } => { + let inner = conn.lock_without_waking("into_0rtt"); + inner.inner.has_0rtt() || inner.inner.side().is_server() + } + ConnectingState::Consumed => false, + }; if is_ok { - let conn = self.conn.take().unwrap(); - Ok((Connection(conn), ZeroRttAccepted(self.connected))) + if let ConnectingState::Active { conn, connected, .. } = std::mem::replace(&mut self.state, ConnectingState::Consumed) { + Ok((Connection(conn), ZeroRttAccepted(connected))) + } else { + unreachable!("state must be Active since is_ok is true") + } } else { Err(self) } @@ -153,14 +178,19 @@ impl Connecting { /// [`Session`](proto::crypto::Session). For the default `rustls` session, the return value can /// be [`downcast`](Box::downcast) to a /// [`crypto::rustls::HandshakeData`](crate::crypto::rustls::HandshakeData). + /// + /// Will panic if called after `poll` has returned `Ready`. pub async fn handshake_data(&mut self) -> Result, ConnectionError> { // Taking &mut self allows us to use a single oneshot channel rather than dealing with // potentially many tasks waiting on the same event. It's a bit of a hack, but keeps things // simple. - if let Some(x) = self.handshake_data_ready.take() { + let ConnectingState::Active { conn, handshake_data_ready, .. } = &mut self.state else { + panic!("used after yielding Ready"); + }; + + if let Some(x) = handshake_data_ready.take() { let _ = x.await; } - let conn = self.conn.as_ref().unwrap(); let inner = conn.lock_without_waking("handshake"); inner .inner @@ -186,7 +216,9 @@ impl Connecting { /// /// Will panic if called after `poll` has returned `Ready`. pub fn local_ip(&self) -> Option { - let conn = self.conn.as_ref().expect("used after yielding Ready"); + let ConnectingState::Active { conn, .. } = &self.state else { + panic!("used after yielding Ready"); + }; let inner = conn.lock_without_waking("local_ip"); inner @@ -200,8 +232,10 @@ impl Connecting { /// /// Will panic if called after `poll` has returned `Ready`. pub fn remote_address(&self) -> SocketAddr { - let conn_ref: &ConnectionRef = self.conn.as_ref().expect("used after yielding Ready"); - conn_ref + let ConnectingState::Active { conn, .. } = &self.state else { + panic!("used after yielding Ready"); + }; + conn .lock_without_waking("remote_address") .inner .network_path(PathId::ZERO) @@ -213,19 +247,30 @@ impl Connecting { impl Future for Connecting { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.connected).poll(cx).map(|_| { - let conn = self.conn.take().unwrap(); + match &mut self.state { + ConnectingState::Active { connected, .. } => { + match Pin::new(connected).poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(_) => {} + } + } + ConnectingState::Consumed => panic!("polled after yielding Ready"), + } + + if let ConnectingState::Active { conn, .. } = std::mem::replace(&mut self.state, ConnectingState::Consumed) { let inner = conn.lock_without_waking("connecting"); if inner.connected { drop(inner); - Ok(Connection(conn)) + Poll::Ready(Ok(Connection(conn))) } else { - Err(inner + Poll::Ready(Err(inner .error .clone() - .expect("connected signaled without connection success or error")) + .expect("connected signaled without connection success or error"))) } - }) + } else { + unreachable!() + } } } diff --git a/noq/src/tests.rs b/noq/src/tests.rs index 74c760c3d..37765e2ee 100755 --- a/noq/src/tests.rs +++ b/noq/src/tests.rs @@ -1762,3 +1762,12 @@ unsafe fn wake_by_ref_waker(data: *const ()) { unsafe fn drop_waker(data: *const ()) { drop(unsafe { Arc::::from_raw(data as *const WakeCounter) }); } + +#[tokio::test] +async fn drop_connecting_cleans_up() { + let ep = endpoint(); + let addr = "127.0.0.1:1234".parse().unwrap(); + let connecting = ep.connect(addr, "localhost").unwrap(); + drop(connecting); + ep.wait_idle().await; +}