From d46679342550ce2fb6c6ae69cdcabd1234375cb9 Mon Sep 17 00:00:00 2001 From: Zeeshan Ali Khan Date: Tue, 25 Nov 2025 13:14:38 +0100 Subject: [PATCH 1/7] Macro now operates on a module This allows the macro to have a visibility on both the struct and the impl block and would enable us to improve the ergonomics of the API and add new API that's currently not possible due to the decoupling between the macro operating on the struct and impl block. --- CLAUDE.md | 29 ++++++++++------ README.md | 71 +++++++++++++++++++------------------ src/controller/mod.rs | 81 +++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 17 +++------ 4 files changed, 142 insertions(+), 56 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 27caf07..a8bac12 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -11,7 +11,7 @@ This is a procedural macro crate that provides the `#[controller]` attribute mac * Signal mechanism for broadcasting events. * Pub/sub system for state change notifications. -The macro works by processing both struct definitions (to add publishers) and impl blocks (to generate the controller's `run` method, client API, and signal infrastructure). +The macro is applied to a module containing both the controller struct definition and its impl block, allowing coordinated code generation of the controller infrastructure, client API, and communication channels. ## Build & Test Commands @@ -41,10 +41,24 @@ cargo doc --locked ## Architecture ### Macro Entry Point (`src/lib.rs`) -The `controller` attribute macro dispatches to either `item_struct` or `item_impl` based on input type. +The `controller` attribute macro parses the input as an `ItemMod` (module) and calls `controller::expand_module()`. + +### Module Processing (`src/controller/mod.rs`) +The `expand_module()` function: +* Validates the module has a body with exactly one struct and one impl block. +* Extracts the struct and impl items from the module. +* Validates that the impl block matches the struct name. +* Calls `item_struct::expand()` and `item_impl::expand()` to process each component. +* Combines the generated code back into the module structure along with any other items. + +Channel capacities and subscriber limits are also defined here: +* `ALL_CHANNEL_CAPACITY`: 8 +* `SIGNAL_CHANNEL_CAPACITY`: 8 +* `BROADCAST_MAX_PUBLISHERS`: 1 +* `BROADCAST_MAX_SUBSCRIBERS`: 16 ### Struct Processing (`src/controller/item_struct.rs`) -Processes `#[controller]` on struct definitions. For fields marked with `#[controller(publish)]`: +Processes the controller struct definition. For fields marked with `#[controller(publish)]`: * Adds publisher fields to the struct. * Generates setters (`set_`) that broadcast changes. * Creates `` stream type and `Changed` event struct. @@ -52,7 +66,7 @@ Processes `#[controller]` on struct definitions. For fields marked with `#[contr The generated `new()` method initializes both user fields and generated publisher fields. ### Impl Processing (`src/controller/item_impl.rs`) -Processes `#[controller]` on impl blocks. Distinguishes between: +Processes the controller impl block. Distinguishes between: **Proxied methods** (normal methods): * Creates request/response channels for each method. @@ -67,13 +81,6 @@ Processes `#[controller]` on impl blocks. Distinguishes between: The generated `run()` method contains a `select_biased!` loop that receives method calls from clients and dispatches them to the user's implementations. -### Constants (`src/controller/mod.rs`) -Channel capacities and subscriber limits are defined here: -* `ALL_CHANNEL_CAPACITY`: 8 -* `SIGNAL_CHANNEL_CAPACITY`: 8 -* `BROADCAST_MAX_PUBLISHERS`: 1 -* `BROADCAST_MAX_SUBSCRIBERS`: 16 - ### Utilities (`src/util.rs`) Case conversion functions (`pascal_to_snake_case`, `snake_to_pascal_case`) used for generating type and method names. diff --git a/README.md b/README.md index 4010508..c145ec9 100644 --- a/README.md +++ b/README.md @@ -39,53 +39,58 @@ pub enum State { Disabled, } -// The controller struct. This is where you define the state of your firmware. #[controller] -pub struct Controller { - #[controller(publish)] - state: State, - // Other fields. Note: No all of them need to be published. -} +mod controller { + use super::*; + + // The controller struct. This is where you define the state of your firmware. + pub struct Controller { + #[controller(publish)] + state: State, + // Other fields. Note: No all of them need to be published. + } -// The controller implementation. This is where you define the logic of your firmware. -#[controller] -impl Controller { - // The `signal` attribute marks this method signature (note: no implementation body) as a - // signal, that you can use to notify other parts of your code about specific events. - #[controller(signal)] - pub async fn power_error(&self, description: heapless::String<64>); - - pub async fn enable_power(&mut self) -> Result<(), MyFirmwareError> { - if self.state != State::Disabled { - return Err(MyFirmwareError::InvalidState); - } + // The controller implementation. This is where you define the logic of your firmware. + impl Controller { + // The `signal` attribute marks this method signature (note: no implementation body) as a + // signal, that you can use to notify other parts of your code about specific events. + #[controller(signal)] + pub async fn power_error(&self, description: heapless::String<64>); - // Any other logic you want to run when enabling power. + pub async fn enable_power(&mut self) -> Result<(), MyFirmwareError> { + if self.state != State::Disabled { + return Err(MyFirmwareError::InvalidState); + } - self.set_state(State::Enabled).await; - self.power_error("Dummy error just for the showcase".try_into().unwrap()) - .await; + // Any other logic you want to run when enabling power. - Ok(()) - } + self.set_state(State::Enabled).await; + self.power_error("Dummy error just for the showcase".try_into().unwrap()) + .await; - pub async fn disable_power(&mut self) -> Result<(), MyFirmwareError> { - if self.state != State::Enabled { - return Err(MyFirmwareError::InvalidState); + Ok(()) } - // Any other logic you want to run when enabling power. + pub async fn disable_power(&mut self) -> Result<(), MyFirmwareError> { + if self.state != State::Enabled { + return Err(MyFirmwareError::InvalidState); + } - self.set_state(State::Disabled).await; + // Any other logic you want to run when enabling power. - Ok(()) - } + self.set_state(State::Disabled).await; + + Ok(()) + } - // Method that doesn't return anything. - pub async fn return_nothing(&self) { + // Method that doesn't return anything. + pub async fn return_nothing(&self) { + } } } +use controller::*; + #[embassy_executor::main] async fn main(spawner: embassy_executor::Spawner) { let mut controller = Controller::new(State::Disabled); diff --git a/src/controller/mod.rs b/src/controller/mod.rs index c433e0d..2c21aa0 100644 --- a/src/controller/mod.rs +++ b/src/controller/mod.rs @@ -1,7 +1,88 @@ pub(crate) mod item_impl; pub(crate) mod item_struct; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{spanned::Spanned, Item, ItemMod, Result}; + const ALL_CHANNEL_CAPACITY: usize = 8; const SIGNAL_CHANNEL_CAPACITY: usize = 8; const BROADCAST_MAX_PUBLISHERS: usize = 1; const BROADCAST_MAX_SUBSCRIBERS: usize = 16; + +pub(crate) fn expand_module(input: ItemMod) -> Result { + let vis = &input.vis; + let mod_name = &input.ident; + let span = input.span(); + + let (_, items) = input + .content + .ok_or_else(|| syn::Error::new(span, "Module must have a body"))?; + + let mut struct_item = None; + let mut impl_item = None; + let mut other_items = Vec::new(); + + for item in items { + match item { + Item::Struct(s) => { + if struct_item.is_some() { + return Err(syn::Error::new( + s.span(), + "Module must contain exactly one struct definition", + )); + } + struct_item = Some(s); + } + Item::Impl(i) => { + if impl_item.is_some() { + return Err(syn::Error::new( + i.span(), + "Module must contain exactly one impl block", + )); + } + impl_item = Some(i); + } + other => other_items.push(other), + } + } + + let struct_item = struct_item.ok_or_else(|| { + syn::Error::new( + span, + "Module must contain a struct definition for the controller", + ) + })?; + + let impl_item = impl_item.ok_or_else(|| { + syn::Error::new(span, "Module must contain an impl block for the controller") + })?; + + let struct_name = &struct_item.ident; + if let syn::Type::Path(type_path) = &*impl_item.self_ty { + if let Some(ident) = type_path.path.get_ident() { + if ident != struct_name { + return Err(syn::Error::new( + impl_item.span(), + format!( + "Impl block is for type '{}' but controller struct is named '{}'", + ident, struct_name + ), + )); + } + } + } + + let expanded_struct = item_struct::expand(struct_item)?; + let expanded_impl = item_impl::expand(impl_item)?; + + Ok(quote! { + #vis mod #mod_name { + #(#other_items)* + + #expanded_struct + + #expanded_impl + } + }) +} diff --git a/src/lib.rs b/src/lib.rs index 4378af5..e897e92 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,7 @@ #![doc = include_str!("../README.md")] use proc_macro::TokenStream; -use syn::{parse_macro_input, punctuated::Punctuated, ItemImpl, ItemStruct, Meta, Token}; +use syn::{parse_macro_input, punctuated::Punctuated, ItemMod, Meta, Token}; mod controller; mod util; @@ -11,15 +11,8 @@ mod util; pub fn controller(attr: TokenStream, item: TokenStream) -> TokenStream { let _args = parse_macro_input!(attr with Punctuated::parse_terminated); - if let Ok(input) = syn::parse::(item.clone()) { - controller::item_struct::expand(input) - .unwrap_or_else(|e| e.to_compile_error()) - .into() - } else if let Ok(input) = syn::parse::(item) { - controller::item_impl::expand(input) - .unwrap_or_else(|e| e.to_compile_error()) - .into() - } else { - panic!("Expected struct or trait") - } + let input = parse_macro_input!(item as ItemMod); + controller::expand_module(input) + .unwrap_or_else(|e| e.to_compile_error()) + .into() } From 08a5e5f5dddbae6fc9f365b80f3781f56490f9a3 Mon Sep 17 00:00:00 2001 From: Zeeshan Ali Khan Date: Tue, 25 Nov 2025 13:42:48 +0100 Subject: [PATCH 2/7] Abstract the user from signal & state change types Instead provide methods to create streams for state changes. --- README.md | 51 ++++++++++++----------- src/controller/item_impl.rs | 54 +++++++++++++++++++++--- src/controller/item_struct.rs | 78 ++++++++++++++++++++++++----------- src/controller/mod.rs | 5 ++- 4 files changed, 132 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index c145ec9..cd0030f 100644 --- a/README.md +++ b/README.md @@ -109,9 +109,8 @@ async fn client() { use embassy_time::{Timer, Duration}; let mut client = ControllerClient::new(); - // SAFETY: We don't create more than 16 instances so we won't panic. - let state_changed = ControllerState::new().unwrap().map(Either::Left); - let error_stream = ControllerPowerError::new().unwrap().map(Either::Right); + let state_changed = client.receive_state_changed().unwrap().map(Either::Left); + let error_stream = client.receive_power_error().unwrap().map(Either::Right); let mut stream = select(state_changed, error_stream); client.enable_power().await.unwrap(); @@ -148,32 +147,34 @@ async fn client() { The `controller` macro will generated the following for you: +## Controller struct + * A `new` method that takes the fields of the struct as arguments and returns the struct. * For each `published` field: - * Setter for this field, named `set_` (so`set_state` here), which broadcasts any + * Setter for this field, named `set_` (e.g., `set_state`), which broadcasts any changes made to this field. - * Two client-side types: - * struct named `Changed` (so `ControllerStateChanged` - for `state` field), containing two public fields, named `previous` and `new` fields - representing the previous and new values of the field, respectively. - * Type named `` (so `ControllerState` for - `state` field), which implements `futures::Stream`, yielding each state change as the change - struct described above. -* `run` method with signature `pub async fn run(&mut self);` which runs the controller logic, - proxying calls from the client to the implementations here and their return value back to - the clients (internally via channels). Typically you'd call it at the end of your `main` - or run it as a task. -* Client-side API for this struct, named `Client` (`ControllerClient` here) - which provides exactly the same methods (except signal methods) defined in this implementation - that other parts of the code use to call these methods. +* A `run` method with signature `pub async fn run(&mut self);` which runs the controller logic, + proxying calls from the client to the implementations and their return values back to the + clients (internally via channels). Typically you'd call it at the end of your `main` or run it + as a task. +* For each `signal` method: + * The method body, that broadcasts the signal to all clients that are listening to it. + +## Client API + +A client struct named `Client` (`ControllerClient` in the example) with the following +methods: + +* All methods defined in the controller impl (except signal methods), which proxy calls to the + controller and return the results. +* For each `published` field: + * `receive__changed()` method (e.g., `receive_state_changed()`) that returns a + stream of state changes. The stream yields `Changed` + structs (e.g., `ControllerStateChanged`) containing `previous` and `new` fields. * For each `signal` method: - * The method body, that broadcasts the signal to all the clients that are listening to it. - * Two client-side types: - * struct, named `Args` (`ControllerPowerErrorArgs` - here), containing all the arguments of this method, as public fields. - * Type named `` (`ControllerPowerError` here) which - implements `futures::Stream`, yielding each signal broadcasted as the args struct described - above. + * `receive_()` method (e.g., `receive_power_error()`) that returns a stream of + signal events. The stream yields `Args` structs + (e.g., `ControllerPowerErrorArgs`) containing all signal arguments as public fields. ## Dependencies assumed diff --git a/src/controller/item_impl.rs b/src/controller/item_impl.rs index e6b72a5..0146150 100644 --- a/src/controller/item_impl.rs +++ b/src/controller/item_impl.rs @@ -7,17 +7,21 @@ use syn::{ Attribute, Ident, ImplItem, ImplItemFn, ItemImpl, Result, Signature, Token, Visibility, }; -use crate::util::snake_to_pascal_case; +use crate::{controller::item_struct::PublishedFieldInfo, util::snake_to_pascal_case}; -pub(crate) fn expand(mut input: ItemImpl) -> Result { +pub(crate) fn expand( + mut input: ItemImpl, + published_fields: &[PublishedFieldInfo], +) -> Result { let struct_name = get_struct_name(&input)?; let struct_name_str = struct_name.to_string(); let methods = get_methods(&mut input, &struct_name)?; - let signal_declarations = methods.iter().filter_map(|m| match m { - Method::Signal(signal) => Some(&signal.declarations), + let signals = methods.iter().filter_map(|m| match m { + Method::Signal(signal) => Some(signal), _ => None, }); + let signal_declarations = signals.clone().map(|s| &s.declarations); let methods = methods.iter().filter_map(|m| match m { Method::Proxied(method) => Some(method), @@ -40,6 +44,31 @@ pub(crate) fn expand(mut input: ItemImpl) -> Result { }; input.items.push(syn::parse2(run_method)?); + // Generate stream getter methods for published fields. + let published_field_getters = published_fields.iter().map(|field| { + let method_name = Ident::new( + &format!("receive_{}_changed", field.field_name), + field.field_name.span(), + ); + let subscriber_type = &field.subscriber_struct_name; + quote! { + pub fn #method_name(&self) -> Option<#subscriber_type> { + #subscriber_type::new() + } + } + }); + + // Generate stream getter methods for signals. + let signal_getters = signals.clone().map(|signal| { + let method_name = &signal.receive_method_name; + let subscriber_type = &signal.subscriber_struct_name; + quote! { + pub fn #method_name(&self) -> Option<#subscriber_type> { + #subscriber_type::new() + } + } + }); + let client_name = Ident::new(&format!("{}Client", struct_name_str), input.span()); let client_methods = methods.clone().map(|m| &m.client_method); let client_method_tx_rx_declarations = @@ -65,6 +94,10 @@ pub(crate) fn expand(mut input: ItemImpl) -> Result { } #(#client_methods)* + + #(#published_field_getters)* + + #(#signal_getters)* } #(#signal_declarations)* @@ -329,6 +362,10 @@ impl ProxiedMethodArgs<'_> { struct Signal { /// The input arguments' channel and client-side struct declarations. declarations: TokenStream, + /// Name of the receive method (e.g., receive_power_error). + receive_method_name: Ident, + /// Name of the subscriber struct (e.g., ControllerPowerError). + subscriber_struct_name: Ident, } impl Signal { @@ -429,7 +466,14 @@ impl Signal { ).await; }); - Ok(Self { declarations }) + let receive_method_name = + Ident::new(&format!("receive_{}", method_name_str), method.span()); + + Ok(Self { + declarations, + receive_method_name, + subscriber_struct_name, + }) } } diff --git a/src/controller/item_struct.rs b/src/controller/item_struct.rs index 95e3588..09505a7 100644 --- a/src/controller/item_struct.rs +++ b/src/controller/item_struct.rs @@ -3,7 +3,20 @@ use proc_macro2::TokenStream; use quote::quote; use syn::{spanned::Spanned, Field, Fields, Ident, ItemStruct, Result}; -pub(crate) fn expand(mut input: ItemStruct) -> Result { +/// Information about a published field, to be used by impl processing. +#[derive(Debug, Clone)] +pub(crate) struct PublishedFieldInfo { + pub field_name: Ident, + pub subscriber_struct_name: Ident, +} + +/// Result of expanding a struct. +pub(crate) struct ExpandedStruct { + pub tokens: TokenStream, + pub published_fields: Vec, +} + +pub(crate) fn expand(mut input: ItemStruct) -> Result { let struct_name = &input.ident; let fields = StructFields::parse(&mut input.fields, struct_name)?; @@ -14,14 +27,16 @@ pub(crate) fn expand(mut input: ItemStruct) -> Result { publisher_fields_initializations, setters, subscriber_declarations, + published_fields_info, ) = fields.published().fold( - (quote!(), quote!(), quote!(), quote!(), quote!()), + (quote!(), quote!(), quote!(), quote!(), quote!(), Vec::new()), |( publish_channels, publisher_fields_declarations, publisher_fields_initializations, setters, subscribers, + mut infos, ), f| { let (publish_channel, publisher_field, publisher_field_init, setter, subscriber) = ( @@ -32,39 +47,45 @@ pub(crate) fn expand(mut input: ItemStruct) -> Result { &f.subscriber_declaration, ); + infos.push(f.info.clone()); + ( quote! { #publish_channels #publish_channel }, quote! { #publisher_fields_declarations #publisher_field, }, quote! { #publisher_fields_initializations #publisher_field_init, }, quote! { #setters #setter }, quote! { #subscribers #subscriber }, + infos, ) }, ); let fields = fields.raw_fields().collect::>(); let vis = &input.vis; - Ok(quote! { - #vis struct #struct_name { - #(#fields),*, - #publisher_fields_declarations - } + Ok(ExpandedStruct { + tokens: quote! { + #vis struct #struct_name { + #(#fields),*, + #publisher_fields_declarations + } - impl #struct_name { - #[allow(clippy::too_many_arguments)] - pub fn new(#(#fields),*) -> Self { - Self { - #(#field_names),*, - #publisher_fields_initializations + impl #struct_name { + #[allow(clippy::too_many_arguments)] + pub fn new(#(#fields),*) -> Self { + Self { + #(#field_names),*, + #publisher_fields_initializations + } } - } - #setters - } + #setters + } - #publish_channel_declarations + #publish_channel_declarations - #subscriber_declarations + #subscriber_declarations + }, + published_fields: published_fields_info, }) } @@ -110,7 +131,7 @@ impl StructFields { /// All the published fields. fn published(&self) -> impl Iterator { self.fields.iter().filter_map(|field| match field { - StructField::Published(published) => Some(published), + StructField::Published(published) => Some(published.as_ref()), _ => None, }) } @@ -120,9 +141,9 @@ impl StructFields { #[derive(Debug)] enum StructField { /// Private field. - Private(Field), + Private(Box), /// Published field. - Published(PublishedField), + Published(Box), } impl StructField { @@ -130,15 +151,16 @@ impl StructField { fn parse(field: &mut Field, struct_name: &Ident) -> Result { PublishedField::parse(field, struct_name).map(|published| { published - .map(StructField::Published) - .unwrap_or_else(|| StructField::Private(field.clone())) + .map(|p| StructField::Published(Box::new(p))) + .unwrap_or_else(|| StructField::Private(Box::new(field.clone()))) }) } /// Get the field. fn field(&self) -> &Field { match self { - Self::Private(field) | Self::Published(PublishedField { field, .. }) => field, + Self::Private(field) => field.as_ref(), + Self::Published(published) => &published.field, } } } @@ -158,6 +180,8 @@ struct PublishedField { publish_channel_declaration: proc_macro2::TokenStream, /// Subscriber struct declaration. subscriber_declaration: proc_macro2::TokenStream, + /// Information to be passed to impl processing. + info: PublishedFieldInfo, } impl PublishedField { @@ -291,6 +315,11 @@ impl PublishedField { } }; + let info = PublishedFieldInfo { + field_name: field_name.clone(), + subscriber_struct_name: subscriber_struct_name.clone(), + }; + Ok(Some(PublishedField { field: field.clone(), publisher_field_declaration, @@ -298,6 +327,7 @@ impl PublishedField { setter, publish_channel_declaration, subscriber_declaration, + info, })) } } diff --git a/src/controller/mod.rs b/src/controller/mod.rs index 2c21aa0..5a4583e 100644 --- a/src/controller/mod.rs +++ b/src/controller/mod.rs @@ -74,13 +74,14 @@ pub(crate) fn expand_module(input: ItemMod) -> Result { } let expanded_struct = item_struct::expand(struct_item)?; - let expanded_impl = item_impl::expand(impl_item)?; + let expanded_impl = item_impl::expand(impl_item, &expanded_struct.published_fields)?; + let struct_tokens = expanded_struct.tokens; Ok(quote! { #vis mod #mod_name { #(#other_items)* - #expanded_struct + #struct_tokens #expanded_impl } From adec2440f6442ce7fc20359ec7385e676edeed03 Mon Sep 17 00:00:00 2001 From: Zeeshan Ali Khan Date: Tue, 25 Nov 2025 14:47:06 +0100 Subject: [PATCH 3/7] Attribute to generate client-side setters Published fields can now have a client-side setter method if user asks for it through a new sub-attribute, `pub_setter`. --- README.md | 3 + src/controller/item_impl.rs | 130 +++++++++++++++++++++++++++++++++- src/controller/item_struct.rs | 30 +++++++- 3 files changed, 160 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index cd0030f..16c1e56 100644 --- a/README.md +++ b/README.md @@ -171,6 +171,9 @@ methods: * `receive__changed()` method (e.g., `receive_state_changed()`) that returns a stream of state changes. The stream yields `Changed` structs (e.g., `ControllerStateChanged`) containing `previous` and `new` fields. + * If the field is marked with `#[controller(publish(pub_setter))]`, a public + `set_()` method (e.g., `set_state()`) is also generated on the client, allowing + external code to update the field value through the client API. * For each `signal` method: * `receive_()` method (e.g., `receive_power_error()`) that returns a stream of signal events. The stream yields `Args` structs diff --git a/src/controller/item_impl.rs b/src/controller/item_impl.rs index 0146150..0934d07 100644 --- a/src/controller/item_impl.rs +++ b/src/controller/item_impl.rs @@ -31,13 +31,30 @@ pub(crate) fn expand( let args_channels_rx_tx = methods.clone().map(|m| &m.args_channels_rx_tx); let select_arms = methods.clone().map(|m| &m.select_arm); + // Generate public setters for published fields with pub_setter. + let pub_setters: Vec<_> = published_fields + .iter() + .filter(|field| field.pub_setter) + .map(|field| generate_pub_setter(field, &struct_name)) + .collect(); + let pub_setter_channel_declarations = pub_setters.iter().map(|s| &s.channel_declarations); + let pub_setter_rx_tx = pub_setters.iter().map(|s| &s.rx_tx); + let pub_setter_select_arms = pub_setters.iter().map(|s| &s.select_arm); + let pub_setter_client_methods = pub_setters.iter().map(|s| &s.client_method); + let pub_setter_client_tx_rx_declarations = + pub_setters.iter().map(|s| &s.client_tx_rx_declarations); + let pub_setter_client_tx_rx_initializations = + pub_setters.iter().map(|s| &s.client_tx_rx_initializations); + let run_method = quote! { pub async fn run(mut self) { #(#args_channels_rx_tx)* + #(#pub_setter_rx_tx)* loop { futures::select_biased! { - #(#select_arms),* + #(#select_arms,)* + #(#pub_setter_select_arms),* } } } @@ -79,22 +96,27 @@ pub(crate) fn expand( Ok(quote! { #(#args_channel_declarations)* + #(#pub_setter_channel_declarations)* #input pub struct #client_name { #(#client_method_tx_rx_declarations)* + #(#pub_setter_client_tx_rx_declarations)* } impl #client_name { pub fn new() -> Self { Self { #(#client_method_tx_rx_initializations)* + #(#pub_setter_client_tx_rx_initializations)* } } #(#client_methods)* + #(#pub_setter_client_methods)* + #(#published_field_getters)* #(#signal_getters)* @@ -567,3 +589,109 @@ impl MethodInputArgs { }) } } + +#[derive(Debug)] +struct PubSetter { + channel_declarations: TokenStream, + rx_tx: TokenStream, + select_arm: TokenStream, + client_method: TokenStream, + client_tx_rx_declarations: TokenStream, + client_tx_rx_initializations: TokenStream, +} + +fn generate_pub_setter(field: &PublishedFieldInfo, struct_name: &Ident) -> PubSetter { + let field_name = &field.field_name; + let field_type = &field.field_type; + let setter_method_name = &field.setter_name; + let field_name_str = field_name.to_string(); + + let struct_name_caps = struct_name.to_string().to_uppercase(); + let field_name_caps = field_name_str.to_uppercase(); + let input_channel_name = Ident::new( + &format!("{}_SET_{}_INPUT_CHANNEL", struct_name_caps, field_name_caps), + field_name.span(), + ); + let output_channel_name = Ident::new( + &format!( + "{}_SET_{}_OUTPUT_CHANNEL", + struct_name_caps, field_name_caps + ), + field_name.span(), + ); + let capacity = super::ALL_CHANNEL_CAPACITY; + + let channel_declarations = quote! { + static #input_channel_name: + embassy_sync::channel::Channel< + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + #field_type, + #capacity, + > = embassy_sync::channel::Channel::new(); + static #output_channel_name: + embassy_sync::channel::Channel< + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + (), + #capacity, + > = embassy_sync::channel::Channel::new(); + }; + + let input_channel_rx_name = + Ident::new(&format!("{}_value_rx", field_name_str), field_name.span()); + let output_channel_tx_name = + Ident::new(&format!("{}_ack_tx", field_name_str), field_name.span()); + let rx_tx = quote! { + let #input_channel_rx_name = embassy_sync::channel::Channel::receiver(&#input_channel_name); + let #output_channel_tx_name = embassy_sync::channel::Channel::sender(&#output_channel_name); + }; + + let select_arm = quote! { + value = futures::FutureExt::fuse( + embassy_sync::channel::Receiver::receive(&#input_channel_rx_name), + ) => { + self.#setter_method_name(value).await; + + embassy_sync::channel::Sender::send(&#output_channel_tx_name, ()).await; + } + }; + + let input_channel_tx_name = + Ident::new(&format!("{}_value_tx", field_name_str), field_name.span()); + let output_channel_rx_name = + Ident::new(&format!("{}_ack_rx", field_name_str), field_name.span()); + let client_method = quote! { + pub async fn #setter_method_name(&self, value: #field_type) { + embassy_sync::channel::Sender::send(&self.#input_channel_tx_name, value).await; + embassy_sync::channel::Receiver::receive(&self.#output_channel_rx_name).await + } + }; + + let client_tx_rx_declarations = quote! { + #input_channel_tx_name: embassy_sync::channel::Sender< + 'static, + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + #field_type, + #capacity, + >, + #output_channel_rx_name: embassy_sync::channel::Receiver< + 'static, + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + (), + #capacity, + > + }; + + let client_tx_rx_initializations = quote! { + #input_channel_tx_name: embassy_sync::channel::Channel::sender(&#input_channel_name), + #output_channel_rx_name: embassy_sync::channel::Channel::receiver(&#output_channel_name) + }; + + PubSetter { + channel_declarations, + rx_tx, + select_arm, + client_method, + client_tx_rx_declarations, + client_tx_rx_initializations, + } +} diff --git a/src/controller/item_struct.rs b/src/controller/item_struct.rs index 09505a7..da929ac 100644 --- a/src/controller/item_struct.rs +++ b/src/controller/item_struct.rs @@ -1,13 +1,16 @@ use crate::util::*; use proc_macro2::TokenStream; use quote::quote; -use syn::{spanned::Spanned, Field, Fields, Ident, ItemStruct, Result}; +use syn::{spanned::Spanned, Field, Fields, Ident, ItemStruct, Result, Token}; /// Information about a published field, to be used by impl processing. #[derive(Debug, Clone)] pub(crate) struct PublishedFieldInfo { pub field_name: Ident, + pub field_type: syn::Type, + pub setter_name: Ident, pub subscriber_struct_name: Ident, + pub pub_setter: bool, } /// Result of expanding a struct. @@ -195,6 +198,7 @@ impl PublishedField { Some(attr) => attr, None => return Ok(None), }; + let mut pub_setter = false; attr.parse_nested_meta(|meta| { if !meta.path.is_ident("publish") { let e = format!( @@ -205,6 +209,25 @@ impl PublishedField { return Err(syn::Error::new_spanned(attr, e)); } + if meta.input.peek(syn::token::Paren) { + let content; + syn::parenthesized!(content in meta.input); + while !content.is_empty() { + let nested_ident: Ident = content.parse()?; + if nested_ident == "pub_setter" { + pub_setter = true; + } else { + let e = + format!("expected `pub_setter` attribute, found `{}`", nested_ident); + return Err(syn::Error::new_spanned(&nested_ident, e)); + } + + if !content.is_empty() { + content.parse::()?; + } + } + } + Ok(()) })?; field @@ -233,6 +256,7 @@ impl PublishedField { let max_subscribers = super::BROADCAST_MAX_SUBSCRIBERS; let max_publishers = super::BROADCAST_MAX_PUBLISHERS; + let setter_name = Ident::new(&format!("set_{field_name_str}"), field.span()); let publisher_name = Ident::new(&format!("{field_name_str}_publisher"), field.span()); let publisher_field_declaration = quote! { #publisher_name: @@ -249,7 +273,6 @@ impl PublishedField { // We only create one publisher so we can't fail. #publisher_name: embassy_sync::pubsub::PubSubChannel::publisher(&#publish_channel_name).unwrap() }; - let setter_name = Ident::new(&format!("set_{field_name_str}"), field.span()); let setter = quote! { pub async fn #setter_name(&mut self, mut value: #ty) { core::mem::swap(&mut self.#field_name, &mut value); @@ -317,7 +340,10 @@ impl PublishedField { let info = PublishedFieldInfo { field_name: field_name.clone(), + field_type: ty.clone(), + setter_name: setter_name.clone(), subscriber_struct_name: subscriber_struct_name.clone(), + pub_setter, }; Ok(Some(PublishedField { From 312a885de7949403698630e0084be5462f2d9c1e Mon Sep 17 00:00:00 2001 From: Zeeshan Ali Khan Date: Tue, 25 Nov 2025 15:10:36 +0100 Subject: [PATCH 4/7] Add an e2e testcase --- Cargo.lock | 29 ++++++ Cargo.toml | 3 + tests/integration.rs | 205 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 237 insertions(+) create mode 100644 tests/integration.rs diff --git a/Cargo.lock b/Cargo.lock index 2d445d0..75e2143 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -232,6 +232,7 @@ dependencies = [ name = "firmware-controller" version = "0.2.0" dependencies = [ + "critical-section", "embassy-executor", "embassy-sync", "embassy-time", @@ -256,6 +257,7 @@ checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -278,6 +280,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-io" version = "0.3.31" @@ -313,12 +326,16 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", + "futures-io", "futures-macro", "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -384,6 +401,12 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + [[package]] name = "nb" version = "0.1.3" @@ -472,6 +495,12 @@ version = "1.0.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" +[[package]] +name = "slab" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" + [[package]] name = "spin" version = "0.9.8" diff --git a/Cargo.toml b/Cargo.toml index f264032..0500bee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,10 @@ syn = { version = "2", features = ["extra-traits", "fold", "full"] } heapless = { version = "0.7", default-features = false } futures = { version = "0.3", default-features = false, features = [ "async-await", + "std", + "executor", ] } +critical-section = { version = "1.2", features = ["std"] } embassy-sync = "0.7.2" embassy-executor = { version = "0.9.1", features = [ "arch-std", diff --git a/tests/integration.rs b/tests/integration.rs new file mode 100644 index 0000000..5c5dea3 --- /dev/null +++ b/tests/integration.rs @@ -0,0 +1,205 @@ +use firmware_controller::controller; +use futures::StreamExt; + +#[derive(Debug, PartialEq, Copy, Clone)] +pub enum State { + Idle, + Active, + Error, +} + +#[derive(Debug, PartialEq, Copy, Clone)] +pub enum Mode { + Normal, + Debug, +} + +#[derive(Debug, PartialEq)] +pub enum TestError { + InvalidState, + OperationFailed, +} + +#[controller] +mod test_controller { + use super::*; + + pub struct Controller { + #[controller(publish)] + state: State, + #[controller(publish(pub_setter))] + mode: Mode, + counter: u32, + } + + impl Controller { + #[controller(signal)] + pub async fn error_occurred(&self, code: u32, message: heapless::String<32>); + + #[controller(signal)] + pub async fn operation_complete(&self); + + pub async fn increment(&mut self) -> u32 { + self.counter += 1; + self.counter + } + + pub async fn get_counter(&self) -> u32 { + self.counter + } + + pub async fn activate(&mut self) -> Result<(), TestError> { + if self.state != State::Idle { + return Err(TestError::InvalidState); + } + self.set_state(State::Active).await; + self.operation_complete().await; + Ok(()) + } + + pub async fn trigger_error(&mut self) -> Result<(), TestError> { + self.set_state(State::Error).await; + self.error_occurred(42, "Test error".try_into().unwrap()) + .await; + Err(TestError::OperationFailed) + } + + pub async fn return_nothing(&self) {} + } +} + +use test_controller::*; + +#[test] +fn test_controller_basic_functionality() { + // Create the controller before spawning the thread to avoid any race conditions. + // The channels used for communication will buffer requests, so it's safe for the + // client to start making calls even if the controller task hasn't fully started yet. + let controller = Controller::new(State::Idle, Mode::Normal, 0); + + // Run the controller in a background thread. + std::thread::spawn(move || { + let executor = Box::leak(Box::new(embassy_executor::Executor::new())); + executor.run(move |spawner| { + spawner.spawn(controller_task(controller)).unwrap(); + }); + }); + + // Run the test logic. + futures::executor::block_on(async { + // Create client. + let mut client = ControllerClient::new(); + + // Test 1: Subscribe to state changes. + let mut state_stream = client.receive_state_changed().expect("Failed to subscribe"); + + // Test 2: Subscribe to signals. + let mut error_stream = client + .receive_error_occurred() + .expect("Failed to subscribe to error"); + let mut complete_stream = client + .receive_operation_complete() + .expect("Failed to subscribe to complete"); + + // Test 3: Call a method and verify return value. + let counter = client.get_counter().await; + assert_eq!(counter, 0, "Initial counter should be 0"); + + // Test 4: Call increment and verify it increases. + let counter = client.increment().await; + assert_eq!(counter, 1, "Counter should be 1 after increment"); + + let counter = client.increment().await; + assert_eq!(counter, 2, "Counter should be 2 after second increment"); + + // Test 5: Call method that changes state and emits signal. + let activate_result = client.activate().await; + assert!( + activate_result.is_ok(), + "Activate should succeed from Idle state" + ); + + // Verify we received the state change. + let state_change = state_stream + .next() + .await + .expect("Should receive state change"); + assert_eq!( + state_change.previous, + State::Idle, + "Previous state should be Idle" + ); + assert_eq!( + state_change.new, + State::Active, + "New state should be Active" + ); + + // Verify we received the operation_complete signal. + let _complete = complete_stream + .next() + .await + .expect("Should receive operation complete signal"); + + // Test 6: Call method that returns error. + let error_result = client.trigger_error().await; + assert!( + error_result.is_err(), + "trigger_error should return an error" + ); + assert_eq!( + error_result.unwrap_err(), + TestError::OperationFailed, + "Should return OperationFailed error" + ); + + // Verify state changed to Error. + let state_change = state_stream + .next() + .await + .expect("Should receive state change"); + assert_eq!( + state_change.previous, + State::Active, + "Previous state should be Active" + ); + assert_eq!(state_change.new, State::Error, "New state should be Error"); + + // Verify we received the error signal. + let error_signal = error_stream + .next() + .await + .expect("Should receive error signal"); + assert_eq!(error_signal.code, 42, "Error code should be 42"); + assert_eq!( + error_signal.message.as_str(), + "Test error", + "Error message should match" + ); + + // Test 7: Try to activate again (should fail due to invalid state). + let activate_result = client.activate().await; + assert!( + activate_result.is_err(), + "Activate should fail from Error state" + ); + assert_eq!( + activate_result.unwrap_err(), + TestError::InvalidState, + "Should return InvalidState error" + ); + + // Test 8: Use pub_setter to change mode. + client.set_mode(Mode::Debug).await; + + // Test 9: Call method with no return value. + client.return_nothing().await; + + // If we get here, all tests passed. + }); +} + +#[embassy_executor::task] +async fn controller_task(controller: Controller) { + controller.run().await; +} From 2ffd30307688f51f28472b924a52e079174605d9 Mon Sep 17 00:00:00 2001 From: Zeeshan Ali Khan Date: Tue, 25 Nov 2025 15:38:18 +0100 Subject: [PATCH 5/7] Bump version We're introducing breaking changes. --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 75e2143..8e1e255 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -230,7 +230,7 @@ dependencies = [ [[package]] name = "firmware-controller" -version = "0.2.0" +version = "0.3.0" dependencies = [ "critical-section", "embassy-executor", diff --git a/Cargo.toml b/Cargo.toml index 0500bee..0a75c33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "firmware-controller" description = "Controller to decouple interactions between components in a no_std environment." -version = "0.2.0" +version = "0.3.0" edition = "2021" authors = [ "Zeeshan Ali Khan ", From f800fb8ba25d5d80d25dbdd91a8270f04c2fc53b Mon Sep 17 00:00:00 2001 From: Zeeshan Ali Khan Date: Tue, 25 Nov 2025 16:33:28 +0100 Subject: [PATCH 6/7] A few typo fixes --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 16c1e56..0468957 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ mod controller { pub struct Controller { #[controller(publish)] state: State, - // Other fields. Note: No all of them need to be published. + // Other fields. Note: Not all of them need to be published. } // The controller implementation. This is where you define the logic of your firmware. @@ -145,7 +145,7 @@ async fn client() { # Details -The `controller` macro will generated the following for you: +The `controller` macro will generate the following for you: ## Controller struct From 8dbe58e4457d8ceaf3464ac02d1cd1a25ee6f36b Mon Sep 17 00:00:00 2001 From: Zeeshan Ali Khan Date: Tue, 25 Nov 2025 16:33:58 +0100 Subject: [PATCH 7/7] README: Fix documentation --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0468957..914896e 100644 --- a/README.md +++ b/README.md @@ -153,7 +153,7 @@ The `controller` macro will generate the following for you: * For each `published` field: * Setter for this field, named `set_` (e.g., `set_state`), which broadcasts any changes made to this field. -* A `run` method with signature `pub async fn run(&mut self);` which runs the controller logic, +* A `run` method with signature `pub async fn run(mut self);` which runs the controller logic, proxying calls from the client to the implementations and their return values back to the clients (internally via channels). Typically you'd call it at the end of your `main` or run it as a task.