From 43118c7f543b9d6a58ac171a4ebdc835d28d2ff6 Mon Sep 17 00:00:00 2001 From: Justin Kovacich Date: Mon, 27 Apr 2026 11:38:42 -0400 Subject: [PATCH] =?UTF-8?q?phase=2010:=20lock-handle=20abstraction=20(Arc>=20=E2=86=92=20trait=20handles)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces three new traits in transport.rs and subscription_manager.rs: - E2ERegistryHandle — wraps Arc> on std, allows alternative implementations for bare-metal targets - InterfaceHandle — wraps Arc> on client - SubscriptionHandle — wraps Arc> on server Client and Server / EventPublisher are now generic over these handles with the existing Arc-backed types as defaults, so all existing call sites compile unchanged. Std implementations live in tokio_transport.rs. Gate: all production lock sites route through handle traits; cargo test --all-features passes (454 unit + 11 integration tests). Co-Authored-By: Claude Sonnet 4.6 --- src/client/inner.rs | 21 ++- src/client/mod.rs | 72 +++++---- src/client/socket_manager.rs | 44 +++--- src/lib.rs | 6 +- src/server/event_publisher.rs | 65 ++++---- src/server/mod.rs | 239 +++++++++++++---------------- src/server/subscription_manager.rs | 94 +++++++++++- src/tokio_transport.rs | 55 ++++++- src/transport.rs | 126 +++++++++++++++ 9 files changed, 489 insertions(+), 233 deletions(-) diff --git a/src/client/inner.rs b/src/client/inner.rs index 4234abb..e822a1c 100644 --- a/src/client/inner.rs +++ b/src/client/inner.rs @@ -25,7 +25,7 @@ use crate::{ protocol::{self, Message}, tokio_transport::{TokioSpawner, TokioTimer, TokioTransport}, traits::PayloadWireFormat, - transport::Spawner, + transport::{E2ERegistryHandle, Spawner}, }; use super::error::Error; @@ -290,7 +290,11 @@ impl ControlMessage

{ } } -pub(super) struct Inner { +pub(super) struct Inner< + PayloadDefinitions: PayloadWireFormat, + S: Spawner = TokioSpawner, + R: E2ERegistryHandle = Arc>, +> { /// MPSC Receiver used to receive control messages from outer client control_receiver: Receiver>, /// Queue of pending control messages to process @@ -322,7 +326,7 @@ pub(super) struct Inner>, + e2e_registry: R, /// Enable multicast loopback on SD sockets for same-host testing multicast_loopback: bool, /// Task-spawner used by `bind_*` to drive per-socket I/O loops. @@ -333,7 +337,7 @@ pub(super) struct Inner, } -impl std::fmt::Debug for Inner { +impl std::fmt::Debug for Inner { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Inner") .field("interface", &self.interface) @@ -345,10 +349,11 @@ impl std::fmt::Debug for Inner { } } -impl Inner +impl Inner where PayloadDefinitions: PayloadWireFormat + Clone + std::fmt::Debug + 'static, S: Spawner + Send + Sync + 'static, + R: E2ERegistryHandle, { /// Construct an `Inner` and return the control/update channels plus /// the run-loop future. The caller must drive the future on a Tokio @@ -362,7 +367,7 @@ where /// exists yet — it's planned alongside the bare-metal port. pub fn build( interface: Ipv4Addr, - e2e_registry: Arc>, + e2e_registry: R, multicast_loopback: bool, spawner: S, ) -> ( @@ -404,7 +409,7 @@ where &TokioTransport, &self.spawner, self.interface, - Arc::clone(&self.e2e_registry), + self.e2e_registry.clone(), self.sd_session_id, self.sd_session_has_wrapped, self.multicast_loopback, @@ -449,7 +454,7 @@ where &TokioTransport, &self.spawner, port, - Arc::clone(&self.e2e_registry), + self.e2e_registry.clone(), ) .await?; let bound_port = unicast_socket.port(); diff --git a/src/client/mod.rs b/src/client/mod.rs index 15453fe..9545603 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -39,7 +39,7 @@ pub use error::Error; use crate::Timer; use crate::e2e::{E2ECheckStatus, E2EKey, E2EProfile, E2ERegistry}; use crate::tokio_transport::{TokioSpawner, TokioTimer}; -use crate::transport::Spawner; +use crate::transport::{E2ERegistryHandle, InterfaceHandle, Spawner}; use crate::{protocol, protocol::Message, traits::PayloadWireFormat}; use inner::{ControlMessage, Inner}; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; @@ -166,25 +166,40 @@ impl ClientUpdates { /// /// `Client` is cheaply [`Clone`]-able. All clones share the same underlying /// event loop and can be used concurrently from different tasks. +/// +/// The optional type parameters `R` and `I` let callers substitute their own +/// [`E2ERegistryHandle`] and [`InterfaceHandle`] implementations (for example, +/// bare-metal handles backed by a critical-section mutex rather than +/// `Arc>`). On `std + tokio`, the defaults +/// (`Arc>` and `Arc>`) are used by the +/// standard constructors [`Self::new`] / [`Self::new_with_loopback`] / +/// [`Self::new_with_spawner_and_loopback`]. #[derive(Clone)] -pub struct Client { - interface: Arc>, +pub struct Client< + MessageDefinitions: PayloadWireFormat, + R: E2ERegistryHandle = Arc>, + I: InterfaceHandle = Arc>, +> { + interface: I, control_sender: mpsc::Sender>, - e2e_registry: Arc>, + e2e_registry: R, } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client +where + MessageDefinitions: PayloadWireFormat, + R: E2ERegistryHandle, + I: InterfaceHandle, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") - .field( - "interface", - &*self.interface.read().expect("interface lock poisoned"), - ) + .field("interface", &self.interface.get()) .finish_non_exhaustive() } } -impl Client +/// Constructors that create the default `Arc`-backed handles for `std + tokio`. +impl Client>, Arc>> where MessageDefinitions: PayloadWireFormat + Clone + std::fmt::Debug + 'static, { @@ -319,15 +334,19 @@ where let updates = ClientUpdates { update_receiver }; (client, updates, run_future) } +} +/// Methods available on all `Client` regardless of handle types. +impl Client +where + MessageDefinitions: PayloadWireFormat + Clone + std::fmt::Debug + 'static, + R: E2ERegistryHandle, + I: InterfaceHandle, +{ /// Returns the current network interface address. - /// - /// # Panics - /// - /// Panics if the interface lock is poisoned. #[must_use] pub fn interface(&self) -> Ipv4Addr { - *self.interface.read().expect("interface lock poisoned") + self.interface.get() } /// Changes the network interface and rebinds sockets. @@ -339,11 +358,6 @@ where /// Returns [`Error::Shutdown`] if the client's run-loop future has /// exited before this call — the control-channel send cannot /// complete without its receiver. - /// - /// # Panics - /// - /// Panics if the interface lock is poisoned (indicates prior panic - /// while the lock was held). pub async fn set_interface(&self, interface: Ipv4Addr) -> Result<(), Error> { let (response, message) = ControlMessage::set_interface(interface); self.control_sender @@ -351,7 +365,7 @@ where .await .map_err(|_| Error::Shutdown)?; response.await.map_err(|_| Error::Shutdown)??; - *self.interface.write().expect("interface lock poisoned") = interface; + self.interface.set(interface); Ok(()) } @@ -860,22 +874,12 @@ where /// /// Panics if the E2E registry mutex is poisoned. pub fn register_e2e(&self, key: E2EKey, profile: E2EProfile) { - self.e2e_registry - .lock() - .expect("e2e registry lock poisoned") - .register(key, profile); + self.e2e_registry.register(key, profile); } /// Remove E2E configuration for the given key. - /// - /// # Panics - /// - /// Panics if the E2E registry mutex is poisoned. pub fn unregister_e2e(&self, key: &E2EKey) { - self.e2e_registry - .lock() - .expect("e2e registry lock poisoned") - .unregister(key); + self.e2e_registry.unregister(key); } /// Shuts down the client by dropping the control channel. @@ -895,7 +899,7 @@ mod tests { use crate::traits::WireFormat; use std::format; - type TestClient = Client; + type TestClient = Client>, Arc>>; #[tokio::test] async fn test_client_new_and_interface() { diff --git a/src/client/socket_manager.rs b/src/client/socket_manager.rs index 6966a09..f06d625 100644 --- a/src/client/socket_manager.rs +++ b/src/client/socket_manager.rs @@ -51,17 +51,19 @@ use crate::{ UDP_BUFFER_SIZE, - e2e::{E2ECheckStatus, E2EKey, E2ERegistry}, + e2e::{E2ECheckStatus, E2EKey}, protocol::{Message, MessageView, sd}, traits::{PayloadWireFormat, WireFormat}, - transport::{ReceivedDatagram, SocketOptions, Spawner, TransportFactory, TransportSocket}, + transport::{ + E2ERegistryHandle, ReceivedDatagram, SocketOptions, Spawner, TransportFactory, + TransportSocket, + }, }; use super::error::Error; use futures::{FutureExt, pin_mut, select}; use std::{ net::{Ipv4Addr, SocketAddr, SocketAddrV4}, - sync::{Arc, Mutex}, task::{Context, Poll}, }; use tokio::sync::mpsc; @@ -151,9 +153,9 @@ where /// socket through the `_with_transport` variant so the `Spawner` /// trait can be exercised end-to-end. #[cfg(test)] - pub async fn bind_discovery_seeded( + pub async fn bind_discovery_seeded( interface: Ipv4Addr, - e2e_registry: Arc>, + e2e_registry: R, session_id: u16, session_has_wrapped: bool, multicast_loopback: bool, @@ -200,11 +202,11 @@ where /// build a small orchestrator directly on top of `protocol`, `e2e`, /// and the `transport` traits — the `bare_metal` example workspace /// member demonstrates the trait layer in isolation. - pub async fn bind_discovery_seeded_with_transport( + pub async fn bind_discovery_seeded_with_transport( factory: &F, spawner: &S, interface: Ipv4Addr, - e2e_registry: Arc>, + e2e_registry: R, session_id: u16, session_has_wrapped: bool, multicast_loopback: bool, @@ -212,6 +214,7 @@ where where F: TransportFactory, S: Spawner, + R: E2ERegistryHandle, { let (rx_tx, rx_rx) = mpsc::channel(16); let (tx_tx, tx_rx) = mpsc::channel(16); @@ -259,7 +262,7 @@ where /// socket through the `_with_transport` variant so the `Spawner` /// trait can be exercised end-to-end. #[cfg(test)] - pub async fn bind(port: u16, e2e_registry: Arc>) -> Result { + pub async fn bind(port: u16, e2e_registry: R) -> Result { use crate::tokio_transport::{TokioSpawner, TokioTransport}; Self::bind_with_transport(&TokioTransport, &TokioSpawner, port, e2e_registry).await } @@ -269,15 +272,16 @@ where /// socket's I/O loop through a caller-supplied [`Spawner`]. See /// [`Self::bind_discovery_seeded_with_transport`] for the factory /// bound rationale. - pub async fn bind_with_transport( + pub async fn bind_with_transport( factory: &F, spawner: &S, port: u16, - e2e_registry: Arc>, + e2e_registry: R, ) -> Result where F: TransportFactory, S: Spawner, + R: E2ERegistryHandle, { let (rx_tx, rx_rx) = mpsc::channel(4); let (tx_tx, tx_rx) = mpsc::channel(4); @@ -394,11 +398,11 @@ where /// return-type notation to express `Send` bounds on the trait's /// RPITIT methods — still nightly as of this writing. #[allow(clippy::too_many_lines)] - async fn socket_loop_future( + async fn socket_loop_future( socket: crate::tokio_transport::TokioSocket, rx_tx: mpsc::Sender, Error>>, mut tx_rx: mpsc::Receiver>, - e2e_registry: Arc>, + e2e_registry: R, ) { // Maximum number of consecutive `recv_from` errors tolerated before // the socket loop gives up. A single failure (transient I/O, peer @@ -458,12 +462,11 @@ where { let key = E2EKey::from_message_id(send_message.message.header().message_id()); - let mut registry = e2e_registry.lock().expect("e2e registry lock poisoned"); - if registry.contains_key(&key) { + if e2e_registry.contains_key(&key) { let upper_header: [u8; 8] = buf[8..16].try_into().expect("upper header slice"); let mut protected = [0u8; UDP_BUFFER_SIZE]; - let result = registry.protect( + let result = e2e_registry.protect( key, &buf[16..message_length], upper_header, @@ -553,14 +556,11 @@ where let payload_bytes = view.payload_bytes(); // Apply E2E check if configured - let (e2e_status, effective_payload) = { - let mut registry = - e2e_registry.lock().expect("e2e registry lock poisoned"); - match registry.check(key, payload_bytes, upper_header) { + let (e2e_status, effective_payload) = + match e2e_registry.check(key, payload_bytes, upper_header) { Some((status, stripped)) => (Some(status), stripped), None => (None, payload_bytes), - } - }; + }; let payload = MessageDefinitions::from_payload_bytes( header.message_id(), @@ -607,9 +607,11 @@ where #[cfg(test)] mod tests { use super::*; + use crate::e2e::E2ERegistry; use crate::protocol::sd::test_support::{TestPayload, empty_sd_header}; use crate::tokio_transport::TokioSpawner; use std::format; + use std::sync::{Arc, Mutex}; use std::vec; // Tests build ad-hoc UDP peers via tokio directly; this is not part of // the production code path, which goes through the `TransportSocket` diff --git a/src/lib.rs b/src/lib.rs index e0d67b4..477e43c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -172,6 +172,8 @@ pub use server::Server; #[cfg(any(feature = "client", feature = "server"))] pub use tokio_transport::{TokioSocket, TokioSpawner, TokioTimer, TokioTransport}; pub use transport::{ - IoErrorKind, ReceivedDatagram, SocketOptions, Spawner, Timer, TransportError, TransportFactory, - TransportSocket, + E2ERegistryHandle, InterfaceHandle, IoErrorKind, ReceivedDatagram, SocketOptions, Spawner, + Timer, TransportError, TransportFactory, TransportSocket, }; +#[cfg(feature = "server")] +pub use server::SubscriptionHandle; diff --git a/src/server/event_publisher.rs b/src/server/event_publisher.rs index 683d47b..2181f7d 100644 --- a/src/server/event_publisher.rs +++ b/src/server/event_publisher.rs @@ -1,29 +1,29 @@ //! Event publishing functionality use super::Error; -use super::subscription_manager::SubscriptionManager; +use super::subscription_manager::{SubscriptionHandle, SubscriptionManager}; use crate::UDP_BUFFER_SIZE; use crate::e2e::{E2EKey, E2ERegistry}; use crate::protocol::{Header, Message}; use crate::traits::{PayloadWireFormat, WireFormat}; +use crate::transport::E2ERegistryHandle; use std::sync::{Arc, Mutex}; use tokio::net::UdpSocket; use tokio::sync::RwLock; /// Publishes events to subscribers -pub struct EventPublisher { - subscriptions: Arc>, +pub struct EventPublisher< + R: E2ERegistryHandle = Arc>, + S: SubscriptionHandle = Arc>, +> { + subscriptions: S, socket: Arc, - e2e_registry: Arc>, + e2e_registry: R, } -impl EventPublisher { +impl EventPublisher { /// Create a new event publisher - pub fn new( - subscriptions: Arc>, - socket: Arc, - e2e_registry: Arc>, - ) -> Self { + pub fn new(subscriptions: S, socket: Arc, e2e_registry: R) -> Self { Self { subscriptions, socket, @@ -54,10 +54,10 @@ impl EventPublisher { message: &Message

, ) -> Result { // Get subscribers - let subscribers = { - let mgr = self.subscriptions.read().await; - mgr.get_subscribers(service_id, instance_id, event_group_id) - }; + let subscribers = self + .subscriptions + .get_subscribers(service_id, instance_id, event_group_id) + .await; if subscribers.is_empty() { tracing::trace!( @@ -96,14 +96,10 @@ impl EventPublisher { // directly out of `buffer[16..]` without a separate copy. { let key = E2EKey::from_message_id(message.header().message_id()); - let mut registry = self - .e2e_registry - .lock() - .expect("e2e registry lock poisoned"); - if registry.contains_key(&key) { + if self.e2e_registry.contains_key(&key) { let upper_header: [u8; 8] = buffer[8..16].try_into().expect("upper header slice"); let mut protected = [0u8; UDP_BUFFER_SIZE]; - let result = registry.protect( + let result = self.e2e_registry.protect( key, &buffer[16..message_length], upper_header, @@ -196,10 +192,10 @@ impl EventPublisher { payload: &[u8], ) -> Result { // Get subscribers - let subscribers = { - let mgr = self.subscriptions.read().await; - mgr.get_subscribers(service_id, instance_id, event_group_id) - }; + let subscribers = self + .subscriptions + .get_subscribers(service_id, instance_id, event_group_id) + .await; if subscribers.is_empty() { return Ok(0); @@ -293,8 +289,10 @@ impl EventPublisher { instance_id: u16, event_group_id: u16, ) -> bool { - let mgr = self.subscriptions.read().await; - !mgr.get_subscribers(service_id, instance_id, event_group_id) + !self + .subscriptions + .get_subscribers(service_id, instance_id, event_group_id) + .await .is_empty() } @@ -346,8 +344,9 @@ impl EventPublisher { event_group_id: u16, subscriber_addr: std::net::SocketAddrV4, ) -> Result<(), crate::server::SubscribeError> { - let mut mgr = self.subscriptions.write().await; - mgr.subscribe(service_id, instance_id, event_group_id, subscriber_addr) + self.subscriptions + .subscribe(service_id, instance_id, event_group_id, subscriber_addr) + .await } /// Remove a previously-registered subscriber from an event group. @@ -367,8 +366,9 @@ impl EventPublisher { event_group_id: u16, subscriber_addr: std::net::SocketAddrV4, ) { - let mut mgr = self.subscriptions.write().await; - mgr.unsubscribe(service_id, instance_id, event_group_id, subscriber_addr); + self.subscriptions + .unsubscribe(service_id, instance_id, event_group_id, subscriber_addr) + .await; } /// Get the current number of subscribers for a specific event group @@ -378,8 +378,9 @@ impl EventPublisher { instance_id: u16, event_group_id: u16, ) -> usize { - let mgr = self.subscriptions.read().await; - mgr.get_subscribers(service_id, instance_id, event_group_id) + self.subscriptions + .get_subscribers(service_id, instance_id, event_group_id) + .await .len() } } diff --git a/src/server/mod.rs b/src/server/mod.rs index 9b1ac18..f871764 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -15,7 +15,7 @@ mod subscription_manager; pub use error::Error; pub use event_publisher::EventPublisher; pub use service_info::{EventGroupInfo, ServiceInfo}; -pub use subscription_manager::{SubscribeError, SubscriptionManager}; +pub use subscription_manager::{SubscribeError, SubscriptionHandle, SubscriptionManager}; use sd_state::SdStateManager; @@ -23,6 +23,7 @@ use crate::Timer; use crate::e2e::{E2EKey, E2EProfile, E2ERegistry}; use crate::protocol::sd::{self, Entry, Flags, OptionsCount, ServiceEntry, TransportProtocol}; use crate::tokio_transport::TokioTimer; +use crate::transport::E2ERegistryHandle; use futures::{FutureExt, pin_mut, select}; use std::{ format, @@ -69,20 +70,23 @@ impl ServerConfig { } /// SOME/IP Server that can offer services and publish events -pub struct Server { +pub struct Server< + R: E2ERegistryHandle = Arc>, + S: SubscriptionHandle = Arc>, +> { config: ServerConfig, /// Socket for receiving subscription requests unicast_socket: Arc, /// Socket for sending SD announcements sd_socket: Arc, /// Subscription manager - subscriptions: Arc>, + subscriptions: S, /// Event publisher - publisher: Arc, + publisher: Arc>, /// SD session-ID counter and announcement emitter sd_state: Arc, /// Shared E2E registry for runtime E2E configuration - e2e_registry: Arc>, + e2e_registry: R, /// `true` if this server was constructed via [`Server::new_passive`]. /// Passive servers have no real SD socket bound to port 30490; their /// SD handling is managed externally. Calling [`Self::announcement_loop`] @@ -177,12 +181,13 @@ impl Server { ); } - let subscriptions = Arc::new(RwLock::new(SubscriptionManager::new())); - let e2e_registry = Arc::new(Mutex::new(E2ERegistry::new())); + let subscriptions: Arc> = + Arc::new(RwLock::new(SubscriptionManager::new())); + let e2e_registry: Arc> = Arc::new(Mutex::new(E2ERegistry::new())); let publisher = Arc::new(EventPublisher::new( - Arc::clone(&subscriptions), + subscriptions.clone(), Arc::clone(&unicast_socket), - Arc::clone(&e2e_registry), + e2e_registry.clone(), )); Ok(Self { @@ -246,12 +251,13 @@ impl Server { sd_socket.local_addr() ); - let subscriptions = Arc::new(RwLock::new(SubscriptionManager::new())); - let e2e_registry = Arc::new(Mutex::new(E2ERegistry::new())); + let subscriptions: Arc> = + Arc::new(RwLock::new(SubscriptionManager::new())); + let e2e_registry: Arc> = Arc::new(Mutex::new(E2ERegistry::new())); let publisher = Arc::new(EventPublisher::new( - Arc::clone(&subscriptions), + subscriptions.clone(), Arc::clone(&unicast_socket), - Arc::clone(&e2e_registry), + e2e_registry.clone(), )); Ok(Self { @@ -265,7 +271,9 @@ impl Server { is_passive: true, }) } +} +impl Server { /// Build the periodic-SD-announcement future. /// /// Returns a future that sends an `OfferService` message to the SD @@ -391,7 +399,7 @@ impl Server { /// Get the event publisher for sending events #[must_use] - pub fn publisher(&self) -> Arc { + pub fn publisher(&self) -> Arc> { Arc::clone(&self.publisher) } @@ -413,27 +421,13 @@ impl Server { /// /// Once registered, outgoing events published via [`EventPublisher::publish_event`] /// will have E2E protection applied automatically. - /// - /// # Panics - /// - /// Panics if the E2E registry mutex is poisoned. pub fn register_e2e(&self, key: E2EKey, profile: E2EProfile) { - self.e2e_registry - .lock() - .expect("e2e registry lock poisoned") - .register(key, profile); + self.e2e_registry.register(key, profile); } /// Remove E2E configuration for the given key. - /// - /// # Panics - /// - /// Panics if the E2E registry mutex is poisoned. pub fn unregister_e2e(&self, key: &E2EKey) { - self.e2e_registry - .lock() - .expect("e2e registry lock poisoned") - .unregister(key); + self.e2e_registry.unregister(key); } /// Run the server event loop @@ -643,24 +637,22 @@ impl Server { let first_count = entry_view.options_count().first_options_count as usize; let second_index = entry_view.index_second_options_run() as usize; let second_count = entry_view.options_count().second_options_count as usize; - if let Some(endpoint_addr) = Self::extract_subscriber_endpoint( + if let Some(endpoint_addr) = extract_subscriber_endpoint( &sd_view.options(), first_index, first_count, second_index, second_count, ) { - let mut subs = self.subscriptions.write().await; - let subscribe_result = subs.subscribe( - entry_view.service_id(), - entry_view.instance_id(), - entry_view.event_group_id(), - endpoint_addr, - ); - // Release the write lock before any await on the - // SD socket (keeps this arm off the lock while we - // emit the response). - drop(subs); + let subscribe_result = self + .subscriptions + .subscribe( + entry_view.service_id(), + entry_view.instance_id(), + entry_view.event_group_id(), + endpoint_addr, + ) + .await; match subscribe_result { Ok(()) => { @@ -726,94 +718,75 @@ impl Server { Ok(()) } +} - /// Extract a single subscriber endpoint from the options runs - /// associated with an SD entry. - /// - /// Each SD entry owns up to two options runs. A run is a contiguous - /// slice of the options array starting at `index_*_options_run` with - /// `*_options_count` entries. This helper walks both runs, collects - /// every `IpV4Endpoint` option it finds, returns the first, and logs - /// a `warn!` if more than one endpoint is present (we do not yet - /// support multi-endpoint subscribers — e.g. TCP+UDP — and will pick - /// an arbitrary one). - /// - /// Returns `None` if no `IpV4Endpoint` is found in either run. - fn extract_subscriber_endpoint( - options: &sd::OptionIter<'_>, - first_index: usize, - first_count: usize, - second_index: usize, - second_count: usize, - ) -> Option { - // Walk each run by cloning the iterator — `OptionIter` is a - // cheap view over borrowed bytes so `clone` is free. Taking - // `options` by reference lets the caller keep ownership and - // keeps the clippy `needless_pass_by_value` lint quiet. - // - // We only ever return the first `IpV4Endpoint` found, so rather - // than collect into a `Vec` (heap alloc on every Subscribe) we - // track the first hit in an `Option` and keep a count so the - // multi-endpoint warn path still reports how many additional - // endpoints were present. This keeps the SD receive loop - // allocation-free on the happy path. - let mut first_endpoint: Option = None; - let mut endpoint_count: usize = 0; - let mut ignored_other: usize = 0; - - let mut walk_run = |index: usize, count: usize| { - if count == 0 { - return; - } - for option_view in options.clone().skip(index).take(count) { - match option_view.option_type() { - Ok(sd::OptionType::IpV4Endpoint) => { - if let Ok((ip, _, port)) = option_view.as_ipv4() { - endpoint_count += 1; - if first_endpoint.is_none() { - first_endpoint = Some(SocketAddrV4::new(ip, port)); - } +/// Extract a single subscriber endpoint from the options runs associated with +/// an SD entry. Walks both option runs, returns the first `IpV4Endpoint` +/// found, and logs a `warn!` if more than one is present. +fn extract_subscriber_endpoint( + options: &sd::OptionIter<'_>, + first_index: usize, + first_count: usize, + second_index: usize, + second_count: usize, +) -> Option { + let mut first_endpoint: Option = None; + let mut endpoint_count: usize = 0; + let mut ignored_other: usize = 0; + + let mut walk_run = |index: usize, count: usize| { + if count == 0 { + return; + } + for option_view in options.clone().skip(index).take(count) { + match option_view.option_type() { + Ok(sd::OptionType::IpV4Endpoint) => { + if let Ok((ip, _, port)) = option_view.as_ipv4() { + endpoint_count += 1; + if first_endpoint.is_none() { + first_endpoint = Some(SocketAddrV4::new(ip, port)); } } - Ok(_) | Err(_) => ignored_other += 1, } + Ok(_) | Err(_) => ignored_other += 1, } - }; + } + }; - walk_run(first_index, first_count); - walk_run(second_index, second_count); + walk_run(first_index, first_count); + walk_run(second_index, second_count); - match endpoint_count { - 0 => { - tracing::warn!( - "No IPv4 endpoint in options runs \ - (first: idx={first_index}, count={first_count}; \ - second: idx={second_index}, count={second_count}; \ - ignored={ignored_other})" - ); - None - } - 1 => { - // Unwrap is safe: count == 1 implies we set `first_endpoint`. - let ep = first_endpoint.expect("endpoint_count=1 implies first_endpoint is Some"); - tracing::trace!("Found IPv4 endpoint {}", ep); - Some(ep) - } - n => { - let ep = first_endpoint.expect("endpoint_count>=1 implies first_endpoint is Some"); - tracing::warn!( - "{} IPv4 endpoints found in subscribe options runs; \ - using first ({}) and ignoring {} additional. \ - Multi-endpoint (e.g. TCP+UDP) subscribers are not yet supported.", - n, - ep, - n - 1 - ); - Some(ep) - } + match endpoint_count { + 0 => { + tracing::warn!( + "No IPv4 endpoint in options runs \ + (first: idx={first_index}, count={first_count}; \ + second: idx={second_index}, count={second_count}; \ + ignored={ignored_other})" + ); + None + } + 1 => { + let ep = first_endpoint.expect("endpoint_count=1 implies first_endpoint is Some"); + tracing::trace!("Found IPv4 endpoint {}", ep); + Some(ep) + } + n => { + let ep = first_endpoint.expect("endpoint_count>=1 implies first_endpoint is Some"); + tracing::warn!( + "{} IPv4 endpoints found in subscribe options runs; \ + using first ({}) and ignoring {} additional. \ + Multi-endpoint (e.g. TCP+UDP) subscribers are not yet supported.", + n, + ep, + n - 1 + ); + Some(ep) } } +} +impl Server { /// Send `SubscribeAck` from an entry view async fn send_subscribe_ack_from_view( &self, @@ -1667,7 +1640,7 @@ mod tests { let total = fill_ipv4_endpoints(&mut buf, 1, 30000); let iter = sd::OptionIter::new(&buf[..total]); - let got = Server::extract_subscriber_endpoint(&iter, 0, 1, 0, 0); + let got = extract_subscriber_endpoint(&iter, 0, 1, 0, 0); assert_eq!( got, Some(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 30000)) @@ -1677,7 +1650,7 @@ mod tests { #[test] fn extract_endpoint_zero_options_in_both_runs_returns_none() { let iter = sd::OptionIter::new(&[]); - assert_eq!(Server::extract_subscriber_endpoint(&iter, 0, 0, 0, 0), None); + assert_eq!(extract_subscriber_endpoint(&iter, 0, 0, 0, 0), None); } #[test] @@ -1689,7 +1662,7 @@ mod tests { let total = fill_ipv4_endpoints(&mut buf, 2, 30100); let iter = sd::OptionIter::new(&buf[..total]); - assert_eq!(Server::extract_subscriber_endpoint(&iter, 1, 0, 0, 0), None); + assert_eq!(extract_subscriber_endpoint(&iter, 1, 0, 0, 0), None); } #[test] @@ -1701,7 +1674,7 @@ mod tests { let total = fill_ipv4_endpoints(&mut buf, 2, 30200); let iter = sd::OptionIter::new(&buf[..total]); - let got = Server::extract_subscriber_endpoint(&iter, 0, 2, 0, 0); + let got = extract_subscriber_endpoint(&iter, 0, 2, 0, 0); assert_eq!( got, Some(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 30200)) @@ -1720,7 +1693,7 @@ mod tests { let total = fill_ipv4_endpoints(&mut buf, 3, 30300); let iter = sd::OptionIter::new(&buf[..total]); - let got = Server::extract_subscriber_endpoint(&iter, 0, 1, 2, 1); + let got = extract_subscriber_endpoint(&iter, 0, 1, 2, 1); assert_eq!( got, Some(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 30300)) @@ -1735,7 +1708,7 @@ mod tests { let total = fill_ipv4_endpoints(&mut buf, 4, 30400); let iter = sd::OptionIter::new(&buf[..total]); - let got = Server::extract_subscriber_endpoint(&iter, 2, 1, 0, 0); + let got = extract_subscriber_endpoint(&iter, 2, 1, 0, 0); assert_eq!( got, Some(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 30402)) @@ -1751,7 +1724,7 @@ mod tests { let iter = sd::OptionIter::new(&buf[..total]); // Take only 1 option starting at index 1 -> port 30501. - let got = Server::extract_subscriber_endpoint(&iter, 1, 1, 0, 0); + let got = extract_subscriber_endpoint(&iter, 1, 1, 0, 0); assert_eq!( got, Some(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 30501)) @@ -1775,7 +1748,7 @@ mod tests { offset += write_load_balancing_option(&mut buf[offset..], 3, 4); let iter = sd::OptionIter::new(&buf[..offset]); - let got = Server::extract_subscriber_endpoint(&iter, 0, 3, 0, 0); + let got = extract_subscriber_endpoint(&iter, 0, 3, 0, 0); assert_eq!( got, Some(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 30600)) @@ -1790,7 +1763,7 @@ mod tests { offset += write_load_balancing_option(&mut buf[offset..], 3, 4); let iter = sd::OptionIter::new(&buf[..offset]); - assert_eq!(Server::extract_subscriber_endpoint(&iter, 0, 2, 0, 0), None); + assert_eq!(extract_subscriber_endpoint(&iter, 0, 2, 0, 0), None); } #[test] @@ -1801,7 +1774,7 @@ mod tests { let total = fill_ipv4_endpoints(&mut buf, 2, 30700); let iter = sd::OptionIter::new(&buf[..total]); - let got = Server::extract_subscriber_endpoint(&iter, 0, 0, 1, 1); + let got = extract_subscriber_endpoint(&iter, 0, 0, 1, 1); assert_eq!( got, Some(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 30701)) @@ -2268,7 +2241,7 @@ mod tests { // 0 endpoints → warn! "No IPv4 endpoint" branch. let iter_empty = sd::OptionIter::new(&[]); assert_eq!( - Server::extract_subscriber_endpoint(&iter_empty, 0, 0, 0, 0), + extract_subscriber_endpoint(&iter_empty, 0, 0, 0, 0), None ); @@ -2277,7 +2250,7 @@ mod tests { let len_one = fill_ipv4_endpoints(&mut buf_one, 1, 31000); let iter_one = sd::OptionIter::new(&buf_one[..len_one]); assert_eq!( - Server::extract_subscriber_endpoint(&iter_one, 0, 1, 0, 0), + extract_subscriber_endpoint(&iter_one, 0, 1, 0, 0), Some(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 31000)) ); @@ -2286,7 +2259,7 @@ mod tests { let len_many = fill_ipv4_endpoints(&mut buf_many, 3, 31100); let iter_many = sd::OptionIter::new(&buf_many[..len_many]); assert_eq!( - Server::extract_subscriber_endpoint(&iter_many, 0, 3, 0, 0), + extract_subscriber_endpoint(&iter_many, 0, 3, 0, 0), Some(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 31100)) ); }); diff --git a/src/server/subscription_manager.rs b/src/server/subscription_manager.rs index bdf548c..d561b83 100644 --- a/src/server/subscription_manager.rs +++ b/src/server/subscription_manager.rs @@ -1,8 +1,10 @@ //! Manages event group subscriptions use super::service_info::Subscriber; +use core::future::Future; use heapless::{Vec as HeaplessVec, index_map::FnvIndexMap}; -use std::{net::SocketAddrV4, vec::Vec}; +use std::{net::SocketAddrV4, sync::Arc, vec::Vec}; +use tokio::sync::RwLock; /// Max number of distinct `(service_id, instance_id, event_group_id)` event /// groups with active subscribers. Must be a power of two. @@ -254,6 +256,96 @@ impl Default for SubscriptionManager { } } +/// Shared handle to the server's subscription table. +/// +/// Abstracts over `Arc>` on `std` and over +/// critical-section-backed equivalents on bare metal. All methods return +/// futures so the implementation can block on an async read/write lock +/// without holding a guard across an `await` point visible to callers. +/// +/// Both `Server` and `EventPublisher` clone the same handle at construction +/// time; the underlying subscription state is shared between them. +pub trait SubscriptionHandle: Clone + Send + Sync + 'static { + /// Add a subscriber to an event group. + /// + /// Idempotent: if the subscriber is already present, this is a no-op + /// returning `Ok(())`. Returns `Err(SubscribeError)` if a capacity + /// limit would be exceeded. + fn subscribe( + &self, + service_id: u16, + instance_id: u16, + event_group_id: u16, + subscriber_addr: SocketAddrV4, + ) -> impl Future> + Send + '_; + + /// Remove a subscriber from an event group. + fn unsubscribe( + &self, + service_id: u16, + instance_id: u16, + event_group_id: u16, + subscriber_addr: SocketAddrV4, + ) -> impl Future + Send + '_; + + /// Returns a snapshot of all subscribers for the given event group. + /// + /// The snapshot is owned — the caller may iterate over it after this + /// future resolves without holding any lock. + fn get_subscribers( + &self, + service_id: u16, + instance_id: u16, + event_group_id: u16, + ) -> impl Future> + Send + '_; +} + +impl SubscriptionHandle for Arc> { + fn subscribe( + &self, + service_id: u16, + instance_id: u16, + event_group_id: u16, + subscriber_addr: SocketAddrV4, + ) -> impl Future> + Send + '_ { + let this = self.clone(); + async move { + this.write() + .await + .subscribe(service_id, instance_id, event_group_id, subscriber_addr) + } + } + + fn unsubscribe( + &self, + service_id: u16, + instance_id: u16, + event_group_id: u16, + subscriber_addr: SocketAddrV4, + ) -> impl Future + Send + '_ { + let this = self.clone(); + async move { + this.write() + .await + .unsubscribe(service_id, instance_id, event_group_id, subscriber_addr); + } + } + + fn get_subscribers( + &self, + service_id: u16, + instance_id: u16, + event_group_id: u16, + ) -> impl Future> + Send + '_ { + let this = self.clone(); + async move { + this.read() + .await + .get_subscribers(service_id, instance_id, event_group_id) + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/tokio_transport.rs b/src/tokio_transport.rs index f53ca6b..c363f3c 100644 --- a/src/tokio_transport.rs +++ b/src/tokio_transport.rs @@ -36,11 +36,15 @@ use core::future::Future; use core::net::{Ipv4Addr, SocketAddrV4}; use core::time::Duration; use std::net::{IpAddr, SocketAddr}; +use std::sync::{Arc, Mutex, RwLock}; use tokio::net::UdpSocket; +use crate::e2e::{E2ECheckStatus, E2EKey, E2EProfile}; +use crate::e2e::Error as E2EError; +use crate::e2e::E2ERegistry; use crate::transport::{ - IoErrorKind, ReceivedDatagram, SocketOptions, Timer, TransportError, TransportFactory, - TransportSocket, + E2ERegistryHandle, InterfaceHandle, IoErrorKind, ReceivedDatagram, SocketOptions, Timer, + TransportError, TransportFactory, TransportSocket, }; /// Factory that binds [`TokioSocket`]s configured via `socket2`. @@ -187,6 +191,53 @@ impl crate::transport::Spawner for TokioSpawner { } } +impl E2ERegistryHandle for Arc> { + fn register(&self, key: E2EKey, profile: E2EProfile) { + self.lock().expect("e2e registry lock poisoned").register(key, profile); + } + + fn unregister(&self, key: &E2EKey) { + self.lock().expect("e2e registry lock poisoned").unregister(key); + } + + fn contains_key(&self, key: &E2EKey) -> bool { + self.lock().expect("e2e registry lock poisoned").contains_key(key) + } + + fn protect( + &self, + key: E2EKey, + payload: &[u8], + upper_header: [u8; 8], + output: &mut [u8], + ) -> Option> { + self.lock() + .expect("e2e registry lock poisoned") + .protect(key, payload, upper_header, output) + } + + fn check<'a>( + &self, + key: E2EKey, + payload: &'a [u8], + upper_header: [u8; 8], + ) -> Option<(E2ECheckStatus, &'a [u8])> { + self.lock() + .expect("e2e registry lock poisoned") + .check(key, payload, upper_header) + } +} + +impl InterfaceHandle for Arc> { + fn get(&self) -> Ipv4Addr { + *self.read().expect("interface lock poisoned") + } + + fn set(&self, addr: Ipv4Addr) { + *self.write().expect("interface lock poisoned") = addr; + } +} + /// Synchronously create and configure a UDP socket via `socket2`, then /// hand it to tokio. Mirrors the existing bind paths in /// `crate::client::socket_manager` and `crate::server` (rendered as diff --git a/src/transport.rs b/src/transport.rs index acbeedf..aa3ab67 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -214,6 +214,9 @@ use core::future::Future; use core::net::{Ipv4Addr, SocketAddrV4}; use core::time::Duration; +use crate::e2e::{E2ECheckStatus, E2EKey, E2EProfile}; +use crate::e2e::Error as E2EError; + /// Portable I/O error kinds surfaced by transport implementations. /// /// This is a deliberately small vocabulary — anything that does not fit @@ -593,6 +596,70 @@ pub trait Spawner { fn spawn(&self, future: impl Future + Send + 'static); } +/// Shared handle to the runtime E2E configuration registry. +/// +/// Abstracts over `Arc>` on `std` and over +/// critical-section-backed primitives (e.g. `embassy_sync::blocking_mutex`) +/// on bare metal. All methods take `&self` and provide interior-mutable +/// access. Implementations are required to be `Clone` so the handle can be +/// cheaply shared between the `Client` (or `Server`) handle and its inner +/// event loop. +pub trait E2ERegistryHandle: Clone + Send + Sync + 'static { + /// Register an E2E profile for the given key, replacing any prior entry. + fn register(&self, key: E2EKey, profile: E2EProfile); + + /// Remove the E2E configuration for the given key. No-op if absent. + fn unregister(&self, key: &E2EKey); + + /// Returns `true` if a profile is registered for `key`. + fn contains_key(&self, key: &E2EKey) -> bool; + + /// Run E2E protect for `key` if configured, writing to `output`. + /// + /// Returns `None` if no profile is registered for `key`. + /// Returns `Some(Err(_))` if protection fails (e.g. buffer too small). + /// Returns `Some(Ok(len))` on success; `len` is the number of bytes + /// written to `output`. + fn protect( + &self, + key: E2EKey, + payload: &[u8], + upper_header: [u8; 8], + output: &mut [u8], + ) -> Option>; + + /// Run E2E check for `key` if configured. + /// + /// Returns `None` if no profile is registered for `key`. Otherwise + /// returns the check status and the effective payload slice — the + /// E2E header is stripped on success; the original bytes are returned + /// on check failure so the caller can decide how to handle it. + /// + /// The returned slice borrows from `payload`, not from this handle. + fn check<'a>( + &self, + key: E2EKey, + payload: &'a [u8], + upper_header: [u8; 8], + ) -> Option<(E2ECheckStatus, &'a [u8])>; +} + +/// Shared handle to the local interface address. +/// +/// Abstracts over `Arc>` on `std`. All clones of a +/// `Client` share the same handle, so writes from one clone (e.g. +/// `Client::set_interface`) are visible to all others. +/// +/// On bare metal, where `Client` is not `Clone`, a trivial implementation +/// wrapping a `core::cell::Cell` suffices. +pub trait InterfaceHandle: Clone + Send + Sync + 'static { + /// Returns the current interface address. + fn get(&self) -> Ipv4Addr; + + /// Updates the stored interface address. + fn set(&self, addr: Ipv4Addr); +} + #[cfg(test)] mod tests { //! The traits are pure interfaces — these tests only verify that @@ -755,4 +822,63 @@ mod tests { assert_eq!(e, TransportError::Io(IoErrorKind::TimedOut)); assert_ne!(e, TransportError::AddressInUse); } + + // Minimal no-op implementations to verify that E2ERegistryHandle and + // InterfaceHandle are implementable without any executor machinery. + #[derive(Clone)] + struct NullE2ERegistry; + + impl E2ERegistryHandle for NullE2ERegistry { + fn register(&self, _key: E2EKey, _profile: E2EProfile) {} + fn unregister(&self, _key: &E2EKey) {} + fn contains_key(&self, _key: &E2EKey) -> bool { + false + } + fn protect( + &self, + _key: E2EKey, + _payload: &[u8], + _upper_header: [u8; 8], + _output: &mut [u8], + ) -> Option> { + None + } + fn check<'a>( + &self, + _key: E2EKey, + _payload: &'a [u8], + _upper_header: [u8; 8], + ) -> Option<(E2ECheckStatus, &'a [u8])> { + None + } + } + + #[derive(Clone)] + struct NullInterface(Ipv4Addr); + + impl InterfaceHandle for NullInterface { + fn get(&self) -> Ipv4Addr { + self.0 + } + fn set(&self, _addr: Ipv4Addr) {} + } + + #[test] + fn null_e2e_registry_compiles() { + let r = NullE2ERegistry; + let key = E2EKey::new(0, 0); + r.register(key, crate::e2e::E2EProfile::Profile4( + crate::e2e::Profile4Config::new(0, 8), + )); + assert!(!r.contains_key(&key)); + assert!(r.check(key, b"hello", [0; 8]).is_none()); + } + + #[test] + fn null_interface_get_set() { + let h = NullInterface(Ipv4Addr::LOCALHOST); + assert_eq!(h.get(), Ipv4Addr::LOCALHOST); + h.set(Ipv4Addr::UNSPECIFIED); // no-op in null impl + assert_eq!(h.get(), Ipv4Addr::LOCALHOST); // unchanged + } }