Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 70 additions & 25 deletions noq/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,28 @@ use proto::{
/// In-progress connection attempt future
#[derive(Debug)]
pub struct Connecting {
conn: Option<ConnectionRef>,
connected: oneshot::Receiver<bool>,
handshake_data_ready: Option<oneshot::Receiver<()>>,
state: ConnectingState,
}

#[derive(Debug)]
enum ConnectingState {
Active {
conn: ConnectionRef,
connected: oneshot::Receiver<bool>,
handshake_data_ready: Option<oneshot::Receiver<()>>,
},
Consumed,
}

impl Drop for Connecting {
fn drop(&mut self) {
if let ConnectingState::Active { conn, .. } = &mut self.state {
let mut state = conn.lock_without_waking("connecting_drop");
if !state.inner.is_closed() {
state.implicit_close(&conn.shared);
}
}
}
}

impl Connecting {
Expand Down Expand Up @@ -81,9 +100,11 @@ impl Connecting {
));

Self {
conn: Some(conn),
connected: on_connected_recv,
handshake_data_ready: Some(on_handshake_data_recv),
state: ConnectingState::Active {
conn,
connected: on_connected_recv,
handshake_data_ready: Some(on_handshake_data_recv),
},
}
}

Expand Down Expand Up @@ -132,16 +153,20 @@ impl Connecting {
/// before TLS client authentication has occurred, and should therefore not be used to send
/// data for which client authentication is being used.
pub fn into_0rtt(mut self) -> Result<(Connection, ZeroRttAccepted), Self> {
// This lock borrows `self` and would normally be dropped at the end of this scope, so we'll
// have to release it explicitly before returning `self` by value.
let conn = (self.conn.as_mut().unwrap()).lock_without_waking("into_0rtt");

let is_ok = conn.inner.has_0rtt() || conn.inner.side().is_server();
drop(conn);
let is_ok = match &self.state {
ConnectingState::Active { conn, .. } => {
let inner = conn.lock_without_waking("into_0rtt");
inner.inner.has_0rtt() || inner.inner.side().is_server()
}
ConnectingState::Consumed => false,
};

if is_ok {
let conn = self.conn.take().unwrap();
Ok((Connection(conn), ZeroRttAccepted(self.connected)))
if let ConnectingState::Active { conn, connected, .. } = std::mem::replace(&mut self.state, ConnectingState::Consumed) {
Ok((Connection(conn), ZeroRttAccepted(connected)))
} else {
unreachable!("state must be Active since is_ok is true")
}
} else {
Err(self)
}
Expand All @@ -153,14 +178,19 @@ impl Connecting {
/// [`Session`](proto::crypto::Session). For the default `rustls` session, the return value can
/// be [`downcast`](Box::downcast) to a
/// [`crypto::rustls::HandshakeData`](crate::crypto::rustls::HandshakeData).
///
/// Will panic if called after `poll` has returned `Ready`.
pub async fn handshake_data(&mut self) -> Result<Box<dyn Any>, ConnectionError> {
// Taking &mut self allows us to use a single oneshot channel rather than dealing with
// potentially many tasks waiting on the same event. It's a bit of a hack, but keeps things
// simple.
if let Some(x) = self.handshake_data_ready.take() {
let ConnectingState::Active { conn, handshake_data_ready, .. } = &mut self.state else {
panic!("used after yielding Ready");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a very confusing panic here. I see this was copied from local_ip. But this function has clear docs that would need to be added here for this to be clear for the reader. Verbatim from those ones:

/// Will panic if called after `poll` has returned `Ready`.

};

if let Some(x) = handshake_data_ready.take() {
let _ = x.await;
}
let conn = self.conn.as_ref().unwrap();
let inner = conn.lock_without_waking("handshake");
inner
.inner
Expand All @@ -186,7 +216,9 @@ impl Connecting {
///
/// Will panic if called after `poll` has returned `Ready`.
pub fn local_ip(&self) -> Option<IpAddr> {
let conn = self.conn.as_ref().expect("used after yielding Ready");
let ConnectingState::Active { conn, .. } = &self.state else {
panic!("used after yielding Ready");
};
let inner = conn.lock_without_waking("local_ip");

inner
Expand All @@ -200,8 +232,10 @@ impl Connecting {
///
/// Will panic if called after `poll` has returned `Ready`.
pub fn remote_address(&self) -> SocketAddr {
let conn_ref: &ConnectionRef = self.conn.as_ref().expect("used after yielding Ready");
conn_ref
let ConnectingState::Active { conn, .. } = &self.state else {
panic!("used after yielding Ready");
};
conn
.lock_without_waking("remote_address")
.inner
.network_path(PathId::ZERO)
Expand All @@ -213,19 +247,30 @@ impl Connecting {
impl Future for Connecting {
type Output = Result<Connection, ConnectionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.connected).poll(cx).map(|_| {
let conn = self.conn.take().unwrap();
match &mut self.state {
ConnectingState::Active { connected, .. } => {
match Pin::new(connected).poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(_) => {}
}
}
ConnectingState::Consumed => panic!("polled after yielding Ready"),
}

if let ConnectingState::Active { conn, .. } = std::mem::replace(&mut self.state, ConnectingState::Consumed) {
let inner = conn.lock_without_waking("connecting");
if inner.connected {
drop(inner);
Ok(Connection(conn))
Poll::Ready(Ok(Connection(conn)))
} else {
Err(inner
Poll::Ready(Err(inner
.error
.clone()
.expect("connected signaled without connection success or error"))
.expect("connected signaled without connection success or error")))
}
})
} else {
unreachable!()
}
}
}

Expand Down
9 changes: 9 additions & 0 deletions noq/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1762,3 +1762,12 @@ unsafe fn wake_by_ref_waker(data: *const ()) {
unsafe fn drop_waker(data: *const ()) {
drop(unsafe { Arc::<WakeCounter>::from_raw(data as *const WakeCounter) });
}

#[tokio::test]
async fn drop_connecting_cleans_up() {
let ep = endpoint();
let addr = "127.0.0.1:1234".parse().unwrap();
let connecting = ep.connect(addr, "localhost").unwrap();
drop(connecting);
ep.wait_idle().await;
}