Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions src/client/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::{
protocol::{self, Message},
tokio_transport::{TokioSpawner, TokioTimer, TokioTransport},
traits::PayloadWireFormat,
transport::Spawner,
transport::{E2ERegistryHandle, Spawner},
};

use super::error::Error;
Expand Down Expand Up @@ -290,7 +290,11 @@ impl<P: PayloadWireFormat> ControlMessage<P> {
}
}

pub(super) struct Inner<PayloadDefinitions: PayloadWireFormat, S: Spawner = TokioSpawner> {
pub(super) struct Inner<
PayloadDefinitions: PayloadWireFormat,
S: Spawner = TokioSpawner,
R: E2ERegistryHandle = Arc<Mutex<E2ERegistry>>,
> {
/// MPSC Receiver used to receive control messages from outer client
control_receiver: Receiver<ControlMessage<PayloadDefinitions>>,
/// Queue of pending control messages to process
Expand Down Expand Up @@ -322,7 +326,7 @@ pub(super) struct Inner<PayloadDefinitions: PayloadWireFormat, S: Spawner = Toki
sd_session_id: u16,
sd_session_has_wrapped: bool,
/// Shared E2E registry for runtime E2E configuration
e2e_registry: Arc<Mutex<E2ERegistry>>,
e2e_registry: R,
/// Enable multicast loopback on SD sockets for same-host testing
multicast_loopback: bool,
/// Task-spawner used by `bind_*` to drive per-socket I/O loops.
Expand All @@ -333,7 +337,7 @@ pub(super) struct Inner<PayloadDefinitions: PayloadWireFormat, S: Spawner = Toki
phantom: std::marker::PhantomData<PayloadDefinitions>,
}

impl<P: PayloadWireFormat, S: Spawner> std::fmt::Debug for Inner<P, S> {
impl<P: PayloadWireFormat, S: Spawner, R: E2ERegistryHandle> std::fmt::Debug for Inner<P, S, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Inner")
.field("interface", &self.interface)
Expand All @@ -345,10 +349,11 @@ impl<P: PayloadWireFormat, S: Spawner> std::fmt::Debug for Inner<P, S> {
}
}

impl<PayloadDefinitions, S> Inner<PayloadDefinitions, S>
impl<PayloadDefinitions, S, R> Inner<PayloadDefinitions, S, R>
where
PayloadDefinitions: PayloadWireFormat + Clone + std::fmt::Debug + 'static,
S: Spawner + Send + Sync + 'static,
R: E2ERegistryHandle,
{
/// Construct an `Inner` and return the control/update channels plus
/// the run-loop future. The caller must drive the future on a Tokio
Expand All @@ -362,7 +367,7 @@ where
/// exists yet — it's planned alongside the bare-metal port.
pub fn build(
interface: Ipv4Addr,
e2e_registry: Arc<Mutex<E2ERegistry>>,
e2e_registry: R,
multicast_loopback: bool,
spawner: S,
) -> (
Expand Down Expand Up @@ -404,7 +409,7 @@ where
&TokioTransport,
&self.spawner,
self.interface,
Arc::clone(&self.e2e_registry),
self.e2e_registry.clone(),
self.sd_session_id,
self.sd_session_has_wrapped,
self.multicast_loopback,
Expand Down Expand Up @@ -449,7 +454,7 @@ where
&TokioTransport,
&self.spawner,
port,
Arc::clone(&self.e2e_registry),
self.e2e_registry.clone(),
)
.await?;
let bound_port = unicast_socket.port();
Expand Down
72 changes: 38 additions & 34 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub use error::Error;
use crate::Timer;
use crate::e2e::{E2ECheckStatus, E2EKey, E2EProfile, E2ERegistry};
use crate::tokio_transport::{TokioSpawner, TokioTimer};
use crate::transport::Spawner;
use crate::transport::{E2ERegistryHandle, InterfaceHandle, Spawner};
use crate::{protocol, protocol::Message, traits::PayloadWireFormat};
use inner::{ControlMessage, Inner};
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
Expand Down Expand Up @@ -166,25 +166,40 @@ impl<MessageDefinitions: PayloadWireFormat> ClientUpdates<MessageDefinitions> {
///
/// `Client` is cheaply [`Clone`]-able. All clones share the same underlying
/// event loop and can be used concurrently from different tasks.
///
/// The optional type parameters `R` and `I` let callers substitute their own
/// [`E2ERegistryHandle`] and [`InterfaceHandle`] implementations (for example,
/// bare-metal handles backed by a critical-section mutex rather than
/// `Arc<Mutex<_>>`). On `std + tokio`, the defaults
/// (`Arc<Mutex<E2ERegistry>>` and `Arc<RwLock<Ipv4Addr>>`) are used by the
/// standard constructors [`Self::new`] / [`Self::new_with_loopback`] /
/// [`Self::new_with_spawner_and_loopback`].
#[derive(Clone)]
pub struct Client<MessageDefinitions: PayloadWireFormat> {
interface: Arc<RwLock<Ipv4Addr>>,
pub struct Client<
MessageDefinitions: PayloadWireFormat,
R: E2ERegistryHandle = Arc<Mutex<E2ERegistry>>,
I: InterfaceHandle = Arc<RwLock<Ipv4Addr>>,
> {
interface: I,
control_sender: mpsc::Sender<inner::ControlMessage<MessageDefinitions>>,
e2e_registry: Arc<Mutex<E2ERegistry>>,
e2e_registry: R,
}

impl<MessageDefinitions: PayloadWireFormat> std::fmt::Debug for Client<MessageDefinitions> {
impl<MessageDefinitions, R, I> std::fmt::Debug for Client<MessageDefinitions, R, I>
where
MessageDefinitions: PayloadWireFormat,
R: E2ERegistryHandle,
I: InterfaceHandle,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field(
"interface",
&*self.interface.read().expect("interface lock poisoned"),
)
.field("interface", &self.interface.get())
.finish_non_exhaustive()
}
}

impl<MessageDefinitions> Client<MessageDefinitions>
/// Constructors that create the default `Arc`-backed handles for `std + tokio`.
impl<MessageDefinitions> Client<MessageDefinitions, Arc<Mutex<E2ERegistry>>, Arc<RwLock<Ipv4Addr>>>
where
MessageDefinitions: PayloadWireFormat + Clone + std::fmt::Debug + 'static,
{
Expand Down Expand Up @@ -319,15 +334,19 @@ where
let updates = ClientUpdates { update_receiver };
(client, updates, run_future)
}
}

/// Methods available on all `Client<M, R, I>` regardless of handle types.
impl<MessageDefinitions, R, I> Client<MessageDefinitions, R, I>
where
MessageDefinitions: PayloadWireFormat + Clone + std::fmt::Debug + 'static,
R: E2ERegistryHandle,
I: InterfaceHandle,
{
/// Returns the current network interface address.
///
/// # Panics
///
/// Panics if the interface lock is poisoned.
#[must_use]
pub fn interface(&self) -> Ipv4Addr {
*self.interface.read().expect("interface lock poisoned")
self.interface.get()
}

/// Changes the network interface and rebinds sockets.
Expand All @@ -339,19 +358,14 @@ where
/// Returns [`Error::Shutdown`] if the client's run-loop future has
/// exited before this call — the control-channel send cannot
/// complete without its receiver.
///
/// # Panics
///
/// Panics if the interface lock is poisoned (indicates prior panic
/// while the lock was held).
pub async fn set_interface(&self, interface: Ipv4Addr) -> Result<(), Error> {
let (response, message) = ControlMessage::set_interface(interface);
self.control_sender
.send(message)
.await
.map_err(|_| Error::Shutdown)?;
response.await.map_err(|_| Error::Shutdown)??;
*self.interface.write().expect("interface lock poisoned") = interface;
self.interface.set(interface);
Ok(())
}

Expand Down Expand Up @@ -860,22 +874,12 @@ where
///
/// Panics if the E2E registry mutex is poisoned.

Copilot AI Apr 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The # Panics section says this panics if the E2E registry mutex is poisoned, but Client is now generic over R: E2ERegistryHandle and no longer directly uses a mutex here. This doc is now inaccurate for non-mutex handle implementations.

Suggested fix: either remove the # Panics section, or reword it to something like “May panic if the underlying handle implementation panics (e.g. a poisoned mutex in the std implementation).”

Suggested change
/// Panics if the E2E registry mutex is poisoned.
/// May panic if the underlying E2E registry handle implementation panics
/// (for example, due to a poisoned mutex in the standard implementation).

Copilot uses AI. Check for mistakes.
pub fn register_e2e(&self, key: E2EKey, profile: E2EProfile) {
self.e2e_registry
.lock()
.expect("e2e registry lock poisoned")
.register(key, profile);
self.e2e_registry.register(key, profile);
}

/// Remove E2E configuration for the given key.
///
/// # Panics
///
/// Panics if the E2E registry mutex is poisoned.
pub fn unregister_e2e(&self, key: &E2EKey) {
self.e2e_registry
.lock()
.expect("e2e registry lock poisoned")
.unregister(key);
self.e2e_registry.unregister(key);
}

/// Shuts down the client by dropping the control channel.
Expand All @@ -895,7 +899,7 @@ mod tests {
use crate::traits::WireFormat;
use std::format;

type TestClient = Client<TestPayload>;
type TestClient = Client<TestPayload, Arc<Mutex<E2ERegistry>>, Arc<RwLock<Ipv4Addr>>>;

#[tokio::test]
async fn test_client_new_and_interface() {
Expand Down
44 changes: 23 additions & 21 deletions src/client/socket_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,19 @@

use crate::{
UDP_BUFFER_SIZE,
e2e::{E2ECheckStatus, E2EKey, E2ERegistry},
e2e::{E2ECheckStatus, E2EKey},
protocol::{Message, MessageView, sd},
traits::{PayloadWireFormat, WireFormat},
transport::{ReceivedDatagram, SocketOptions, Spawner, TransportFactory, TransportSocket},
transport::{
E2ERegistryHandle, ReceivedDatagram, SocketOptions, Spawner, TransportFactory,
TransportSocket,
},
};

use super::error::Error;
use futures::{FutureExt, pin_mut, select};
use std::{
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
sync::{Arc, Mutex},
task::{Context, Poll},
};
use tokio::sync::mpsc;
Expand Down Expand Up @@ -151,9 +153,9 @@ where
/// socket through the `_with_transport` variant so the `Spawner`
/// trait can be exercised end-to-end.
#[cfg(test)]
pub async fn bind_discovery_seeded(
pub async fn bind_discovery_seeded<R: E2ERegistryHandle>(
interface: Ipv4Addr,
e2e_registry: Arc<Mutex<E2ERegistry>>,
e2e_registry: R,
session_id: u16,
session_has_wrapped: bool,
multicast_loopback: bool,
Expand Down Expand Up @@ -200,18 +202,19 @@ where
/// build a small orchestrator directly on top of `protocol`, `e2e`,
/// and the `transport` traits — the `bare_metal` example workspace
/// member demonstrates the trait layer in isolation.
pub async fn bind_discovery_seeded_with_transport<F, S>(
pub async fn bind_discovery_seeded_with_transport<F, S, R>(
factory: &F,
spawner: &S,
interface: Ipv4Addr,
e2e_registry: Arc<Mutex<E2ERegistry>>,
e2e_registry: R,
session_id: u16,
session_has_wrapped: bool,
multicast_loopback: bool,
) -> Result<Self, Error>
where
F: TransportFactory<Socket = crate::tokio_transport::TokioSocket>,
S: Spawner,
R: E2ERegistryHandle,
{
let (rx_tx, rx_rx) = mpsc::channel(16);
let (tx_tx, tx_rx) = mpsc::channel(16);
Expand Down Expand Up @@ -259,7 +262,7 @@ where
/// socket through the `_with_transport` variant so the `Spawner`
/// trait can be exercised end-to-end.
#[cfg(test)]
pub async fn bind(port: u16, e2e_registry: Arc<Mutex<E2ERegistry>>) -> Result<Self, Error> {
pub async fn bind<R: E2ERegistryHandle>(port: u16, e2e_registry: R) -> Result<Self, Error> {
use crate::tokio_transport::{TokioSpawner, TokioTransport};
Self::bind_with_transport(&TokioTransport, &TokioSpawner, port, e2e_registry).await
}
Expand All @@ -269,15 +272,16 @@ where
/// socket's I/O loop through a caller-supplied [`Spawner`]. See
/// [`Self::bind_discovery_seeded_with_transport`] for the factory
/// bound rationale.
pub async fn bind_with_transport<F, S>(
pub async fn bind_with_transport<F, S, R>(
factory: &F,
spawner: &S,
port: u16,
e2e_registry: Arc<Mutex<E2ERegistry>>,
e2e_registry: R,
) -> Result<Self, Error>
where
F: TransportFactory<Socket = crate::tokio_transport::TokioSocket>,
S: Spawner,
R: E2ERegistryHandle,
{
let (rx_tx, rx_rx) = mpsc::channel(4);
let (tx_tx, tx_rx) = mpsc::channel(4);
Expand Down Expand Up @@ -394,11 +398,11 @@ where
/// return-type notation to express `Send` bounds on the trait's
/// RPITIT methods — still nightly as of this writing.
#[allow(clippy::too_many_lines)]
async fn socket_loop_future(
async fn socket_loop_future<R: E2ERegistryHandle>(
socket: crate::tokio_transport::TokioSocket,
rx_tx: mpsc::Sender<Result<ReceivedMessage<MessageDefinitions>, Error>>,
mut tx_rx: mpsc::Receiver<SendMessage<MessageDefinitions>>,
e2e_registry: Arc<Mutex<E2ERegistry>>,
e2e_registry: R,
) {
// Maximum number of consecutive `recv_from` errors tolerated before
// the socket loop gives up. A single failure (transient I/O, peer
Expand Down Expand Up @@ -458,12 +462,11 @@ where
{
let key =
E2EKey::from_message_id(send_message.message.header().message_id());
let mut registry = e2e_registry.lock().expect("e2e registry lock poisoned");
if registry.contains_key(&key) {
if e2e_registry.contains_key(&key) {
let upper_header: [u8; 8] =
buf[8..16].try_into().expect("upper header slice");
let mut protected = [0u8; UDP_BUFFER_SIZE];
let result = registry.protect(
let result = e2e_registry.protect(
key,
&buf[16..message_length],
upper_header,
Comment on lines 468 to 472

Copilot AI Apr 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the E2E-protect path, a protect() failure is currently handled by logging and then proceeding to send the original (unprotected) message while still reporting Ok(()) to the caller. This can silently violate the contract implied by contains_key(&key) (i.e., that E2E protection is applied when configured).

Consider treating Some(Err(e)) as a send failure: reply on send_message.response with Err(Error::E2e(e)) (or another appropriate error) and continue so the datagram is not sent unprotected.

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -553,14 +556,11 @@ where
let payload_bytes = view.payload_bytes();

// Apply E2E check if configured
let (e2e_status, effective_payload) = {
let mut registry =
e2e_registry.lock().expect("e2e registry lock poisoned");
match registry.check(key, payload_bytes, upper_header) {
let (e2e_status, effective_payload) =
match e2e_registry.check(key, payload_bytes, upper_header) {
Some((status, stripped)) => (Some(status), stripped),
None => (None, payload_bytes),
}
};
};

let payload = MessageDefinitions::from_payload_bytes(
header.message_id(),
Expand Down Expand Up @@ -607,9 +607,11 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::e2e::E2ERegistry;
use crate::protocol::sd::test_support::{TestPayload, empty_sd_header};
use crate::tokio_transport::TokioSpawner;
use std::format;
use std::sync::{Arc, Mutex};
use std::vec;
// Tests build ad-hoc UDP peers via tokio directly; this is not part of
// the production code path, which goes through the `TransportSocket`
Expand Down
6 changes: 4 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ pub use server::Server;
#[cfg(any(feature = "client", feature = "server"))]
pub use tokio_transport::{TokioSocket, TokioSpawner, TokioTimer, TokioTransport};
pub use transport::{
IoErrorKind, ReceivedDatagram, SocketOptions, Spawner, Timer, TransportError, TransportFactory,
TransportSocket,
E2ERegistryHandle, InterfaceHandle, IoErrorKind, ReceivedDatagram, SocketOptions, Spawner,
Timer, TransportError, TransportFactory, TransportSocket,
};
#[cfg(feature = "server")]
pub use server::SubscriptionHandle;
Loading
Loading