diff --git a/CLAUDE.md b/CLAUDE.md index 27caf07..a8bac12 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -11,7 +11,7 @@ This is a procedural macro crate that provides the `#[controller]` attribute mac * Signal mechanism for broadcasting events. * Pub/sub system for state change notifications. -The macro works by processing both struct definitions (to add publishers) and impl blocks (to generate the controller's `run` method, client API, and signal infrastructure). +The macro is applied to a module containing both the controller struct definition and its impl block, allowing coordinated code generation of the controller infrastructure, client API, and communication channels. ## Build & Test Commands @@ -41,10 +41,24 @@ cargo doc --locked ## Architecture ### Macro Entry Point (`src/lib.rs`) -The `controller` attribute macro dispatches to either `item_struct` or `item_impl` based on input type. +The `controller` attribute macro parses the input as an `ItemMod` (module) and calls `controller::expand_module()`. + +### Module Processing (`src/controller/mod.rs`) +The `expand_module()` function: +* Validates the module has a body with exactly one struct and one impl block. +* Extracts the struct and impl items from the module. +* Validates that the impl block matches the struct name. +* Calls `item_struct::expand()` and `item_impl::expand()` to process each component. +* Combines the generated code back into the module structure along with any other items. + +Channel capacities and subscriber limits are also defined here: +* `ALL_CHANNEL_CAPACITY`: 8 +* `SIGNAL_CHANNEL_CAPACITY`: 8 +* `BROADCAST_MAX_PUBLISHERS`: 1 +* `BROADCAST_MAX_SUBSCRIBERS`: 16 ### Struct Processing (`src/controller/item_struct.rs`) -Processes `#[controller]` on struct definitions. For fields marked with `#[controller(publish)]`: +Processes the controller struct definition. For fields marked with `#[controller(publish)]`: * Adds publisher fields to the struct. * Generates setters (`set_`) that broadcast changes. * Creates `` stream type and `Changed` event struct. @@ -52,7 +66,7 @@ Processes `#[controller]` on struct definitions. For fields marked with `#[contr The generated `new()` method initializes both user fields and generated publisher fields. ### Impl Processing (`src/controller/item_impl.rs`) -Processes `#[controller]` on impl blocks. Distinguishes between: +Processes the controller impl block. Distinguishes between: **Proxied methods** (normal methods): * Creates request/response channels for each method. @@ -67,13 +81,6 @@ Processes `#[controller]` on impl blocks. Distinguishes between: The generated `run()` method contains a `select_biased!` loop that receives method calls from clients and dispatches them to the user's implementations. -### Constants (`src/controller/mod.rs`) -Channel capacities and subscriber limits are defined here: -* `ALL_CHANNEL_CAPACITY`: 8 -* `SIGNAL_CHANNEL_CAPACITY`: 8 -* `BROADCAST_MAX_PUBLISHERS`: 1 -* `BROADCAST_MAX_SUBSCRIBERS`: 16 - ### Utilities (`src/util.rs`) Case conversion functions (`pascal_to_snake_case`, `snake_to_pascal_case`) used for generating type and method names. diff --git a/Cargo.lock b/Cargo.lock index 2d445d0..8e1e255 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -230,8 +230,9 @@ dependencies = [ [[package]] name = "firmware-controller" -version = "0.2.0" +version = "0.3.0" dependencies = [ + "critical-section", "embassy-executor", "embassy-sync", "embassy-time", @@ -256,6 +257,7 @@ checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -278,6 +280,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-io" version = "0.3.31" @@ -313,12 +326,16 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", + "futures-io", "futures-macro", "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -384,6 +401,12 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + [[package]] name = "nb" version = "0.1.3" @@ -472,6 +495,12 @@ version = "1.0.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" +[[package]] +name = "slab" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" + [[package]] name = "spin" version = "0.9.8" diff --git a/Cargo.toml b/Cargo.toml index f264032..0a75c33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "firmware-controller" description = "Controller to decouple interactions between components in a no_std environment." -version = "0.2.0" +version = "0.3.0" edition = "2021" authors = [ "Zeeshan Ali Khan ", @@ -22,7 +22,10 @@ syn = { version = "2", features = ["extra-traits", "fold", "full"] } heapless = { version = "0.7", default-features = false } futures = { version = "0.3", default-features = false, features = [ "async-await", + "std", + "executor", ] } +critical-section = { version = "1.2", features = ["std"] } embassy-sync = "0.7.2" embassy-executor = { version = "0.9.1", features = [ "arch-std", diff --git a/README.md b/README.md index 4010508..914896e 100644 --- a/README.md +++ b/README.md @@ -39,53 +39,58 @@ pub enum State { Disabled, } -// The controller struct. This is where you define the state of your firmware. #[controller] -pub struct Controller { - #[controller(publish)] - state: State, - // Other fields. Note: No all of them need to be published. -} +mod controller { + use super::*; + + // The controller struct. This is where you define the state of your firmware. + pub struct Controller { + #[controller(publish)] + state: State, + // Other fields. Note: Not all of them need to be published. + } -// The controller implementation. This is where you define the logic of your firmware. -#[controller] -impl Controller { - // The `signal` attribute marks this method signature (note: no implementation body) as a - // signal, that you can use to notify other parts of your code about specific events. - #[controller(signal)] - pub async fn power_error(&self, description: heapless::String<64>); - - pub async fn enable_power(&mut self) -> Result<(), MyFirmwareError> { - if self.state != State::Disabled { - return Err(MyFirmwareError::InvalidState); - } + // The controller implementation. This is where you define the logic of your firmware. + impl Controller { + // The `signal` attribute marks this method signature (note: no implementation body) as a + // signal, that you can use to notify other parts of your code about specific events. + #[controller(signal)] + pub async fn power_error(&self, description: heapless::String<64>); - // Any other logic you want to run when enabling power. + pub async fn enable_power(&mut self) -> Result<(), MyFirmwareError> { + if self.state != State::Disabled { + return Err(MyFirmwareError::InvalidState); + } - self.set_state(State::Enabled).await; - self.power_error("Dummy error just for the showcase".try_into().unwrap()) - .await; + // Any other logic you want to run when enabling power. - Ok(()) - } + self.set_state(State::Enabled).await; + self.power_error("Dummy error just for the showcase".try_into().unwrap()) + .await; - pub async fn disable_power(&mut self) -> Result<(), MyFirmwareError> { - if self.state != State::Enabled { - return Err(MyFirmwareError::InvalidState); + Ok(()) } - // Any other logic you want to run when enabling power. + pub async fn disable_power(&mut self) -> Result<(), MyFirmwareError> { + if self.state != State::Enabled { + return Err(MyFirmwareError::InvalidState); + } + + // Any other logic you want to run when enabling power. - self.set_state(State::Disabled).await; + self.set_state(State::Disabled).await; - Ok(()) - } + Ok(()) + } - // Method that doesn't return anything. - pub async fn return_nothing(&self) { + // Method that doesn't return anything. + pub async fn return_nothing(&self) { + } } } +use controller::*; + #[embassy_executor::main] async fn main(spawner: embassy_executor::Spawner) { let mut controller = Controller::new(State::Disabled); @@ -104,9 +109,8 @@ async fn client() { use embassy_time::{Timer, Duration}; let mut client = ControllerClient::new(); - // SAFETY: We don't create more than 16 instances so we won't panic. - let state_changed = ControllerState::new().unwrap().map(Either::Left); - let error_stream = ControllerPowerError::new().unwrap().map(Either::Right); + let state_changed = client.receive_state_changed().unwrap().map(Either::Left); + let error_stream = client.receive_power_error().unwrap().map(Either::Right); let mut stream = select(state_changed, error_stream); client.enable_power().await.unwrap(); @@ -141,34 +145,39 @@ async fn client() { # Details -The `controller` macro will generated the following for you: +The `controller` macro will generate the following for you: + +## Controller struct * A `new` method that takes the fields of the struct as arguments and returns the struct. * For each `published` field: - * Setter for this field, named `set_` (so`set_state` here), which broadcasts any + * Setter for this field, named `set_` (e.g., `set_state`), which broadcasts any changes made to this field. - * Two client-side types: - * struct named `Changed` (so `ControllerStateChanged` - for `state` field), containing two public fields, named `previous` and `new` fields - representing the previous and new values of the field, respectively. - * Type named `` (so `ControllerState` for - `state` field), which implements `futures::Stream`, yielding each state change as the change - struct described above. -* `run` method with signature `pub async fn run(&mut self);` which runs the controller logic, - proxying calls from the client to the implementations here and their return value back to - the clients (internally via channels). Typically you'd call it at the end of your `main` - or run it as a task. -* Client-side API for this struct, named `Client` (`ControllerClient` here) - which provides exactly the same methods (except signal methods) defined in this implementation - that other parts of the code use to call these methods. +* A `run` method with signature `pub async fn run(mut self);` which runs the controller logic, + proxying calls from the client to the implementations and their return values back to the + clients (internally via channels). Typically you'd call it at the end of your `main` or run it + as a task. +* For each `signal` method: + * The method body, that broadcasts the signal to all clients that are listening to it. + +## Client API + +A client struct named `Client` (`ControllerClient` in the example) with the following +methods: + +* All methods defined in the controller impl (except signal methods), which proxy calls to the + controller and return the results. +* For each `published` field: + * `receive__changed()` method (e.g., `receive_state_changed()`) that returns a + stream of state changes. The stream yields `Changed` + structs (e.g., `ControllerStateChanged`) containing `previous` and `new` fields. + * If the field is marked with `#[controller(publish(pub_setter))]`, a public + `set_()` method (e.g., `set_state()`) is also generated on the client, allowing + external code to update the field value through the client API. * For each `signal` method: - * The method body, that broadcasts the signal to all the clients that are listening to it. - * Two client-side types: - * struct, named `Args` (`ControllerPowerErrorArgs` - here), containing all the arguments of this method, as public fields. - * Type named `` (`ControllerPowerError` here) which - implements `futures::Stream`, yielding each signal broadcasted as the args struct described - above. + * `receive_()` method (e.g., `receive_power_error()`) that returns a stream of + signal events. The stream yields `Args` structs + (e.g., `ControllerPowerErrorArgs`) containing all signal arguments as public fields. ## Dependencies assumed diff --git a/src/controller/item_impl.rs b/src/controller/item_impl.rs index e6b72a5..0934d07 100644 --- a/src/controller/item_impl.rs +++ b/src/controller/item_impl.rs @@ -7,17 +7,21 @@ use syn::{ Attribute, Ident, ImplItem, ImplItemFn, ItemImpl, Result, Signature, Token, Visibility, }; -use crate::util::snake_to_pascal_case; +use crate::{controller::item_struct::PublishedFieldInfo, util::snake_to_pascal_case}; -pub(crate) fn expand(mut input: ItemImpl) -> Result { +pub(crate) fn expand( + mut input: ItemImpl, + published_fields: &[PublishedFieldInfo], +) -> Result { let struct_name = get_struct_name(&input)?; let struct_name_str = struct_name.to_string(); let methods = get_methods(&mut input, &struct_name)?; - let signal_declarations = methods.iter().filter_map(|m| match m { - Method::Signal(signal) => Some(&signal.declarations), + let signals = methods.iter().filter_map(|m| match m { + Method::Signal(signal) => Some(signal), _ => None, }); + let signal_declarations = signals.clone().map(|s| &s.declarations); let methods = methods.iter().filter_map(|m| match m { Method::Proxied(method) => Some(method), @@ -27,19 +31,61 @@ pub(crate) fn expand(mut input: ItemImpl) -> Result { let args_channels_rx_tx = methods.clone().map(|m| &m.args_channels_rx_tx); let select_arms = methods.clone().map(|m| &m.select_arm); + // Generate public setters for published fields with pub_setter. + let pub_setters: Vec<_> = published_fields + .iter() + .filter(|field| field.pub_setter) + .map(|field| generate_pub_setter(field, &struct_name)) + .collect(); + let pub_setter_channel_declarations = pub_setters.iter().map(|s| &s.channel_declarations); + let pub_setter_rx_tx = pub_setters.iter().map(|s| &s.rx_tx); + let pub_setter_select_arms = pub_setters.iter().map(|s| &s.select_arm); + let pub_setter_client_methods = pub_setters.iter().map(|s| &s.client_method); + let pub_setter_client_tx_rx_declarations = + pub_setters.iter().map(|s| &s.client_tx_rx_declarations); + let pub_setter_client_tx_rx_initializations = + pub_setters.iter().map(|s| &s.client_tx_rx_initializations); + let run_method = quote! { pub async fn run(mut self) { #(#args_channels_rx_tx)* + #(#pub_setter_rx_tx)* loop { futures::select_biased! { - #(#select_arms),* + #(#select_arms,)* + #(#pub_setter_select_arms),* } } } }; input.items.push(syn::parse2(run_method)?); + // Generate stream getter methods for published fields. + let published_field_getters = published_fields.iter().map(|field| { + let method_name = Ident::new( + &format!("receive_{}_changed", field.field_name), + field.field_name.span(), + ); + let subscriber_type = &field.subscriber_struct_name; + quote! { + pub fn #method_name(&self) -> Option<#subscriber_type> { + #subscriber_type::new() + } + } + }); + + // Generate stream getter methods for signals. + let signal_getters = signals.clone().map(|signal| { + let method_name = &signal.receive_method_name; + let subscriber_type = &signal.subscriber_struct_name; + quote! { + pub fn #method_name(&self) -> Option<#subscriber_type> { + #subscriber_type::new() + } + } + }); + let client_name = Ident::new(&format!("{}Client", struct_name_str), input.span()); let client_methods = methods.clone().map(|m| &m.client_method); let client_method_tx_rx_declarations = @@ -50,21 +96,30 @@ pub(crate) fn expand(mut input: ItemImpl) -> Result { Ok(quote! { #(#args_channel_declarations)* + #(#pub_setter_channel_declarations)* #input pub struct #client_name { #(#client_method_tx_rx_declarations)* + #(#pub_setter_client_tx_rx_declarations)* } impl #client_name { pub fn new() -> Self { Self { #(#client_method_tx_rx_initializations)* + #(#pub_setter_client_tx_rx_initializations)* } } #(#client_methods)* + + #(#pub_setter_client_methods)* + + #(#published_field_getters)* + + #(#signal_getters)* } #(#signal_declarations)* @@ -329,6 +384,10 @@ impl ProxiedMethodArgs<'_> { struct Signal { /// The input arguments' channel and client-side struct declarations. declarations: TokenStream, + /// Name of the receive method (e.g., receive_power_error). + receive_method_name: Ident, + /// Name of the subscriber struct (e.g., ControllerPowerError). + subscriber_struct_name: Ident, } impl Signal { @@ -429,7 +488,14 @@ impl Signal { ).await; }); - Ok(Self { declarations }) + let receive_method_name = + Ident::new(&format!("receive_{}", method_name_str), method.span()); + + Ok(Self { + declarations, + receive_method_name, + subscriber_struct_name, + }) } } @@ -523,3 +589,109 @@ impl MethodInputArgs { }) } } + +#[derive(Debug)] +struct PubSetter { + channel_declarations: TokenStream, + rx_tx: TokenStream, + select_arm: TokenStream, + client_method: TokenStream, + client_tx_rx_declarations: TokenStream, + client_tx_rx_initializations: TokenStream, +} + +fn generate_pub_setter(field: &PublishedFieldInfo, struct_name: &Ident) -> PubSetter { + let field_name = &field.field_name; + let field_type = &field.field_type; + let setter_method_name = &field.setter_name; + let field_name_str = field_name.to_string(); + + let struct_name_caps = struct_name.to_string().to_uppercase(); + let field_name_caps = field_name_str.to_uppercase(); + let input_channel_name = Ident::new( + &format!("{}_SET_{}_INPUT_CHANNEL", struct_name_caps, field_name_caps), + field_name.span(), + ); + let output_channel_name = Ident::new( + &format!( + "{}_SET_{}_OUTPUT_CHANNEL", + struct_name_caps, field_name_caps + ), + field_name.span(), + ); + let capacity = super::ALL_CHANNEL_CAPACITY; + + let channel_declarations = quote! { + static #input_channel_name: + embassy_sync::channel::Channel< + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + #field_type, + #capacity, + > = embassy_sync::channel::Channel::new(); + static #output_channel_name: + embassy_sync::channel::Channel< + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + (), + #capacity, + > = embassy_sync::channel::Channel::new(); + }; + + let input_channel_rx_name = + Ident::new(&format!("{}_value_rx", field_name_str), field_name.span()); + let output_channel_tx_name = + Ident::new(&format!("{}_ack_tx", field_name_str), field_name.span()); + let rx_tx = quote! { + let #input_channel_rx_name = embassy_sync::channel::Channel::receiver(&#input_channel_name); + let #output_channel_tx_name = embassy_sync::channel::Channel::sender(&#output_channel_name); + }; + + let select_arm = quote! { + value = futures::FutureExt::fuse( + embassy_sync::channel::Receiver::receive(&#input_channel_rx_name), + ) => { + self.#setter_method_name(value).await; + + embassy_sync::channel::Sender::send(&#output_channel_tx_name, ()).await; + } + }; + + let input_channel_tx_name = + Ident::new(&format!("{}_value_tx", field_name_str), field_name.span()); + let output_channel_rx_name = + Ident::new(&format!("{}_ack_rx", field_name_str), field_name.span()); + let client_method = quote! { + pub async fn #setter_method_name(&self, value: #field_type) { + embassy_sync::channel::Sender::send(&self.#input_channel_tx_name, value).await; + embassy_sync::channel::Receiver::receive(&self.#output_channel_rx_name).await + } + }; + + let client_tx_rx_declarations = quote! { + #input_channel_tx_name: embassy_sync::channel::Sender< + 'static, + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + #field_type, + #capacity, + >, + #output_channel_rx_name: embassy_sync::channel::Receiver< + 'static, + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + (), + #capacity, + > + }; + + let client_tx_rx_initializations = quote! { + #input_channel_tx_name: embassy_sync::channel::Channel::sender(&#input_channel_name), + #output_channel_rx_name: embassy_sync::channel::Channel::receiver(&#output_channel_name) + }; + + PubSetter { + channel_declarations, + rx_tx, + select_arm, + client_method, + client_tx_rx_declarations, + client_tx_rx_initializations, + } +} diff --git a/src/controller/item_struct.rs b/src/controller/item_struct.rs index 95e3588..da929ac 100644 --- a/src/controller/item_struct.rs +++ b/src/controller/item_struct.rs @@ -1,9 +1,25 @@ use crate::util::*; use proc_macro2::TokenStream; use quote::quote; -use syn::{spanned::Spanned, Field, Fields, Ident, ItemStruct, Result}; +use syn::{spanned::Spanned, Field, Fields, Ident, ItemStruct, Result, Token}; + +/// Information about a published field, to be used by impl processing. +#[derive(Debug, Clone)] +pub(crate) struct PublishedFieldInfo { + pub field_name: Ident, + pub field_type: syn::Type, + pub setter_name: Ident, + pub subscriber_struct_name: Ident, + pub pub_setter: bool, +} + +/// Result of expanding a struct. +pub(crate) struct ExpandedStruct { + pub tokens: TokenStream, + pub published_fields: Vec, +} -pub(crate) fn expand(mut input: ItemStruct) -> Result { +pub(crate) fn expand(mut input: ItemStruct) -> Result { let struct_name = &input.ident; let fields = StructFields::parse(&mut input.fields, struct_name)?; @@ -14,14 +30,16 @@ pub(crate) fn expand(mut input: ItemStruct) -> Result { publisher_fields_initializations, setters, subscriber_declarations, + published_fields_info, ) = fields.published().fold( - (quote!(), quote!(), quote!(), quote!(), quote!()), + (quote!(), quote!(), quote!(), quote!(), quote!(), Vec::new()), |( publish_channels, publisher_fields_declarations, publisher_fields_initializations, setters, subscribers, + mut infos, ), f| { let (publish_channel, publisher_field, publisher_field_init, setter, subscriber) = ( @@ -32,39 +50,45 @@ pub(crate) fn expand(mut input: ItemStruct) -> Result { &f.subscriber_declaration, ); + infos.push(f.info.clone()); + ( quote! { #publish_channels #publish_channel }, quote! { #publisher_fields_declarations #publisher_field, }, quote! { #publisher_fields_initializations #publisher_field_init, }, quote! { #setters #setter }, quote! { #subscribers #subscriber }, + infos, ) }, ); let fields = fields.raw_fields().collect::>(); let vis = &input.vis; - Ok(quote! { - #vis struct #struct_name { - #(#fields),*, - #publisher_fields_declarations - } + Ok(ExpandedStruct { + tokens: quote! { + #vis struct #struct_name { + #(#fields),*, + #publisher_fields_declarations + } - impl #struct_name { - #[allow(clippy::too_many_arguments)] - pub fn new(#(#fields),*) -> Self { - Self { - #(#field_names),*, - #publisher_fields_initializations + impl #struct_name { + #[allow(clippy::too_many_arguments)] + pub fn new(#(#fields),*) -> Self { + Self { + #(#field_names),*, + #publisher_fields_initializations + } } - } - #setters - } + #setters + } - #publish_channel_declarations + #publish_channel_declarations - #subscriber_declarations + #subscriber_declarations + }, + published_fields: published_fields_info, }) } @@ -110,7 +134,7 @@ impl StructFields { /// All the published fields. fn published(&self) -> impl Iterator { self.fields.iter().filter_map(|field| match field { - StructField::Published(published) => Some(published), + StructField::Published(published) => Some(published.as_ref()), _ => None, }) } @@ -120,9 +144,9 @@ impl StructFields { #[derive(Debug)] enum StructField { /// Private field. - Private(Field), + Private(Box), /// Published field. - Published(PublishedField), + Published(Box), } impl StructField { @@ -130,15 +154,16 @@ impl StructField { fn parse(field: &mut Field, struct_name: &Ident) -> Result { PublishedField::parse(field, struct_name).map(|published| { published - .map(StructField::Published) - .unwrap_or_else(|| StructField::Private(field.clone())) + .map(|p| StructField::Published(Box::new(p))) + .unwrap_or_else(|| StructField::Private(Box::new(field.clone()))) }) } /// Get the field. fn field(&self) -> &Field { match self { - Self::Private(field) | Self::Published(PublishedField { field, .. }) => field, + Self::Private(field) => field.as_ref(), + Self::Published(published) => &published.field, } } } @@ -158,6 +183,8 @@ struct PublishedField { publish_channel_declaration: proc_macro2::TokenStream, /// Subscriber struct declaration. subscriber_declaration: proc_macro2::TokenStream, + /// Information to be passed to impl processing. + info: PublishedFieldInfo, } impl PublishedField { @@ -171,6 +198,7 @@ impl PublishedField { Some(attr) => attr, None => return Ok(None), }; + let mut pub_setter = false; attr.parse_nested_meta(|meta| { if !meta.path.is_ident("publish") { let e = format!( @@ -181,6 +209,25 @@ impl PublishedField { return Err(syn::Error::new_spanned(attr, e)); } + if meta.input.peek(syn::token::Paren) { + let content; + syn::parenthesized!(content in meta.input); + while !content.is_empty() { + let nested_ident: Ident = content.parse()?; + if nested_ident == "pub_setter" { + pub_setter = true; + } else { + let e = + format!("expected `pub_setter` attribute, found `{}`", nested_ident); + return Err(syn::Error::new_spanned(&nested_ident, e)); + } + + if !content.is_empty() { + content.parse::()?; + } + } + } + Ok(()) })?; field @@ -209,6 +256,7 @@ impl PublishedField { let max_subscribers = super::BROADCAST_MAX_SUBSCRIBERS; let max_publishers = super::BROADCAST_MAX_PUBLISHERS; + let setter_name = Ident::new(&format!("set_{field_name_str}"), field.span()); let publisher_name = Ident::new(&format!("{field_name_str}_publisher"), field.span()); let publisher_field_declaration = quote! { #publisher_name: @@ -225,7 +273,6 @@ impl PublishedField { // We only create one publisher so we can't fail. #publisher_name: embassy_sync::pubsub::PubSubChannel::publisher(&#publish_channel_name).unwrap() }; - let setter_name = Ident::new(&format!("set_{field_name_str}"), field.span()); let setter = quote! { pub async fn #setter_name(&mut self, mut value: #ty) { core::mem::swap(&mut self.#field_name, &mut value); @@ -291,6 +338,14 @@ impl PublishedField { } }; + let info = PublishedFieldInfo { + field_name: field_name.clone(), + field_type: ty.clone(), + setter_name: setter_name.clone(), + subscriber_struct_name: subscriber_struct_name.clone(), + pub_setter, + }; + Ok(Some(PublishedField { field: field.clone(), publisher_field_declaration, @@ -298,6 +353,7 @@ impl PublishedField { setter, publish_channel_declaration, subscriber_declaration, + info, })) } } diff --git a/src/controller/mod.rs b/src/controller/mod.rs index c433e0d..5a4583e 100644 --- a/src/controller/mod.rs +++ b/src/controller/mod.rs @@ -1,7 +1,89 @@ pub(crate) mod item_impl; pub(crate) mod item_struct; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{spanned::Spanned, Item, ItemMod, Result}; + const ALL_CHANNEL_CAPACITY: usize = 8; const SIGNAL_CHANNEL_CAPACITY: usize = 8; const BROADCAST_MAX_PUBLISHERS: usize = 1; const BROADCAST_MAX_SUBSCRIBERS: usize = 16; + +pub(crate) fn expand_module(input: ItemMod) -> Result { + let vis = &input.vis; + let mod_name = &input.ident; + let span = input.span(); + + let (_, items) = input + .content + .ok_or_else(|| syn::Error::new(span, "Module must have a body"))?; + + let mut struct_item = None; + let mut impl_item = None; + let mut other_items = Vec::new(); + + for item in items { + match item { + Item::Struct(s) => { + if struct_item.is_some() { + return Err(syn::Error::new( + s.span(), + "Module must contain exactly one struct definition", + )); + } + struct_item = Some(s); + } + Item::Impl(i) => { + if impl_item.is_some() { + return Err(syn::Error::new( + i.span(), + "Module must contain exactly one impl block", + )); + } + impl_item = Some(i); + } + other => other_items.push(other), + } + } + + let struct_item = struct_item.ok_or_else(|| { + syn::Error::new( + span, + "Module must contain a struct definition for the controller", + ) + })?; + + let impl_item = impl_item.ok_or_else(|| { + syn::Error::new(span, "Module must contain an impl block for the controller") + })?; + + let struct_name = &struct_item.ident; + if let syn::Type::Path(type_path) = &*impl_item.self_ty { + if let Some(ident) = type_path.path.get_ident() { + if ident != struct_name { + return Err(syn::Error::new( + impl_item.span(), + format!( + "Impl block is for type '{}' but controller struct is named '{}'", + ident, struct_name + ), + )); + } + } + } + + let expanded_struct = item_struct::expand(struct_item)?; + let expanded_impl = item_impl::expand(impl_item, &expanded_struct.published_fields)?; + let struct_tokens = expanded_struct.tokens; + + Ok(quote! { + #vis mod #mod_name { + #(#other_items)* + + #struct_tokens + + #expanded_impl + } + }) +} diff --git a/src/lib.rs b/src/lib.rs index 4378af5..e897e92 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,7 @@ #![doc = include_str!("../README.md")] use proc_macro::TokenStream; -use syn::{parse_macro_input, punctuated::Punctuated, ItemImpl, ItemStruct, Meta, Token}; +use syn::{parse_macro_input, punctuated::Punctuated, ItemMod, Meta, Token}; mod controller; mod util; @@ -11,15 +11,8 @@ mod util; pub fn controller(attr: TokenStream, item: TokenStream) -> TokenStream { let _args = parse_macro_input!(attr with Punctuated::parse_terminated); - if let Ok(input) = syn::parse::(item.clone()) { - controller::item_struct::expand(input) - .unwrap_or_else(|e| e.to_compile_error()) - .into() - } else if let Ok(input) = syn::parse::(item) { - controller::item_impl::expand(input) - .unwrap_or_else(|e| e.to_compile_error()) - .into() - } else { - panic!("Expected struct or trait") - } + let input = parse_macro_input!(item as ItemMod); + controller::expand_module(input) + .unwrap_or_else(|e| e.to_compile_error()) + .into() } diff --git a/tests/integration.rs b/tests/integration.rs new file mode 100644 index 0000000..5c5dea3 --- /dev/null +++ b/tests/integration.rs @@ -0,0 +1,205 @@ +use firmware_controller::controller; +use futures::StreamExt; + +#[derive(Debug, PartialEq, Copy, Clone)] +pub enum State { + Idle, + Active, + Error, +} + +#[derive(Debug, PartialEq, Copy, Clone)] +pub enum Mode { + Normal, + Debug, +} + +#[derive(Debug, PartialEq)] +pub enum TestError { + InvalidState, + OperationFailed, +} + +#[controller] +mod test_controller { + use super::*; + + pub struct Controller { + #[controller(publish)] + state: State, + #[controller(publish(pub_setter))] + mode: Mode, + counter: u32, + } + + impl Controller { + #[controller(signal)] + pub async fn error_occurred(&self, code: u32, message: heapless::String<32>); + + #[controller(signal)] + pub async fn operation_complete(&self); + + pub async fn increment(&mut self) -> u32 { + self.counter += 1; + self.counter + } + + pub async fn get_counter(&self) -> u32 { + self.counter + } + + pub async fn activate(&mut self) -> Result<(), TestError> { + if self.state != State::Idle { + return Err(TestError::InvalidState); + } + self.set_state(State::Active).await; + self.operation_complete().await; + Ok(()) + } + + pub async fn trigger_error(&mut self) -> Result<(), TestError> { + self.set_state(State::Error).await; + self.error_occurred(42, "Test error".try_into().unwrap()) + .await; + Err(TestError::OperationFailed) + } + + pub async fn return_nothing(&self) {} + } +} + +use test_controller::*; + +#[test] +fn test_controller_basic_functionality() { + // Create the controller before spawning the thread to avoid any race conditions. + // The channels used for communication will buffer requests, so it's safe for the + // client to start making calls even if the controller task hasn't fully started yet. + let controller = Controller::new(State::Idle, Mode::Normal, 0); + + // Run the controller in a background thread. + std::thread::spawn(move || { + let executor = Box::leak(Box::new(embassy_executor::Executor::new())); + executor.run(move |spawner| { + spawner.spawn(controller_task(controller)).unwrap(); + }); + }); + + // Run the test logic. + futures::executor::block_on(async { + // Create client. + let mut client = ControllerClient::new(); + + // Test 1: Subscribe to state changes. + let mut state_stream = client.receive_state_changed().expect("Failed to subscribe"); + + // Test 2: Subscribe to signals. + let mut error_stream = client + .receive_error_occurred() + .expect("Failed to subscribe to error"); + let mut complete_stream = client + .receive_operation_complete() + .expect("Failed to subscribe to complete"); + + // Test 3: Call a method and verify return value. + let counter = client.get_counter().await; + assert_eq!(counter, 0, "Initial counter should be 0"); + + // Test 4: Call increment and verify it increases. + let counter = client.increment().await; + assert_eq!(counter, 1, "Counter should be 1 after increment"); + + let counter = client.increment().await; + assert_eq!(counter, 2, "Counter should be 2 after second increment"); + + // Test 5: Call method that changes state and emits signal. + let activate_result = client.activate().await; + assert!( + activate_result.is_ok(), + "Activate should succeed from Idle state" + ); + + // Verify we received the state change. + let state_change = state_stream + .next() + .await + .expect("Should receive state change"); + assert_eq!( + state_change.previous, + State::Idle, + "Previous state should be Idle" + ); + assert_eq!( + state_change.new, + State::Active, + "New state should be Active" + ); + + // Verify we received the operation_complete signal. + let _complete = complete_stream + .next() + .await + .expect("Should receive operation complete signal"); + + // Test 6: Call method that returns error. + let error_result = client.trigger_error().await; + assert!( + error_result.is_err(), + "trigger_error should return an error" + ); + assert_eq!( + error_result.unwrap_err(), + TestError::OperationFailed, + "Should return OperationFailed error" + ); + + // Verify state changed to Error. + let state_change = state_stream + .next() + .await + .expect("Should receive state change"); + assert_eq!( + state_change.previous, + State::Active, + "Previous state should be Active" + ); + assert_eq!(state_change.new, State::Error, "New state should be Error"); + + // Verify we received the error signal. + let error_signal = error_stream + .next() + .await + .expect("Should receive error signal"); + assert_eq!(error_signal.code, 42, "Error code should be 42"); + assert_eq!( + error_signal.message.as_str(), + "Test error", + "Error message should match" + ); + + // Test 7: Try to activate again (should fail due to invalid state). + let activate_result = client.activate().await; + assert!( + activate_result.is_err(), + "Activate should fail from Error state" + ); + assert_eq!( + activate_result.unwrap_err(), + TestError::InvalidState, + "Should return InvalidState error" + ); + + // Test 8: Use pub_setter to change mode. + client.set_mode(Mode::Debug).await; + + // Test 9: Call method with no return value. + client.return_nothing().await; + + // If we get here, all tests passed. + }); +} + +#[embassy_executor::task] +async fn controller_task(controller: Controller) { + controller.run().await; +}