diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9a49a8ed..3ce43c0e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -150,6 +150,7 @@ jobs: cargo build else cargo run + cargo test fi done diff --git a/CHANGELOG.md b/CHANGELOG.md index 50a3ce14..f3d5b959 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,10 @@ ### v0.x.x - 202x-xx-xx - **[BREAKING]** Rename `derive_unsafe` to `derive_unchecked` (both the feature flag and the attribute). -- **[FEATURE]** Ability to derive [`Valuable`](https://docs.rs/valuable/0.1.1/valuable/trait.Valuable.html) (requires `valuable` feature). +- **[FEATURE]** Support `cfg_attr` for conditional derives, e.g. `cfg_attr(feature = "serde", derive(Serialize, Deserialize))`. Supports complex predicates and multiple entries. +- **[FEATURE]** Support `where` clauses in generic newtypes, including Higher-Ranked Trait Bounds (HRTB) like `for<'a> &'a C: IntoIterator` (see [#160](https://github.com/greyblake/nutype/issues/160)). - **[FEATURE]** Ability to control constructor visibility with `constructor(visibility = ...)` attribute (see [#211](https://github.com/greyblake/nutype/issues/211)). - **[FEATURE]** Add `len_utf16_min` and `len_utf16_max` validators for string types to validate UTF-16 code unit length (useful for JavaScript interop) (see [#162](https://github.com/greyblake/nutype/issues/162)). -- **[FEATURE]** Support `where` clauses in generic newtypes, including Higher-Ranked Trait Bounds (HRTB) like `for<'a> &'a C: IntoIterator` (see [#160](https://github.com/greyblake/nutype/issues/160)). +- **[FEATURE]** Ability to derive [`Valuable`](https://docs.rs/valuable/0.1.1/valuable/trait.Valuable.html) (requires `valuable` feature). ### v0.6.2 - 2025-06-30 - **[FEATURE]** Introduce `derive_unsafe(..)` attribute to derive any arbitrary trait (requires `derive_unsafe` feature to be enabled). diff --git a/Cargo.lock b/Cargo.lock index ccf9304c..cd86090e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -83,6 +83,15 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "cfg_attr_example" +version = "0.1.0" +dependencies = [ + "nutype", + "serde", + "serde_json", +] + [[package]] name = "const_example" version = "0.1.0" @@ -208,6 +217,12 @@ version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "impls" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a46645bbd70538861a90d0f26c31537cdf1e44aae99a794fb75a664b70951bc" + [[package]] name = "indexmap" version = "2.13.0" @@ -673,6 +688,7 @@ version = "0.1.0" dependencies = [ "arbitrary", "arbtest 0.2.0", + "impls", "lazy_static", "num", "nutype", diff --git a/Cargo.toml b/Cargo.toml index 9e40d4c3..f03dad09 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,4 +24,5 @@ members = [ "examples/any_generics", "examples/custom_error", "examples/const_example", "examples/valuable_example", + "examples/cfg_attr_example", ] diff --git a/Justfile b/Justfile index 605254c5..1fa05b3b 100644 --- a/Justfile +++ b/Justfile @@ -45,6 +45,7 @@ examples: cargo build else cargo run + cargo test fi done diff --git a/README.md b/README.md index 50cea7dc..dd5644c3 100644 --- a/README.md +++ b/README.md @@ -409,6 +409,44 @@ However, **use this with caution**: `nutype` cannot verify that these traits pre It is the developer's responsibility to ensure that the derived traits do not introduce ways to bypass validation (e.g., by allowing mutable access to the inner value). +### `cfg_attr` + +You can use `cfg_attr` to conditionally derive traits based on `cfg` predicates: + +```rust +#[nutype( + derive(Debug, PartialEq), + cfg_attr(feature = "serde", derive(Serialize, Deserialize)), +)] +pub struct Email(String); +``` + +Only `derive(...)` and `derive_unchecked(...)` are supported inside `cfg_attr`. + +Complex predicates work as well: + +```rust +#[nutype( + derive(Debug), + cfg_attr(all(test, debug_assertions), derive(Clone, Display)), +)] +pub struct Label(String); +``` + +Multiple `cfg_attr` entries are allowed: + +```rust +#[nutype( + derive(Debug), + cfg_attr(test, derive(Clone)), + cfg_attr(feature = "serde", derive(Serialize, Deserialize)), +)] +pub struct Tag(String); +``` + +Note that a trait cannot appear in both unconditional `derive` and `cfg_attr` `derive` at the same time. + + ## Constants You can mark a type with the `const_fn` flag. In that case, its `new` and `try_new` functions will be declared as `const`: diff --git a/examples/cfg_attr_example/Cargo.toml b/examples/cfg_attr_example/Cargo.toml new file mode 100644 index 00000000..6de5f451 --- /dev/null +++ b/examples/cfg_attr_example/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "cfg_attr_example" +version = "0.1.0" +edition = "2024" +publish = false + +[features] +serde = ["dep:serde", "dep:serde_json"] + +[dependencies] +nutype = { path = "../../nutype", features = ["serde"] } +serde = { version = "1.0", features = ["derive"], optional = true } +serde_json = { version = "1.0", optional = true } diff --git a/examples/cfg_attr_example/src/main.rs b/examples/cfg_attr_example/src/main.rs new file mode 100644 index 00000000..8315ffab --- /dev/null +++ b/examples/cfg_attr_example/src/main.rs @@ -0,0 +1,106 @@ +use nutype::nutype; + +// 1. Conditional serde behind a feature flag +// Serialize and Deserialize are only derived when the "serde" feature is enabled. +// Note: nutype/serde must be enabled at compile time so the macro accepts these traits, +// but the actual derive is gated by cfg_attr. +#[nutype( + sanitize(trim, lowercase), + validate(not_empty, len_char_max = 100), + derive(Debug, Clone, PartialEq, AsRef), + cfg_attr(feature = "serde", derive(Serialize, Deserialize)) +)] +pub struct Email(String); + +// 2. Conditional Default for tests +// Default is only derived under `cfg(test)`, but the default value is always specified. +#[nutype( + validate(greater_or_equal = 1, less_or_equal = 65535), + default = 8080, + derive(Debug, Clone, Copy, PartialEq, Into), + cfg_attr(test, derive(Default)) +)] +pub struct Port(u16); + +// 3. Complex predicate +// Clone and Display are only derived when both `test` and `debug_assertions` are active. +#[nutype( + sanitize(trim), + validate(not_empty, len_char_max = 50), + derive(Debug, PartialEq, AsRef), + cfg_attr(all(test, debug_assertions), derive(Clone, Display)) +)] +pub struct Label(String); + +// 4. Multiple cfg_attr entries +// Each cfg_attr line is independent and can gate different traits behind different predicates. +#[nutype( + validate(not_empty), + derive(Debug), + cfg_attr(test, derive(Clone)), + cfg_attr(feature = "serde", derive(Serialize, Deserialize)) +)] +pub struct Tag(String); + +fn main() { + // Exercise Email + let email = Email::try_new(" Alice@Example.COM ").unwrap(); + assert_eq!(email.as_ref(), "alice@example.com"); + println!("Email: {email:?}"); + + // Exercise Email with serde (only when the feature is enabled) + #[cfg(feature = "serde")] + { + let json = serde_json::to_string(&email).unwrap(); + println!("Email as JSON: {json}"); + + let parsed: Email = serde_json::from_str(&json).unwrap(); + assert_eq!(email, parsed); + println!("Round-tripped email: {parsed:?}"); + } + + // Exercise Port + let port = Port::try_new(3000).unwrap(); + let port_val: u16 = port.into(); + assert_eq!(port_val, 3000u16); + println!("Port: {port:?}"); + + // Exercise Label + let label = Label::try_new(" Rust ").unwrap(); + assert_eq!(label.as_ref(), "Rust"); + println!("Label: {label:?}"); + + // Exercise Tag + let tag = Tag::try_new("nutype").unwrap(); + println!("Tag: {tag:?}"); + + // Exercise Tag with serde (only when the feature is enabled) + #[cfg(feature = "serde")] + { + let json = serde_json::to_string(&tag).unwrap(); + println!("Tag as JSON: {json}"); + } + + println!("All cfg_attr examples passed!"); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_port_default() { + // Default is conditionally derived under cfg(test), so this works in tests. + let port = Port::default(); + let port_val: u16 = port.into(); + assert_eq!(port_val, 8080u16); + } + + #[test] + fn test_tag_clone() { + // Clone is conditionally derived under cfg(test). + let tag = Tag::try_new("example").unwrap(); + let tag2 = tag.clone(); + assert_eq!(format!("{tag:?}"), format!("{tag2:?}")); + } +} diff --git a/nutype/src/lib.rs b/nutype/src/lib.rs index 3761faac..f27d8cbf 100644 --- a/nutype/src/lib.rs +++ b/nutype/src/lib.rs @@ -469,6 +469,50 @@ //! It is the developer's responsibility to ensure that the derived traits do not introduce ways to bypass validation (e.g., by allowing mutable access to the inner value). //! //! +//! ### `cfg_attr` +//! +//! You can use `cfg_attr` to conditionally derive traits based on `cfg` predicates: +//! +//! ```rust +//! use nutype::nutype; +//! +//! #[nutype( +//! derive(Debug, PartialEq), +//! cfg_attr(test, derive(Clone)), +//! )] +//! pub struct Email(String); +//! ``` +//! +//! Only `derive(...)` and `derive_unchecked(...)` are supported inside `cfg_attr`. +//! +//! Complex predicates work as well: +//! +//! ```rust +//! use nutype::nutype; +//! +//! #[nutype( +//! derive(Debug), +//! cfg_attr(all(test, debug_assertions), derive(Clone, Display)), +//! )] +//! pub struct Label(String); +//! ``` +//! +//! Multiple `cfg_attr` entries are allowed: +//! +//! ```rust +//! use nutype::nutype; +//! +//! #[nutype( +//! derive(Debug), +//! cfg_attr(test, derive(Clone)), +//! cfg_attr(test, derive(Display)), +//! )] +//! pub struct Tag(String); +//! ``` +//! +//! Note that a trait cannot appear in both unconditional `derive` and `cfg_attr` `derive` at the same time. +//! +//! //! ## Constants //! //! You can mark a type with the `const_fn` flag. In that case, its `new` and `try_new` functions will be declared as `const`: diff --git a/nutype_macros/src/any/generate/mod.rs b/nutype_macros/src/any/generate/mod.rs index 5d3dd6b9..8de786c8 100644 --- a/nutype_macros/src/any/generate/mod.rs +++ b/nutype_macros/src/any/generate/mod.rs @@ -12,7 +12,8 @@ use crate::common::{ GenerateNewtype, tests::gen_test_should_have_valid_default_value, traits::GeneratedTraits, }, models::{ - ConstFn, ErrorTypePath, Guard, SpannedDeriveUnsafeTrait, TypeName, TypedCustomFunction, + ConditionalDeriveGroup, ConstFn, ErrorTypePath, Guard, SpannedDeriveUnsafeTrait, TypeName, + TypedCustomFunction, }, }; @@ -125,6 +126,7 @@ impl GenerateNewtype for AnyNewtype { unsafe_traits: &[SpannedDeriveUnsafeTrait], maybe_default_value: Option, guard: &AnyGuard, + conditional_derives: &[ConditionalDeriveGroup], ) -> Result { gen_traits( type_name, @@ -134,6 +136,7 @@ impl GenerateNewtype for AnyNewtype { unsafe_traits, maybe_default_value, guard, + conditional_derives, ) } diff --git a/nutype_macros/src/any/generate/traits/mod.rs b/nutype_macros/src/any/generate/traits/mod.rs index e1b18e08..65278c4d 100644 --- a/nutype_macros/src/any/generate/traits/mod.rs +++ b/nutype_macros/src/any/generate/traits/mod.rs @@ -9,13 +9,14 @@ use crate::{ any::models::{AnyDeriveTrait, AnyGuard, AnyInnerType}, common::{ generate::traits::{ - GeneratableTrait, GeneratableTraits, GeneratedTraits, gen_impl_trait_as_ref, - gen_impl_trait_borrow, gen_impl_trait_default, gen_impl_trait_deref, - gen_impl_trait_display, gen_impl_trait_from, gen_impl_trait_from_str, - gen_impl_trait_into, gen_impl_trait_serde_deserialize, gen_impl_trait_serde_serialize, - gen_impl_trait_try_from, split_into_generatable_traits, + ConditionalTraits, GeneratableTrait, GeneratableTraits, GeneratedTraits, + HasGeneratedParseError, gen_impl_trait_as_ref, gen_impl_trait_borrow, + gen_impl_trait_default, gen_impl_trait_deref, gen_impl_trait_display, + gen_impl_trait_from, gen_impl_trait_from_str, gen_impl_trait_into, + gen_impl_trait_serde_deserialize, gen_impl_trait_serde_serialize, + gen_impl_trait_try_from, process_conditional_derives, split_into_generatable_traits, }, - models::{SpannedDeriveUnsafeTrait, TypeName}, + models::{ConditionalDeriveGroup, SpannedDeriveUnsafeTrait, TypeName}, }, }; @@ -114,6 +115,15 @@ enum AnyIrregularTrait { ArbitraryArbitrary, } +/// Any's `FromStr` generates a `ParseError` type via `gen_impl_trait_from_str` -> +/// `gen_def_parse_error`, which needs module-level re-export in conditional derives. +impl HasGeneratedParseError for AnyIrregularTrait { + fn has_generated_parse_error(&self) -> bool { + matches!(self, Self::FromStr) + } +} + +#[allow(clippy::too_many_arguments)] pub fn gen_traits( type_name: &TypeName, generics: &syn::Generics, @@ -122,6 +132,7 @@ pub fn gen_traits( unsafe_traits: &[SpannedDeriveUnsafeTrait], maybe_default_value: Option, guard: &AnyGuard, + conditional_derives: &[ConditionalDeriveGroup], ) -> Result { let GeneratableTraits { transparent_traits, @@ -140,13 +151,31 @@ pub fn gen_traits( generics, inner_type, irregular_traits, - maybe_default_value, + maybe_default_value.clone(), guard, )?; + let ConditionalTraits { + derive_transparent_traits: conditional_derive_transparent_traits, + implement_traits: conditional_implement_traits, + from_str_parse_errors: conditional_from_str_parse_errors, + } = process_conditional_derives(conditional_derives, type_name, |irregular| { + gen_implemented_traits( + type_name, + generics, + inner_type, + irregular, + maybe_default_value.clone(), + guard, + ) + })?; + Ok(GeneratedTraits { derive_transparent_traits, implement_traits, + conditional_derive_transparent_traits, + conditional_implement_traits, + conditional_from_str_parse_errors, }) } diff --git a/nutype_macros/src/any/mod.rs b/nutype_macros/src/any/mod.rs index 8d83692a..e41ba0d1 100644 --- a/nutype_macros/src/any/mod.rs +++ b/nutype_macros/src/any/mod.rs @@ -4,14 +4,15 @@ pub mod parse; pub mod validate; use proc_macro2::TokenStream; -use std::collections::HashSet; use self::models::{AnyDeriveTrait, AnyGuard, AnyInnerType, AnySanitizer, AnyValidator}; use crate::common::generate::GenerateNewtype; use crate::common::models::TypeName; use crate::{ any::validate::validate_any_derive_traits, - common::models::{Attributes, GenerateParams, Newtype, SpannedDeriveTrait}, + common::models::{ + Attributes, CfgAttrEntry, GenerateParams, Newtype, SpannedDeriveTrait, ValidatedDerives, + }, }; pub struct AnyNewtype; @@ -32,8 +33,17 @@ impl Newtype for AnyNewtype { fn validate( guard: &AnyGuard, derive_traits: Vec, - ) -> Result, syn::Error> { - validate_any_derive_traits(guard, derive_traits) + cfg_attr_entries: &[CfgAttrEntry], + maybe_default_value: &Option, + type_name: &TypeName, + ) -> Result, syn::Error> { + validate_any_derive_traits( + guard, + derive_traits, + cfg_attr_entries, + maybe_default_value, + type_name, + ) } fn generate( diff --git a/nutype_macros/src/any/models.rs b/nutype_macros/src/any/models.rs index a83833cb..6d8dea48 100644 --- a/nutype_macros/src/any/models.rs +++ b/nutype_macros/src/any/models.rs @@ -24,7 +24,7 @@ pub enum AnyValidator { pub type SpannedAnyValidator = SpannedItem; -#[derive(Debug, Eq, PartialEq, Hash)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] pub enum AnyDeriveTrait { // Standard Debug, @@ -57,6 +57,9 @@ impl TypeTrait for AnyDeriveTrait { fn is_from_str(&self) -> bool { self == &AnyDeriveTrait::FromStr } + fn is_default(&self) -> bool { + self == &AnyDeriveTrait::Default + } } pub type AnyRawGuard = RawGuard; diff --git a/nutype_macros/src/any/parse.rs b/nutype_macros/src/any/parse.rs index b9fa7396..d78af751 100644 --- a/nutype_macros/src/any/parse.rs +++ b/nutype_macros/src/any/parse.rs @@ -31,6 +31,7 @@ pub fn parse_attributes( default, derive_traits, derive_unchecked_traits, + cfg_attr_entries, } = attrs; let raw_guard = AnyRawGuard { sanitizers, @@ -45,6 +46,7 @@ pub fn parse_attributes( default, derive_traits, derive_unchecked_traits, + cfg_attr_entries, }) } diff --git a/nutype_macros/src/any/validate.rs b/nutype_macros/src/any/validate.rs index 8b6d42ae..2d62532a 100644 --- a/nutype_macros/src/any/validate.rs +++ b/nutype_macros/src/any/validate.rs @@ -1,10 +1,8 @@ -use std::collections::HashSet; - use proc_macro2::Span; use crate::common::{ - models::{DeriveTrait, SpannedDeriveTrait, TypeName}, - validate::{validate_duplicates, validate_guard, validate_traits_from_xor_try_from}, + models::{CfgAttrEntry, DeriveTrait, SpannedDeriveTrait, TypeName, ValidatedDerives}, + validate::{validate_all_derive_traits, validate_duplicates, validate_guard}, }; use super::models::{ @@ -50,23 +48,22 @@ fn validate_sanitizers( pub fn validate_any_derive_traits( guard: &AnyGuard, - spanned_derive_traits: Vec, -) -> Result, syn::Error> { - validate_traits_from_xor_try_from(&spanned_derive_traits)?; - - let mut traits = HashSet::with_capacity(24); - let has_validation = guard.has_validation(); - - for spanned_trait in spanned_derive_traits { - let string_derive_trait = - to_any_derive_trait(spanned_trait.item, has_validation, spanned_trait.span)?; - traits.insert(string_derive_trait); - } - - Ok(traits) + derive_traits: Vec, + cfg_attr_entries: &[CfgAttrEntry], + maybe_default_value: &Option, + type_name: &TypeName, +) -> Result, syn::Error> { + validate_all_derive_traits( + guard.has_validation(), + derive_traits, + cfg_attr_entries, + maybe_default_value, + type_name, + to_any_derive_trait, + ) } -fn to_any_derive_trait( +pub(crate) fn to_any_derive_trait( tr: DeriveTrait, _has_validation: bool, span: Span, diff --git a/nutype_macros/src/common/generate/mod.rs b/nutype_macros/src/common/generate/mod.rs index d603dc6f..95142047 100644 --- a/nutype_macros/src/common/generate/mod.rs +++ b/nutype_macros/src/common/generate/mod.rs @@ -11,8 +11,9 @@ use std::collections::HashSet; use self::traits::GeneratedTraits; use super::models::{ - ConstFn, ConstructorVisibility, CustomFunction, ErrorTypePath, GenerateParams, Guard, - NewUnchecked, ParseErrorTypeName, SpannedDeriveUnsafeTrait, TypeName, TypeTrait, + ConditionalDeriveGroup, ConstFn, ConstructorVisibility, CustomFunction, ErrorTypePath, + GenerateParams, Guard, NewUnchecked, ParseErrorTypeName, SpannedDeriveUnsafeTrait, TypeName, + TypeTrait, }; use crate::common::{ generate::{new_unchecked::gen_new_unchecked, parse_error::gen_parse_error_name}, @@ -103,6 +104,7 @@ pub fn gen_reimports( module_name: &ModuleName, maybe_error_type_path: Option<&ErrorTypePath>, maybe_parse_error_type_name: Option<&ParseErrorTypeName>, + conditional_parse_error_reimports: &[(TokenStream, ParseErrorTypeName)], ) -> TokenStream { let reimport_main_type = quote! { #vis use #module_name::#type_name; @@ -126,10 +128,21 @@ pub fn gen_reimports( } }; + let reimport_conditional_parse_errors: TokenStream = conditional_parse_error_reimports + .iter() + .map(|(pred, parse_error_name)| { + quote! { + #[cfg(#pred)] + #vis use #module_name::#parse_error_name; + } + }) + .collect(); + quote! { #reimport_main_type #reimport_error_type_if_needed #reimport_parse_error_type_if_needed + #reimport_conditional_parse_errors } } @@ -201,6 +214,7 @@ pub trait GenerateNewtype { validators: &[Self::Validator], ) -> TokenStream; + #[allow(clippy::too_many_arguments)] fn gen_traits( type_name: &TypeName, generics: &Generics, @@ -209,6 +223,7 @@ pub trait GenerateNewtype { unsafe_traits: &[SpannedDeriveUnsafeTrait], maybe_default_value: Option, guard: &Guard, + conditional_derives: &[ConditionalDeriveGroup], ) -> Result; fn gen_try_new( @@ -402,6 +417,7 @@ pub trait GenerateNewtype { maybe_default_value, inner_type, generics, + conditional_derives, } = params; let module_name = gen_module_name_for_type(&type_name); @@ -442,17 +458,12 @@ pub trait GenerateNewtype { }, }; - let reimports = gen_reimports( - vis, - &type_name, - &module_name, - maybe_reimported_error_type_path, - maybe_parse_error_type_path.as_ref(), - ); - let GeneratedTraits { derive_transparent_traits, implement_traits, + conditional_derive_transparent_traits, + conditional_implement_traits, + conditional_from_str_parse_errors, } = Self::gen_traits( &type_name, &generics, @@ -461,8 +472,18 @@ pub trait GenerateNewtype { &unsafe_traits, maybe_default_value, &guard, + &conditional_derives, )?; + let reimports = gen_reimports( + vis, + &type_name, + &module_name, + maybe_reimported_error_type_path, + maybe_parse_error_type_path.as_ref(), + &conditional_from_str_parse_errors, + ); + // Split generics for struct definition to properly handle where clauses let generics::SplitGenerics { impl_generics: struct_generics, @@ -482,10 +503,12 @@ pub trait GenerateNewtype { #(#doc_attrs)* #derive_transparent_traits + #conditional_derive_transparent_traits pub struct #type_name #struct_generics (#inner_type) #struct_where_clause; #implementation #implement_traits + #conditional_implement_traits #[cfg(test)] mod tests { diff --git a/nutype_macros/src/common/generate/traits.rs b/nutype_macros/src/common/generate/traits.rs index 4518c5b8..d6b93310 100644 --- a/nutype_macros/src/common/generate/traits.rs +++ b/nutype_macros/src/common/generate/traits.rs @@ -1,3 +1,4 @@ +use core::hash::Hash; use std::collections::HashSet; use proc_macro2::TokenStream; @@ -6,7 +7,7 @@ use syn::Generics; use crate::common::{ generate::generics::{SplitGenerics, add_bound_to_all_type_params}, - models::{ErrorTypePath, InnerType, TypeName}, + models::{ConditionalDeriveGroup, ErrorTypePath, InnerType, ParseErrorTypeName, TypeName}, }; use super::parse_error::{gen_def_parse_error, gen_parse_error_name}; @@ -18,6 +19,15 @@ pub struct GeneratedTraits { /// Implementation of traits. pub implement_traits: TokenStream, + + /// Conditional `#[cfg_attr(pred, derive(...))]` attributes. + pub conditional_derive_transparent_traits: TokenStream, + + /// Conditional `#[cfg(pred)] impl ...` blocks. + pub conditional_implement_traits: TokenStream, + + /// (predicate, ParseErrorTypeName) pairs for conditional `FromStr` re-exports. + pub conditional_from_str_parse_errors: Vec<(TokenStream, ParseErrorTypeName)>, } /// Split traits into 2 groups for generation: @@ -55,6 +65,115 @@ where } } +/// Output of processing conditional derives. +pub struct ConditionalTraits { + pub derive_transparent_traits: TokenStream, + pub implement_traits: TokenStream, + pub from_str_parse_errors: Vec<(TokenStream, ParseErrorTypeName)>, +} + +/// Indicates whether an irregular trait variant generates a `ParseError` type definition +/// that must be re-exported at module level when used inside conditional derives. +/// +/// For non-string types (integer, float, any), `FromStr` generates a `ParseError` enum +/// (via `gen_impl_trait_from_str` -> `gen_def_parse_error`). When such a trait appears +/// inside a conditional derive group, the generated code must use a `mod` wrapper +/// instead of `const _: () = { ... }`, so that `ParseError` is accessible for re-export. +/// +/// String's `FromStr` does **not** generate a `ParseError` type -- it reuses the +/// validation error type directly -- so it returns `false` for all variants. +pub trait HasGeneratedParseError { + /// Returns `true` if this irregular trait variant generates a `ParseError` type + /// definition that needs module-level re-export in conditional derives. + fn has_generated_parse_error(&self) -> bool; +} + +/// Process conditional derive groups, splitting each into transparent and irregular traits, +/// and generating the appropriate `#[cfg_attr(...)]` / `#[cfg(...)]` wrappers. +/// +/// This is shared logic used by all type-specific `gen_traits` functions. +/// +/// For each conditional group: +/// 1. Transparent traits + unchecked traits -> `#[cfg_attr(pred, derive(...))]` +/// 2. Irregular traits -> wrapped in either: +/// - `mod __fromstr_impl__ { ... }` + `pub use ...` if any trait +/// [`has_generated_parse_error`](HasGeneratedParseError::has_generated_parse_error), +/// so the `ParseError` type is accessible for re-export. +/// - `const _: () = { ... };` otherwise. +pub fn process_conditional_derives( + conditional_derives: &[ConditionalDeriveGroup], + type_name: &TypeName, + gen_impl_traits: impl Fn(Vec) -> Result, +) -> Result +where + InputTrait: Eq + Hash + Clone, + TransparentTrait: ToTokens, + IrregularTrait: HasGeneratedParseError, + GeneratableTrait: From, +{ + let mut derive_transparent_traits = TokenStream::new(); + let mut implement_traits = TokenStream::new(); + let mut from_str_parse_errors: Vec<(TokenStream, ParseErrorTypeName)> = vec![]; + + for group in conditional_derives { + let pred = &group.predicate; + + let cond_traits: HashSet = group.typed_traits.iter().cloned().collect(); + let GeneratableTraits { + transparent_traits: cond_transparent, + irregular_traits: cond_irregular, + } = split_into_generatable_traits(cond_traits); + + let cond_unchecked = &group.unchecked_traits; + if !cond_transparent.is_empty() || !cond_unchecked.is_empty() { + derive_transparent_traits.extend(quote! { + #[cfg_attr(#pred, derive( + #(#cond_transparent,)* + #(#cond_unchecked,)* + ))] + }); + } + + if !cond_irregular.is_empty() { + let needs_parse_error_reexport = cond_irregular + .iter() + .any(HasGeneratedParseError::has_generated_parse_error); + + let impl_tokens = gen_impl_traits(cond_irregular)?; + + if needs_parse_error_reexport { + // When FromStr is conditional, use a module wrapper so ParseError + // is accessible for re-export (not trapped inside const block). + let fromstr_mod_name = quote::format_ident!("__fromstr_impl__"); + let parse_error_name = gen_parse_error_name(type_name); + implement_traits.extend(quote! { + #[cfg(#pred)] + mod #fromstr_mod_name { + use super::*; + #impl_tokens + } + #[cfg(#pred)] + pub use #fromstr_mod_name::#parse_error_name; + }); + from_str_parse_errors.push((pred.clone(), parse_error_name)); + } else { + implement_traits.extend(quote! { + #[cfg(#pred)] + const _: () = { + #impl_tokens + }; + }); + } + } + } + + Ok(ConditionalTraits { + derive_transparent_traits, + implement_traits, + from_str_parse_errors, + }) +} + pub fn gen_impl_trait_into( type_name: &TypeName, generics: &Generics, diff --git a/nutype_macros/src/common/models.rs b/nutype_macros/src/common/models.rs index c16a0f90..45baa4eb 100644 --- a/nutype_macros/src/common/models.rs +++ b/nutype_macros/src/common/models.rs @@ -276,6 +276,9 @@ pub struct Attributes { /// List of unchecked traits that are derived with `derive_unchecked(...)` attribute. pub derive_unchecked_traits: Vec, + + /// Conditional entries from `cfg_attr(...)`. + pub cfg_attr_entries: Vec, } /// Represents a value known at compile time or an expression. @@ -344,7 +347,7 @@ pub struct RawGuard { pub validation: Option>, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum DeriveTrait { // Standard library Debug, @@ -383,8 +386,98 @@ pub enum DeriveTrait { ValuableValuable, } +impl core::fmt::Display for DeriveTrait { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let name = match self { + DeriveTrait::Debug => "Debug", + DeriveTrait::Clone => "Clone", + DeriveTrait::Copy => "Copy", + DeriveTrait::PartialEq => "PartialEq", + DeriveTrait::Eq => "Eq", + DeriveTrait::PartialOrd => "PartialOrd", + DeriveTrait::Ord => "Ord", + DeriveTrait::FromStr => "FromStr", + DeriveTrait::AsRef => "AsRef", + DeriveTrait::From => "From", + DeriveTrait::TryFrom => "TryFrom", + DeriveTrait::Into => "Into", + DeriveTrait::Hash => "Hash", + DeriveTrait::Borrow => "Borrow", + DeriveTrait::Display => "Display", + DeriveTrait::Default => "Default", + DeriveTrait::Deref => "Deref", + DeriveTrait::IntoIterator => "IntoIterator", + DeriveTrait::SerdeSerialize => "Serialize", + DeriveTrait::SerdeDeserialize => "Deserialize", + DeriveTrait::SchemarsJsonSchema => "JsonSchema", + DeriveTrait::ArbitraryArbitrary => "Arbitrary", + DeriveTrait::ValuableValuable => "Valuable", + }; + write!(f, "{name}") + } +} + pub type SpannedDeriveTrait = SpannedItem; +/// The inner attribute of a `cfg_attr(...)` entry. +#[derive(Debug)] +// Suppress dead_code warning: `DeriveUnchecked` is only constructed when +// the `derive_unchecked` feature is enabled. +#[allow(dead_code)] +pub enum CfgAttrContent { + /// `cfg_attr(, derive(...))` + Derive(Vec), + + /// `cfg_attr(, derive_unchecked(...))` + DeriveUnchecked(Vec), +} + +/// A single `cfg_attr(, )` entry parsed from `#[nutype(...)]`. +/// The predicate is stored as raw tokens - the proc macro does not evaluate it. +#[derive(Debug)] +pub struct CfgAttrEntry { + pub predicate: TokenStream, + pub content: CfgAttrContent, +} + +/// Result of trait validation, containing typed traits for both unconditional +/// and conditional derive entries. +pub struct ValidatedDerives { + /// Typed traits from unconditional `derive(...)`. + pub unconditional: HashSet, + + /// Typed traits from `cfg_attr(...)` entries, grouped by predicate. + pub conditional: Vec>, +} + +impl ValidatedDerives { + pub fn has_default_trait(&self) -> bool { + self.unconditional.iter().any(|t| t.is_default()) + || self + .conditional + .iter() + .any(|entry| entry.traits.iter().any(|t| t.is_default())) + } +} + +/// A single cfg_attr derive group after validation and type conversion. +pub struct ValidatedCfgAttrDerives { + pub predicate: TokenStream, + pub traits: Vec, +} + +/// A single predicate group for conditional code generation. +/// Contains either typed derive traits or unchecked derive traits (not both). +pub struct ConditionalDeriveGroup { + pub predicate: TokenStream, + + /// Typed traits from `cfg_attr(pred, derive(...))` - already validated and converted. + pub typed_traits: Vec, + + /// Unchecked traits from `cfg_attr(pred, derive_unchecked(...))` - passed through as-is. + pub unchecked_traits: Vec, +} + /// A trait that is derive with `derive_unchecked(...)` attribute. /// `derive_unchecked` simply bypasses traits into `derive(...)`. This allows /// allows to derive traits that nutype is not aware of. @@ -418,8 +511,8 @@ impl Parse for SpannedDeriveUnsafeTrait { } pub trait TypeTrait { - // If this is FromStr variant? fn is_from_str(&self) -> bool; + fn is_default(&self) -> bool; } /// The flag that indicates that a newtype will be generated with extra constructor, @@ -504,12 +597,14 @@ pub struct GenerateParams { pub const_fn: ConstFn, pub constructor_visibility: ConstructorVisibility, pub maybe_default_value: Option, + /// Conditional derive groups, one per predicate. + pub conditional_derives: Vec>, } pub trait Newtype { type Sanitizer; type Validator; - type TypedTrait; + type TypedTrait: TypeTrait; type InnerType; #[allow(clippy::type_complexity)] @@ -521,7 +616,10 @@ pub trait Newtype { fn validate( guard: &Guard, derive_traits: Vec, - ) -> Result, syn::Error>; + cfg_attr_entries: &[CfgAttrEntry], + maybe_default_value: &Option, + type_name: &TypeName, + ) -> Result, syn::Error>; #[allow(clippy::type_complexity)] fn generate( @@ -551,11 +649,23 @@ pub trait Newtype { default: maybe_default_value, derive_traits, derive_unchecked_traits, + cfg_attr_entries, } = Self::parse_attributes(attrs, &type_name)?; - let traits = Self::validate(&guard, derive_traits)?; + + let validated = Self::validate( + &guard, + derive_traits, + &cfg_attr_entries, + &maybe_default_value, + &type_name, + )?; + + let conditional_derives = + build_conditional_derive_groups(validated.conditional, &cfg_attr_entries); + let generated_output = Self::generate(GenerateParams { doc_attrs, - traits, + traits: validated.unconditional, unsafe_traits: derive_unchecked_traits, vis, type_name, @@ -566,11 +676,43 @@ pub trait Newtype { constructor_visibility, maybe_default_value, inner_type, + conditional_derives, })?; Ok(generated_output) } } +/// Build a list of ConditionalDeriveGroups from validated conditional derives and cfg_attr entries. +/// Each validated Derive entry becomes its own group (with typed_traits only). +/// Each DeriveUnchecked entry becomes its own group (with unchecked_traits only). +/// Groups with the same predicate are NOT merged. +pub fn build_conditional_derive_groups( + validated_conditional: Vec>, + cfg_attr_entries: &[CfgAttrEntry], +) -> Vec> { + let mut groups: Vec> = validated_conditional + .into_iter() + .map(|v| ConditionalDeriveGroup { + predicate: v.predicate, + typed_traits: v.traits, + unchecked_traits: vec![], + }) + .collect(); + + // Append DeriveUnchecked entries as their own groups + for entry in cfg_attr_entries { + if let CfgAttrContent::DeriveUnchecked(ref unchecked) = entry.content { + groups.push(ConditionalDeriveGroup { + predicate: entry.predicate.clone(), + typed_traits: vec![], + unchecked_traits: unchecked.clone(), + }); + } + } + + groups +} + /// Represents a function that is used for custom sanitizers and validators specified /// with `with =`. /// It can be either pass to an existing function or a closure. diff --git a/nutype_macros/src/common/parse/mod.rs b/nutype_macros/src/common/parse/mod.rs index f1b8c269..870009dc 100644 --- a/nutype_macros/src/common/parse/mod.rs +++ b/nutype_macros/src/common/parse/mod.rs @@ -17,7 +17,9 @@ use syn::{ token::Paren, }; -use crate::common::models::{SpannedDeriveTrait, SpannedDeriveUnsafeTrait}; +use crate::common::models::{ + CfgAttrContent, CfgAttrEntry, SpannedDeriveTrait, SpannedDeriveUnsafeTrait, +}; use super::models::{ ConstFn, ConstructorVisibility, CustomFunction, ErrorTypePath, NewUnchecked, @@ -82,6 +84,9 @@ pub struct ParseableAttributes { /// Parse from `derive_unchecked(...)` attribute pub derive_unchecked_traits: Vec, + + /// Parsed from `cfg_attr(...)` entries + pub cfg_attr_entries: Vec, } enum ValidateAttr { @@ -243,6 +248,7 @@ impl Default for ParseableAttributes default: None, derive_traits: vec![], derive_unchecked_traits: vec![], + cfg_attr_entries: vec![], } } } @@ -346,6 +352,20 @@ where return Err(syn::Error::new(ident.span(), msg)); } } + } else if ident == "cfg_attr" { + if input.peek(Paren) { + let content; + parenthesized!(content in input); + let entry = parse_cfg_attr_content(&content)?; + attrs.cfg_attr_entries.push(entry); + } else { + let msg = concat!( + "`cfg_attr` must be used with parenthesis.\n", + "For example:\n\n", + " cfg_attr(feature = \"serde\", derive(Serialize, Deserialize))\n\n" + ); + return Err(syn::Error::new(ident.span(), msg)); + } } else if ident == "constructor" { if input.peek(Paren) { let content; @@ -486,6 +506,68 @@ where } } +fn parse_cfg_predicate(input: ParseStream) -> syn::Result { + let mut tokens = Vec::new(); + + while !input.is_empty() && !input.peek(Token![,]) { + tokens.push(input.parse::()?); + } + + if tokens.is_empty() { + return Err(input.error("expected cfg predicate")); + } + + Ok(tokens.into_iter().collect()) +} + +fn parse_cfg_attr_content(input: ParseStream) -> syn::Result { + // 1. Parse the predicate: everything before the first top-level `,` + let predicate = parse_cfg_predicate(input)?; + let _comma: Token![,] = input.parse()?; + + // 2. Parse the inner attribute keyword + let attr_ident: Ident = input.parse()?; + + let content = if attr_ident == "derive" { + let inner; + parenthesized!(inner in input); + let items = inner.parse_terminated(SpannedDeriveTrait::parse, Token![,])?; + CfgAttrContent::Derive(items.into_iter().collect()) + } else if attr_ident == "derive_unchecked" { + cfg_if! { + if #[cfg(feature = "derive_unchecked")] { + let inner; + parenthesized!(inner in input); + let items = inner.parse_terminated(SpannedDeriveUnsafeTrait::parse, Token![,])?; + CfgAttrContent::DeriveUnchecked(items.into_iter().collect()) + } else { + let msg = concat!( + "To use derive_unchecked() function, the feature `derive_unchecked` ", + "of crate `nutype` needs to be enabled.\n\n", + "DID YOU KNOW?\n", + "It's called `derive_unchecked` because it enables to derive any traits ", + "that nutype is not aware of.\n", + "So it is developer's responsibility to ensure that the derived traits ", + "do not create a loophole to bypass the constraints.\n", + ); + return Err(syn::Error::new(attr_ident.span(), msg)); + } + } + } else { + let msg = format!( + "Attribute `{attr_ident}` is not supported inside `cfg_attr()`.\n\ + Only `derive(...)` and `derive_unchecked(...)` are allowed." + ); + return Err(syn::Error::new(attr_ident.span(), msg)); + }; + + if !input.is_empty() { + return Err(input.error("unexpected tokens after `derive(...)` inside `cfg_attr()`")); + } + + Ok(CfgAttrEntry { predicate, content }) +} + const CONSTRUCTOR_VISIBILITY_ERROR: &str = concat!( "Invalid constructor visibility.\n\n", "Valid options:\n", diff --git a/nutype_macros/src/common/validate.rs b/nutype_macros/src/common/validate.rs index 4f95b21a..9176938e 100644 --- a/nutype_macros/src/common/validate.rs +++ b/nutype_macros/src/common/validate.rs @@ -1,11 +1,14 @@ +use core::hash::Hash; use kinded::Kinded; use proc_macro2::Span; +use std::collections::HashSet; use super::{ r#generate::error::gen_error_type_name, models::{ - DeriveTrait, Guard, NumericBoundValidator, RawGuard, SpannedDeriveTrait, SpannedItem, - TypeName, Validation, + CfgAttrContent, CfgAttrEntry, DeriveTrait, Guard, NumericBoundValidator, RawGuard, + SpannedDeriveTrait, SpannedItem, TypeName, TypeTrait, ValidatedCfgAttrDerives, + ValidatedDerives, Validation, }, parse::RawValidation, }; @@ -171,3 +174,111 @@ pub fn validate_traits_from_xor_try_from( _ => Ok(()), } } + +/// Check that no trait appears in both unconditional `derive(...)` and any conditional +/// `cfg_attr(..., derive(...))`, and that no trait appears in multiple `cfg_attr` entries. +pub fn check_cfg_attr_no_duplicates( + unconditional: &[SpannedDeriveTrait], + cfg_attr_entries: &[CfgAttrEntry], +) -> Result<(), syn::Error> { + let unconditional_set: HashSet = unconditional.iter().map(|s| s.item).collect(); + + let mut conditional_seen: HashSet = HashSet::new(); + + for entry in cfg_attr_entries { + if let CfgAttrContent::Derive(ref traits) = entry.content { + for spanned in traits { + if unconditional_set.contains(&spanned.item) { + let msg = format!( + "Trait `{}` appears in both unconditional `derive()` and \ + conditional `cfg_attr(..., derive())`. Remove it from one of them.", + spanned.item + ); + return Err(syn::Error::new(spanned.span, msg)); + } + + if !conditional_seen.insert(spanned.item) { + let msg = format!( + "Trait `{}` appears in multiple `cfg_attr(...)` entries. \ + If their predicates overlap at compile time, this will cause \ + a compilation error. Combine them under a single predicate or \ + ensure predicates are mutually exclusive.", + spanned.item + ); + return Err(syn::Error::new(spanned.span, msg)); + } + } + } + } + Ok(()) +} + +/// Validate all derive traits (unconditional + conditional) in a single pass. +/// +/// The `convert` function is the only type-specific part - it converts a generic +/// `DeriveTrait` to the type-specific `TypedTrait`. +pub fn validate_all_derive_traits( + has_validation: bool, + derive_traits: Vec, + cfg_attr_entries: &[CfgAttrEntry], + maybe_default_value: &Option, + type_name: &TypeName, + convert: impl Fn(DeriveTrait, bool, Span) -> Result, +) -> Result, syn::Error> +where + TypedTrait: Eq + Hash + TypeTrait, +{ + // 0. Check for unconditional-vs-conditional duplicates + check_cfg_attr_no_duplicates(&derive_traits, cfg_attr_entries)?; + + // 1. Build the union of all derive traits for cross-trait dependency checks + let mut all_spanned = derive_traits.clone(); + for entry in cfg_attr_entries { + if let CfgAttrContent::Derive(ref traits) = entry.content { + all_spanned.extend(traits.iter().cloned()); + } + } + + // 2. Run cross-trait checks on the union (e.g., From XOR TryFrom) + validate_traits_from_xor_try_from(&all_spanned)?; + + // 3. Convert and collect unconditional traits (with type-compatibility checks) + let unconditional = derive_traits + .iter() + .map(|st| convert(st.item, has_validation, st.span)) + .collect::, _>>()?; + + // 4. Convert conditional traits (same conversion, per entry) + let conditional = cfg_attr_entries + .iter() + .filter_map(|entry| match &entry.content { + CfgAttrContent::Derive(traits) => Some((entry, traits)), + _ => None, + }) + .map(|(entry, traits)| { + let typed = traits + .iter() + .map(|st| convert(st.item, has_validation, st.span)) + .collect::, _>>()?; + Ok(ValidatedCfgAttrDerives { + predicate: entry.predicate.clone(), + traits: typed, + }) + }) + .collect::, syn::Error>>()?; + + let validated = ValidatedDerives { + unconditional, + conditional, + }; + + // 5. If Default appears ANYWHERE (unconditional or conditional), require default = + if validated.has_default_trait() && maybe_default_value.is_none() { + let msg = format!( + "Trait `Default` is derived for type {type_name}, but `default = ` parameter is missing in #[nutype] macro" + ); + return Err(syn::Error::new(proc_macro2::Span::call_site(), msg)); + } + + Ok(validated) +} diff --git a/nutype_macros/src/float/generate/mod.rs b/nutype_macros/src/float/generate/mod.rs index 0455d3f3..23f5a0d9 100644 --- a/nutype_macros/src/float/generate/mod.rs +++ b/nutype_macros/src/float/generate/mod.rs @@ -22,7 +22,10 @@ use crate::{ }, traits::GeneratedTraits, }, - models::{ConstFn, ErrorTypePath, Guard, SpannedDeriveUnsafeTrait, TypeName}, + models::{ + ConditionalDeriveGroup, ConstFn, ErrorTypePath, Guard, SpannedDeriveUnsafeTrait, + TypeName, + }, }, float::models::FloatInnerType, }; @@ -143,6 +146,7 @@ where unsafe_traits: &[SpannedDeriveUnsafeTrait], maybe_default_value: Option, guard: &FloatGuard, + conditional_derives: &[ConditionalDeriveGroup], ) -> Result { gen_traits( type_name, @@ -152,6 +156,7 @@ where traits, unsafe_traits, guard, + conditional_derives, ) } diff --git a/nutype_macros/src/float/generate/traits/mod.rs b/nutype_macros/src/float/generate/traits/mod.rs index 27d99e7c..9c503c5d 100644 --- a/nutype_macros/src/float/generate/traits/mod.rs +++ b/nutype_macros/src/float/generate/traits/mod.rs @@ -8,13 +8,14 @@ use syn::Generics; use crate::{ common::{ generate::traits::{ - GeneratableTrait, GeneratableTraits, GeneratedTraits, gen_impl_trait_as_ref, - gen_impl_trait_borrow, gen_impl_trait_default, gen_impl_trait_deref, - gen_impl_trait_display, gen_impl_trait_from, gen_impl_trait_from_str, - gen_impl_trait_into, gen_impl_trait_serde_deserialize, gen_impl_trait_serde_serialize, - gen_impl_trait_try_from, split_into_generatable_traits, + ConditionalTraits, GeneratableTrait, GeneratableTraits, GeneratedTraits, + HasGeneratedParseError, gen_impl_trait_as_ref, gen_impl_trait_borrow, + gen_impl_trait_default, gen_impl_trait_deref, gen_impl_trait_display, + gen_impl_trait_from, gen_impl_trait_from_str, gen_impl_trait_into, + gen_impl_trait_serde_deserialize, gen_impl_trait_serde_serialize, + gen_impl_trait_try_from, process_conditional_derives, split_into_generatable_traits, }, - models::{SpannedDeriveUnsafeTrait, TypeName}, + models::{ConditionalDeriveGroup, SpannedDeriveUnsafeTrait, TypeName}, }, float::models::{FloatDeriveTrait, FloatGuard, FloatInnerType}, }; @@ -53,6 +54,14 @@ enum FloatIrregularTrait { ArbitraryArbitrary, } +/// Float's `FromStr` generates a `ParseError` type via `gen_impl_trait_from_str` -> +/// `gen_def_parse_error`, which needs module-level re-export in conditional derives. +impl HasGeneratedParseError for FloatIrregularTrait { + fn has_generated_parse_error(&self) -> bool { + matches!(self, Self::FromStr) + } +} + impl From for FloatGeneratableTrait { fn from(derive_trait: FloatDeriveTrait) -> FloatGeneratableTrait { match derive_trait { @@ -126,6 +135,7 @@ impl ToTokens for FloatTransparentTrait { } } +#[allow(clippy::too_many_arguments)] pub fn gen_traits( type_name: &TypeName, generics: &Generics, @@ -134,6 +144,7 @@ pub fn gen_traits( traits: HashSet, unsafe_traits: &[SpannedDeriveUnsafeTrait], guard: &FloatGuard, + conditional_derives: &[ConditionalDeriveGroup], ) -> Result { let GeneratableTraits { transparent_traits, @@ -167,14 +178,32 @@ pub fn gen_traits( type_name, generics, inner_type, - maybe_default_value, + maybe_default_value.clone(), irregular_traits, guard, )?; + let ConditionalTraits { + derive_transparent_traits: conditional_derive_transparent_traits, + implement_traits: conditional_implement_traits, + from_str_parse_errors: conditional_from_str_parse_errors, + } = process_conditional_derives(conditional_derives, type_name, |irregular| { + gen_implemented_traits( + type_name, + generics, + inner_type, + maybe_default_value.clone(), + irregular, + guard, + ) + })?; + Ok(GeneratedTraits { derive_transparent_traits, implement_traits, + conditional_derive_transparent_traits, + conditional_implement_traits, + conditional_from_str_parse_errors, }) } diff --git a/nutype_macros/src/float/mod.rs b/nutype_macros/src/float/mod.rs index 87514125..ebe35d67 100644 --- a/nutype_macros/src/float/mod.rs +++ b/nutype_macros/src/float/mod.rs @@ -3,14 +3,16 @@ use core::{ marker::PhantomData, str::FromStr, }; -use std::collections::HashSet; use proc_macro2::TokenStream; use quote::ToTokens; use crate::common::{ generate::GenerateNewtype, - models::{Attributes, GenerateParams, Guard, Newtype, SpannedDeriveTrait, TypeName}, + models::{ + Attributes, CfgAttrEntry, GenerateParams, Guard, Newtype, SpannedDeriveTrait, TypeName, + ValidatedDerives, + }, }; use self::{ @@ -47,8 +49,17 @@ where fn validate( guard: &Guard, derive_traits: Vec, - ) -> Result, syn::Error> { - validate_float_derive_traits(derive_traits, guard) + cfg_attr_entries: &[CfgAttrEntry], + maybe_default_value: &Option, + type_name: &TypeName, + ) -> Result, syn::Error> { + validate_float_derive_traits( + derive_traits, + guard, + cfg_attr_entries, + maybe_default_value, + type_name, + ) } fn generate( diff --git a/nutype_macros/src/float/models.rs b/nutype_macros/src/float/models.rs index ea1966ef..774ff1af 100644 --- a/nutype_macros/src/float/models.rs +++ b/nutype_macros/src/float/models.rs @@ -39,7 +39,7 @@ pub type SpannedFloatValidator = SpannedItem>; // Traits // -#[derive(Debug, Eq, PartialEq, Hash)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] pub enum FloatDeriveTrait { // Standard Debug, @@ -71,6 +71,9 @@ impl TypeTrait for FloatDeriveTrait { fn is_from_str(&self) -> bool { self == &FloatDeriveTrait::FromStr } + fn is_default(&self) -> bool { + self == &FloatDeriveTrait::Default + } } pub type FloatRawGuard = RawGuard, SpannedFloatValidator>; diff --git a/nutype_macros/src/float/parse.rs b/nutype_macros/src/float/parse.rs index 10c6f8ae..4f8608df 100644 --- a/nutype_macros/src/float/parse.rs +++ b/nutype_macros/src/float/parse.rs @@ -44,6 +44,7 @@ where default, derive_traits, derive_unchecked_traits, + cfg_attr_entries, } = attrs; let raw_guard = FloatRawGuard { sanitizers, @@ -58,6 +59,7 @@ where default, derive_traits, derive_unchecked_traits, + cfg_attr_entries, }) } diff --git a/nutype_macros/src/float/validate.rs b/nutype_macros/src/float/validate.rs index 888e0a7f..043c5b13 100644 --- a/nutype_macros/src/float/validate.rs +++ b/nutype_macros/src/float/validate.rs @@ -2,10 +2,12 @@ use proc_macro2::Span; use std::collections::HashSet; use crate::common::{ - models::{DeriveTrait, SpannedDeriveTrait, TypeName, Validation}, + models::{ + CfgAttrContent, CfgAttrEntry, DeriveTrait, SpannedDeriveTrait, TypeName, ValidatedDerives, + Validation, + }, validate::{ - validate_duplicates, validate_guard, validate_numeric_bounds, - validate_traits_from_xor_try_from, + validate_all_derive_traits, validate_duplicates, validate_guard, validate_numeric_bounds, }, }; @@ -74,7 +76,7 @@ fn has_validation_against_nan(guard: &FloatGuard) -> bool { } #[derive(Debug, Clone, Copy)] -struct ValidationInfo { +pub(crate) struct ValidationInfo { has_validation: bool, has_nan_validation: bool, } @@ -91,25 +93,32 @@ impl ValidationInfo { } pub fn validate_float_derive_traits( - spanned_derive_traits: Vec, + derive_traits: Vec, guard: &FloatGuard, -) -> Result, syn::Error> { - validate_traits_from_xor_try_from(&spanned_derive_traits)?; - + cfg_attr_entries: &[CfgAttrEntry], + maybe_default_value: &Option, + type_name: &TypeName, +) -> Result, syn::Error> { let validation = ValidationInfo::from_guard(guard); - let mut traits = HashSet::with_capacity(24); - for spanned_trait in spanned_derive_traits.iter() { - let normal_trait = spanned_trait.item; - let string_derive_trait = - to_float_derive_trait(normal_trait, validation, spanned_trait.span)?; - traits.insert(string_derive_trait); + // Build union of all spanned traits for inter-trait dependency checks + let mut all_spanned = derive_traits.clone(); + for entry in cfg_attr_entries { + if let CfgAttrContent::Derive(ref traits) = entry.content { + all_spanned.extend(traits.iter().cloned()); + } + } + + // Convert all traits for dependency checks + let mut all_typed = HashSet::with_capacity(24); + for spanned_trait in all_spanned.iter() { + let typed = to_float_derive_trait(spanned_trait.item, validation, spanned_trait.span)?; + all_typed.insert(typed); } - // Get a span of a given trait, so we can render a better message below - // when we validate inter trait dependencies. + // Get a span of a given trait from the full union let get_span_for = |needle: DeriveTrait| -> Span { - spanned_derive_traits + all_spanned .iter() .flat_map(|spanned_tr| { if spanned_tr.item == needle { @@ -122,29 +131,38 @@ pub fn validate_float_derive_traits( .unwrap_or_else(Span::call_site) }; - // Validate inter trait dependencies - // - if traits.contains(&FloatDeriveTrait::Eq) && !traits.contains(&FloatDeriveTrait::PartialEq) { + // Validate inter trait dependencies on the union + if all_typed.contains(&FloatDeriveTrait::Eq) + && !all_typed.contains(&FloatDeriveTrait::PartialEq) + { let span = get_span_for(DeriveTrait::Eq); let msg = "Trait Eq requires PartialEq.\nEvery expert was once a beginner."; return Err(syn::Error::new(span, msg)); } - if traits.contains(&FloatDeriveTrait::Ord) { - if !traits.contains(&FloatDeriveTrait::PartialOrd) { + if all_typed.contains(&FloatDeriveTrait::Ord) { + if !all_typed.contains(&FloatDeriveTrait::PartialOrd) { let span = get_span_for(DeriveTrait::Ord); let msg = "Trait Ord requires PartialOrd.\nÜbung macht den Meister."; return Err(syn::Error::new(span, msg)); - } else if !traits.contains(&FloatDeriveTrait::Eq) { + } else if !all_typed.contains(&FloatDeriveTrait::Eq) { let span = get_span_for(DeriveTrait::Ord); let msg = "Trait Ord requires Eq.\nFestina lente."; return Err(syn::Error::new(span, msg)); } } - Ok(traits) + // Use shared helper for the rest (From XOR TryFrom, conversion) + validate_all_derive_traits( + validation.has_validation, + derive_traits, + cfg_attr_entries, + maybe_default_value, + type_name, + |tr, _has_validation, span| to_float_derive_trait(tr, validation, span), + ) } -fn to_float_derive_trait( +pub(crate) fn to_float_derive_trait( tr: DeriveTrait, validation: ValidationInfo, span: Span, diff --git a/nutype_macros/src/integer/generate/mod.rs b/nutype_macros/src/integer/generate/mod.rs index b094423b..d568415f 100644 --- a/nutype_macros/src/integer/generate/mod.rs +++ b/nutype_macros/src/integer/generate/mod.rs @@ -24,7 +24,9 @@ use crate::common::{ }, traits::GeneratedTraits, }, - models::{ConstFn, ErrorTypePath, Guard, SpannedDeriveUnsafeTrait, TypeName}, + models::{ + ConditionalDeriveGroup, ConstFn, ErrorTypePath, Guard, SpannedDeriveUnsafeTrait, TypeName, + }, }; impl GenerateNewtype for IntegerNewtype @@ -135,6 +137,7 @@ where unsafe_traits: &[SpannedDeriveUnsafeTrait], maybe_default_value: Option, guard: &IntegerGuard, + conditional_derives: &[ConditionalDeriveGroup], ) -> Result { gen_traits( type_name, @@ -144,6 +147,7 @@ where unsafe_traits, maybe_default_value, guard, + conditional_derives, ) } diff --git a/nutype_macros/src/integer/generate/traits/mod.rs b/nutype_macros/src/integer/generate/traits/mod.rs index e4388954..98710b9c 100644 --- a/nutype_macros/src/integer/generate/traits/mod.rs +++ b/nutype_macros/src/integer/generate/traits/mod.rs @@ -9,19 +9,21 @@ use syn::Generics; use crate::{ common::{ generate::traits::{ - GeneratableTrait, GeneratableTraits, GeneratedTraits, gen_impl_trait_as_ref, - gen_impl_trait_borrow, gen_impl_trait_default, gen_impl_trait_deref, - gen_impl_trait_display, gen_impl_trait_from, gen_impl_trait_from_str, - gen_impl_trait_into, gen_impl_trait_serde_deserialize, gen_impl_trait_serde_serialize, - gen_impl_trait_try_from, split_into_generatable_traits, + ConditionalTraits, GeneratableTrait, GeneratableTraits, GeneratedTraits, + HasGeneratedParseError, gen_impl_trait_as_ref, gen_impl_trait_borrow, + gen_impl_trait_default, gen_impl_trait_deref, gen_impl_trait_display, + gen_impl_trait_from, gen_impl_trait_from_str, gen_impl_trait_into, + gen_impl_trait_serde_deserialize, gen_impl_trait_serde_serialize, + gen_impl_trait_try_from, process_conditional_derives, split_into_generatable_traits, }, - models::{SpannedDeriveUnsafeTrait, TypeName}, + models::{ConditionalDeriveGroup, SpannedDeriveUnsafeTrait, TypeName}, }, integer::models::{IntegerDeriveTrait, IntegerGuard, IntegerInnerType}, }; type IntegerGeneratableTrait = GeneratableTrait; +#[allow(clippy::too_many_arguments)] pub fn gen_traits( type_name: &TypeName, generics: &Generics, @@ -30,6 +32,7 @@ pub fn gen_traits( unsafe_traits: &[SpannedDeriveUnsafeTrait], maybe_default_value: Option, guard: &IntegerGuard, + conditional_derives: &[ConditionalDeriveGroup], ) -> Result { let GeneratableTraits { transparent_traits, @@ -48,13 +51,31 @@ pub fn gen_traits( generics, inner_type, irregular_traits, - maybe_default_value, + maybe_default_value.clone(), guard, )?; + let ConditionalTraits { + derive_transparent_traits: conditional_derive_transparent_traits, + implement_traits: conditional_implement_traits, + from_str_parse_errors: conditional_from_str_parse_errors, + } = process_conditional_derives(conditional_derives, type_name, |irregular| { + gen_implemented_traits( + type_name, + generics, + inner_type, + irregular, + maybe_default_value.clone(), + guard, + ) + })?; + Ok(GeneratedTraits { derive_transparent_traits, implement_traits, + conditional_derive_transparent_traits, + conditional_implement_traits, + conditional_from_str_parse_errors, }) } @@ -164,6 +185,14 @@ enum IntegerIrregularTrait { ArbitraryArbitrary, } +/// Integer's `FromStr` generates a `ParseError` type via `gen_impl_trait_from_str` -> +/// `gen_def_parse_error`, which needs module-level re-export in conditional derives. +impl HasGeneratedParseError for IntegerIrregularTrait { + fn has_generated_parse_error(&self) -> bool { + matches!(self, Self::FromStr) + } +} + impl ToTokens for IntegerTransparentTrait { fn to_tokens(&self, token_stream: &mut TokenStream) { let tokens = match self { diff --git a/nutype_macros/src/integer/mod.rs b/nutype_macros/src/integer/mod.rs index 530463f9..0ea31db2 100644 --- a/nutype_macros/src/integer/mod.rs +++ b/nutype_macros/src/integer/mod.rs @@ -3,14 +3,16 @@ use core::{ marker::PhantomData, str::FromStr, }; -use std::collections::HashSet; use proc_macro2::TokenStream; use quote::ToTokens; use crate::common::{ generate::GenerateNewtype, - models::{Attributes, GenerateParams, Guard, Newtype, SpannedDeriveTrait, TypeName}, + models::{ + Attributes, CfgAttrEntry, GenerateParams, Guard, Newtype, SpannedDeriveTrait, TypeName, + ValidatedDerives, + }, }; use self::{ @@ -48,9 +50,18 @@ where fn validate( guard: &Guard, derive_traits: Vec, - ) -> Result, syn::Error> { + cfg_attr_entries: &[CfgAttrEntry], + maybe_default_value: &Option, + type_name: &TypeName, + ) -> Result, syn::Error> { let has_validation = guard.has_validation(); - validate_integer_derive_traits(derive_traits, has_validation) + validate_integer_derive_traits( + derive_traits, + has_validation, + cfg_attr_entries, + maybe_default_value, + type_name, + ) } fn generate( diff --git a/nutype_macros/src/integer/models.rs b/nutype_macros/src/integer/models.rs index 2ce7d557..99f5bd29 100644 --- a/nutype_macros/src/integer/models.rs +++ b/nutype_macros/src/integer/models.rs @@ -38,7 +38,7 @@ pub type SpannedIntegerValidator = SpannedItem>; // Traits // -#[derive(Debug, Eq, PartialEq, Hash)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] pub enum IntegerDeriveTrait { // Standard Debug, @@ -71,6 +71,9 @@ impl TypeTrait for IntegerDeriveTrait { fn is_from_str(&self) -> bool { self == &IntegerDeriveTrait::FromStr } + fn is_default(&self) -> bool { + self == &IntegerDeriveTrait::Default + } } pub type IntegerRawGuard = RawGuard, SpannedIntegerValidator>; diff --git a/nutype_macros/src/integer/parse.rs b/nutype_macros/src/integer/parse.rs index ae69aa56..1ae89346 100644 --- a/nutype_macros/src/integer/parse.rs +++ b/nutype_macros/src/integer/parse.rs @@ -44,6 +44,7 @@ where default, derive_traits, derive_unchecked_traits, + cfg_attr_entries, } = attrs; let raw_guard = IntegerRawGuard { sanitizers, @@ -58,6 +59,7 @@ where default, derive_traits, derive_unchecked_traits, + cfg_attr_entries, }) } diff --git a/nutype_macros/src/integer/validate.rs b/nutype_macros/src/integer/validate.rs index 4cdeb9dc..65541735 100644 --- a/nutype_macros/src/integer/validate.rs +++ b/nutype_macros/src/integer/validate.rs @@ -1,12 +1,9 @@ -use std::collections::HashSet; - use proc_macro2::Span; use crate::common::{ - models::{DeriveTrait, SpannedDeriveTrait, TypeName}, + models::{CfgAttrEntry, DeriveTrait, SpannedDeriveTrait, TypeName, ValidatedDerives}, validate::{ - validate_duplicates, validate_guard, validate_numeric_bounds, - validate_traits_from_xor_try_from, + validate_all_derive_traits, validate_duplicates, validate_guard, validate_numeric_bounds, }, }; @@ -63,23 +60,23 @@ where } pub fn validate_integer_derive_traits( - spanned_derive_traits: Vec, + derive_traits: Vec, has_validation: bool, -) -> Result, syn::Error> { - validate_traits_from_xor_try_from(&spanned_derive_traits)?; - - let mut traits = HashSet::with_capacity(24); - - for spanned_trait in spanned_derive_traits { - let string_derive_trait = - to_integer_derive_trait(spanned_trait.item, has_validation, spanned_trait.span)?; - traits.insert(string_derive_trait); - } - - Ok(traits) + cfg_attr_entries: &[CfgAttrEntry], + maybe_default_value: &Option, + type_name: &TypeName, +) -> Result, syn::Error> { + validate_all_derive_traits( + has_validation, + derive_traits, + cfg_attr_entries, + maybe_default_value, + type_name, + to_integer_derive_trait, + ) } -fn to_integer_derive_trait( +pub(crate) fn to_integer_derive_trait( tr: DeriveTrait, has_validation: bool, span: Span, diff --git a/nutype_macros/src/string/generate/mod.rs b/nutype_macros/src/string/generate/mod.rs index ae421506..dd022509 100644 --- a/nutype_macros/src/string/generate/mod.rs +++ b/nutype_macros/src/string/generate/mod.rs @@ -14,7 +14,10 @@ use crate::{ GenerateNewtype, tests::gen_test_should_have_valid_default_value, traits::GeneratedTraits, }, - models::{ConstFn, ErrorTypePath, Guard, SpannedDeriveUnsafeTrait, TypeName}, + models::{ + ConditionalDeriveGroup, ConstFn, ErrorTypePath, Guard, SpannedDeriveUnsafeTrait, + TypeName, + }, }, string::models::{RegexDef, StringInnerType, StringSanitizer, StringValidator}, }; @@ -209,6 +212,7 @@ impl GenerateNewtype for StringNewtype { unsafe_traits: &[SpannedDeriveUnsafeTrait], maybe_default_value: Option, guard: &StringGuard, + conditional_derives: &[ConditionalDeriveGroup], ) -> Result { gen_traits( type_name, @@ -217,6 +221,7 @@ impl GenerateNewtype for StringNewtype { unsafe_traits, maybe_default_value, guard, + conditional_derives, ) } diff --git a/nutype_macros/src/string/generate/traits/mod.rs b/nutype_macros/src/string/generate/traits/mod.rs index 26f35292..aefa5c3e 100644 --- a/nutype_macros/src/string/generate/traits/mod.rs +++ b/nutype_macros/src/string/generate/traits/mod.rs @@ -9,13 +9,14 @@ use syn::Generics; use crate::{ common::{ generate::traits::{ - GeneratableTrait, GeneratableTraits, GeneratedTraits, gen_impl_trait_as_ref, - gen_impl_trait_borrow, gen_impl_trait_default, gen_impl_trait_deref, - gen_impl_trait_display, gen_impl_trait_from, gen_impl_trait_into, - gen_impl_trait_serde_deserialize, gen_impl_trait_serde_serialize, - gen_impl_trait_try_from, split_into_generatable_traits, + ConditionalTraits, GeneratableTrait, GeneratableTraits, GeneratedTraits, + HasGeneratedParseError, gen_impl_trait_as_ref, gen_impl_trait_borrow, + gen_impl_trait_default, gen_impl_trait_deref, gen_impl_trait_display, + gen_impl_trait_from, gen_impl_trait_into, gen_impl_trait_serde_deserialize, + gen_impl_trait_serde_serialize, gen_impl_trait_try_from, process_conditional_derives, + split_into_generatable_traits, }, - models::{ErrorTypePath, SpannedDeriveUnsafeTrait, TypeName}, + models::{ConditionalDeriveGroup, ErrorTypePath, SpannedDeriveUnsafeTrait, TypeName}, }, string::models::{StringDeriveTrait, StringGuard, StringInnerType}, }; @@ -54,6 +55,15 @@ enum StringIrregularTrait { ArbitraryArbitrary, } +/// Always returns `false`: String's `FromStr` implementation reuses the validation error +/// type directly (via `gen_impl_from_str`) and does **not** generate a separate `ParseError` +/// type definition. Therefore no module-level re-export is needed in conditional derives. +impl HasGeneratedParseError for StringIrregularTrait { + fn has_generated_parse_error(&self) -> bool { + false + } +} + impl From for StringGeneratableTrait { fn from(derive_trait: StringDeriveTrait) -> StringGeneratableTrait { match derive_trait { @@ -141,6 +151,7 @@ impl ToTokens for StringTransparentTrait { } } +#[allow(clippy::too_many_arguments)] pub fn gen_traits( type_name: &TypeName, generics: &Generics, @@ -148,6 +159,7 @@ pub fn gen_traits( unsafe_traits: &[SpannedDeriveUnsafeTrait], maybe_default_value: Option, guard: &StringGuard, + conditional_derives: &[ConditionalDeriveGroup], ) -> Result { let GeneratableTraits { transparent_traits, @@ -164,14 +176,31 @@ pub fn gen_traits( let implement_traits = gen_implemented_traits( type_name, generics, - maybe_default_value, + maybe_default_value.clone(), irregular_traits, guard, )?; + let ConditionalTraits { + derive_transparent_traits: conditional_derive_transparent_traits, + implement_traits: conditional_implement_traits, + from_str_parse_errors: conditional_from_str_parse_errors, + } = process_conditional_derives(conditional_derives, type_name, |irregular| { + gen_implemented_traits( + type_name, + generics, + maybe_default_value.clone(), + irregular, + guard, + ) + })?; + Ok(GeneratedTraits { derive_transparent_traits, implement_traits, + conditional_derive_transparent_traits, + conditional_implement_traits, + conditional_from_str_parse_errors, }) } diff --git a/nutype_macros/src/string/mod.rs b/nutype_macros/src/string/mod.rs index adf8f497..1a9efd68 100644 --- a/nutype_macros/src/string/mod.rs +++ b/nutype_macros/src/string/mod.rs @@ -3,11 +3,12 @@ pub mod models; pub mod parse; pub mod validate; -use std::collections::HashSet; - use crate::common::{ generate::GenerateNewtype, - models::{Attributes, GenerateParams, Newtype, SpannedDeriveTrait, TypeName}, + models::{ + Attributes, CfgAttrEntry, GenerateParams, Newtype, SpannedDeriveTrait, TypeName, + ValidatedDerives, + }, }; use models::{StringDeriveTrait, StringSanitizer, StringValidator}; @@ -36,8 +37,17 @@ impl Newtype for StringNewtype { fn validate( guard: &StringGuard, derive_traits: Vec, - ) -> Result, syn::Error> { - validate_string_derive_traits(guard, derive_traits) + cfg_attr_entries: &[CfgAttrEntry], + maybe_default_value: &Option, + type_name: &TypeName, + ) -> Result, syn::Error> { + validate_string_derive_traits( + guard, + derive_traits, + cfg_attr_entries, + maybe_default_value, + type_name, + ) } fn generate( diff --git a/nutype_macros/src/string/models.rs b/nutype_macros/src/string/models.rs index 5520f589..e077ac53 100644 --- a/nutype_macros/src/string/models.rs +++ b/nutype_macros/src/string/models.rs @@ -86,6 +86,9 @@ impl TypeTrait for StringDeriveTrait { fn is_from_str(&self) -> bool { self == &Self::FromStr } + fn is_default(&self) -> bool { + self == &Self::Default + } } pub type StringRawGuard = RawGuard; diff --git a/nutype_macros/src/string/parse.rs b/nutype_macros/src/string/parse.rs index 83c61aee..5d31ad0b 100644 --- a/nutype_macros/src/string/parse.rs +++ b/nutype_macros/src/string/parse.rs @@ -40,6 +40,7 @@ pub fn parse_attributes( default, derive_traits, derive_unchecked_traits, + cfg_attr_entries, } = attrs; let raw_guard = StringRawGuard { sanitizers, @@ -54,6 +55,7 @@ pub fn parse_attributes( default, derive_traits, derive_unchecked_traits, + cfg_attr_entries, }) } diff --git a/nutype_macros/src/string/validate.rs b/nutype_macros/src/string/validate.rs index 737f7ecb..cebe1f8f 100644 --- a/nutype_macros/src/string/validate.rs +++ b/nutype_macros/src/string/validate.rs @@ -1,12 +1,13 @@ use kinded::Kinded; -use std::collections::HashSet; use proc_macro2::Span; use crate::{ common::{ - models::{DeriveTrait, SpannedDeriveTrait, TypeName, ValueOrExpr}, - validate::{validate_duplicates, validate_guard, validate_traits_from_xor_try_from}, + models::{ + CfgAttrEntry, DeriveTrait, SpannedDeriveTrait, TypeName, ValidatedDerives, ValueOrExpr, + }, + validate::{validate_all_derive_traits, validate_duplicates, validate_guard}, }, string::models::{StringGuard, StringRawGuard, StringSanitizer, StringValidator}, }; @@ -134,23 +135,22 @@ fn validate_sanitizers( pub fn validate_string_derive_traits( guard: &StringGuard, - spanned_derive_traits: Vec, -) -> Result, syn::Error> { - validate_traits_from_xor_try_from(&spanned_derive_traits)?; - - let mut traits = HashSet::with_capacity(24); - let has_validation = guard.has_validation(); - - for spanned_trait in spanned_derive_traits { - let string_derive_trait = - to_string_derive_trait(spanned_trait.item, has_validation, spanned_trait.span)?; - traits.insert(string_derive_trait); - } - - Ok(traits) + derive_traits: Vec, + cfg_attr_entries: &[CfgAttrEntry], + maybe_default_value: &Option, + type_name: &TypeName, +) -> Result, syn::Error> { + validate_all_derive_traits( + guard.has_validation(), + derive_traits, + cfg_attr_entries, + maybe_default_value, + type_name, + to_string_derive_trait, + ) } -fn to_string_derive_trait( +pub(crate) fn to_string_derive_trait( tr: DeriveTrait, has_validation: bool, span: Span, diff --git a/test_suite/Cargo.toml b/test_suite/Cargo.toml index a5c4ad15..88e27522 100644 --- a/test_suite/Cargo.toml +++ b/test_suite/Cargo.toml @@ -16,6 +16,7 @@ lazy_static = { version = "1", optional = true } regex = { version = "1", optional = true } once_cell = { version = "1", optional = true } arbitrary = "1.3.0" +impls = "1" arbtest = "0.2.0" ron = "0.8.1" rmp-serde = "1.1.2" diff --git a/test_suite/tests/any.rs b/test_suite/tests/any.rs index 35db2584..fafd4874 100644 --- a/test_suite/tests/any.rs +++ b/test_suite/tests/any.rs @@ -1144,3 +1144,57 @@ mod into_iter { assert_eq!(iter.next(), None); } } + +#[cfg(test)] +mod cfg_attr { + use super::*; + + #[test] + fn test_cfg_attr_derive_transparent_trait() { + #[nutype(derive(Debug, PartialEq), cfg_attr(test, derive(Clone)))] + pub struct Wrapper(Vec); + + let w = Wrapper::new(vec![1, 2, 3]); + let w2 = w.clone(); + assert_eq!(w, w2); + } + + #[test] + fn test_cfg_attr_derive_irregular_trait() { + #[nutype(derive(Debug), cfg_attr(test, derive(Display, Deref)))] + pub struct Num(i32); + + let num = Num::new(42); + assert_eq!(format!("{num}"), "42"); + assert_eq!(*num, 42); + } + + #[test] + fn test_cfg_attr_with_validation() { + #[nutype( + validate(predicate = |v: &Vec| !v.is_empty()), + derive(Debug, PartialEq), + cfg_attr(test, derive(Clone, Into)), + )] + pub struct NonEmptyVec(Vec); + + let v = NonEmptyVec::try_new(vec![1, 2]).unwrap(); + let v2 = v.clone(); + assert_eq!(v, v2); + let inner: Vec = v2.into(); + assert_eq!(inner, vec![1, 2]); + } + + #[test] + fn test_cfg_attr_derive_default() { + #[nutype( + derive(Debug, PartialEq), + default = vec![0], + cfg_attr(test, derive(Default)) + )] + pub struct DefaultVec(Vec); + + let val = DefaultVec::default(); + assert_eq!(val.into_inner(), vec![0]); + } +} diff --git a/test_suite/tests/float.rs b/test_suite/tests/float.rs index 26fe6c87..d84e2313 100644 --- a/test_suite/tests/float.rs +++ b/test_suite/tests/float.rs @@ -874,3 +874,63 @@ mod constants { assert_eq!(FIFTY.into_inner(), 50.0); } } + +#[cfg(test)] +mod cfg_attr { + use super::*; + + #[test] + fn test_cfg_attr_derive_transparent_trait() { + #[nutype( + validate(finite, greater_or_equal = 0.0), + derive(Debug, PartialEq), + cfg_attr(test, derive(Clone, Copy)) + )] + pub struct PositiveFloat(f64); + + let val = PositiveFloat::try_new(3.14).unwrap(); + let val2 = val; + let val3 = val; + assert_eq!(val2, val3); + } + + #[test] + fn test_cfg_attr_derive_irregular_trait() { + #[nutype(derive(Debug), cfg_attr(test, derive(Display, Into)))] + pub struct Temperature(f64); + + let temp = Temperature::new(36.6); + assert_eq!(format!("{temp}"), "36.6"); + let inner: f64 = temp.into(); + assert_eq!(inner, 36.6); + } + + #[test] + fn test_cfg_attr_derive_from_str() { + // Conditional FromStr on float type + #[nutype(validate(finite), derive(Debug), cfg_attr(test, derive(FromStr)))] + pub struct FiniteFloat(f64); + + let val: FiniteFloat = "3.14".parse().unwrap(); + assert_eq!(val.into_inner(), 3.14); + + // Invalid parse + assert!("not_a_number".parse::().is_err()); + + // Valid parse but fails validation (NaN) + assert!("NaN".parse::().is_err()); + } + + #[test] + fn test_cfg_attr_derive_default() { + #[nutype( + derive(Debug, PartialEq), + default = 0.0, + cfg_attr(test, derive(Default)) + )] + pub struct Score(f64); + + let val = Score::default(); + assert_eq!(val, Score::new(0.0)); + } +} diff --git a/test_suite/tests/integer.rs b/test_suite/tests/integer.rs index a7261f2d..3eebf78d 100644 --- a/test_suite/tests/integer.rs +++ b/test_suite/tests/integer.rs @@ -980,3 +980,119 @@ mod constants { } } } + +#[cfg(test)] +mod cfg_attr { + use super::*; + + #[test] + fn test_cfg_attr_derive_transparent_trait() { + #[nutype( + validate(greater_or_equal = 0, less_or_equal = 100), + derive(Debug, PartialEq), + cfg_attr(test, derive(Clone, Copy)) + )] + pub struct Percent(i32); + + let p = Percent::try_new(50).unwrap(); + let p2 = p; + let p3 = p; + assert_eq!(p2, p3); + } + + #[test] + fn test_cfg_attr_derive_irregular_trait() { + #[nutype( + validate(greater_or_equal = 1), + derive(Debug), + cfg_attr(test, derive(Display, AsRef)) + )] + pub struct PositiveInt(i64); + + let val = PositiveInt::try_new(42).unwrap(); + assert_eq!(format!("{val}"), "42"); + let inner: &i64 = val.as_ref(); + assert_eq!(*inner, 42); + } + + #[test] + fn test_cfg_attr_without_validation() { + #[nutype(derive(Debug, PartialEq), cfg_attr(test, derive(Clone, Copy, Into)))] + pub struct Count(u32); + + let c = Count::new(10); + let c2 = c; + let val: u32 = c2.into(); + assert_eq!(val, 10); + } + + #[test] + fn test_cfg_attr_derive_from_str() { + // Conditional FromStr on integer type + #[nutype( + validate(greater_or_equal = 1), + derive(Debug), + cfg_attr(test, derive(FromStr)) + )] + pub struct PositiveNum(i32); + + let val: PositiveNum = "42".parse().unwrap(); + assert_eq!(val.into_inner(), 42); + + // Invalid parse (not a number) + assert!("abc".parse::().is_err()); + + // Valid parse but fails validation + assert!("0".parse::().is_err()); + + // Verify ParseError type is accessible by name (re-exported) + let err = "abc".parse::().unwrap_err(); + assert!(matches!(err, PositiveNumParseError::Parse(_))); + + let err = "0".parse::().unwrap_err(); + assert!(matches!(err, PositiveNumParseError::Validate(_))); + } + + #[test] + fn test_cfg_attr_derive_default() { + // Conditional Default with unconditional default value + #[nutype( + derive(Debug, PartialEq), + default = 10, + cfg_attr(test, derive(Default)) + )] + pub struct DefNum(i32); + + let val = DefNum::default(); + assert_eq!(val, DefNum::new(10)); + } + + #[test] + fn test_cfg_attr_complex_predicate() { + // Complex cfg predicate with all(...) + #[nutype( + derive(Debug), + cfg_attr(all(test, target_pointer_width = "64"), derive(Clone, Copy)) + )] + pub struct Width(u64); + + let w = Width::new(100); + #[cfg(all(test, target_pointer_width = "64"))] + { + let w2 = w; + let _w3 = w2; + } + let _ = w; + } + + #[test] + fn test_cfg_attr_cross_predicate_traits() { + // PartialEq unconditional, Eq conditional - should work when predicate is true + #[nutype(derive(Debug, PartialEq), cfg_attr(test, derive(Eq)))] + pub struct Level(i32); + + let a = Level::new(5); + let b = Level::new(5); + assert_eq!(a, b); + } +} diff --git a/test_suite/tests/string.rs b/test_suite/tests/string.rs index ccf52ceb..6f6896f8 100644 --- a/test_suite/tests/string.rs +++ b/test_suite/tests/string.rs @@ -1034,3 +1034,132 @@ mod constructor_visibility { } } } + +#[cfg(test)] +mod cfg_attr { + use super::*; + + #[test] + fn test_cfg_attr_derive_transparent_trait() { + // cfg_attr with a trait that is always true (test cfg) + #[nutype( + sanitize(trim), + validate(not_empty), + derive(Debug, PartialEq), + cfg_attr(test, derive(Clone)) + )] + pub struct Name(String); + + let name = Name::try_new("hello").unwrap(); + let name2 = name.clone(); + assert_eq!(name, name2); + } + + #[test] + fn test_cfg_attr_derive_irregular_trait() { + // cfg_attr with Display which requires a custom impl + #[nutype(sanitize(trim), derive(Debug), cfg_attr(test, derive(Display)))] + pub struct Greeting(String); + + let greeting = Greeting::new("hello"); + assert_eq!(format!("{greeting}"), "hello"); + } + + #[test] + fn test_cfg_attr_with_false_predicate() { + // cfg_attr with a predicate that is never true should not derive + #[nutype(derive(Debug), cfg_attr(not(test), derive(Clone)))] + pub struct Label(String); + + let _label = Label::new("test"); + // Clone should NOT be available here (not(test) is false in test context) + assert!(!impls::impls!(Label: Clone)); + } + + #[test] + fn test_cfg_attr_multiple_traits() { + // cfg_attr deriving multiple traits at once + #[nutype(derive(Debug), cfg_attr(test, derive(Clone, PartialEq, Eq, Display)))] + pub struct Tag(String); + + let tag1 = Tag::new("rust"); + let tag2 = tag1.clone(); + assert_eq!(tag1, tag2); + assert_eq!(format!("{tag1}"), "rust"); + } + + #[test] + fn test_cfg_attr_with_validation() { + // cfg_attr with a validated type + #[nutype( + validate(not_empty, len_char_max = 100), + derive(Debug, PartialEq), + cfg_attr(test, derive(Clone, Display)) + )] + pub struct Title(String); + + let title = Title::try_new("Hello World").unwrap(); + let title2 = title.clone(); + assert_eq!(title, title2); + assert_eq!(format!("{title}"), "Hello World"); + } + + #[test] + fn test_cfg_attr_multiple_cfg_attr_entries() { + // Multiple cfg_attr entries + #[nutype( + derive(Debug), + cfg_attr(test, derive(Clone)), + cfg_attr(test, derive(Display)) + )] + pub struct Item(String); + + let item = Item::new("thing"); + let _item2 = item.clone(); + assert_eq!(format!("{item}"), "thing"); + } + + #[test] + fn test_cfg_attr_derive_default() { + #[nutype( + derive(Debug, PartialEq), + default = "default", + cfg_attr(test, derive(Default)) + )] + pub struct DefString(String); + + let val = DefString::default(); + assert_eq!(val.into_inner(), "default"); + } + + #[test] + fn test_cfg_attr_complex_predicate_any() { + // Complex cfg predicate with any(...) + #[nutype( + derive(Debug), + cfg_attr(any(test, debug_assertions), derive(Clone, Display)) + )] + pub struct Msg(String); + + let msg = Msg::new("hello"); + let msg2 = msg.clone(); + assert_eq!(format!("{msg2}"), "hello"); + } + + #[cfg(feature = "serde")] + #[test] + fn test_cfg_attr_conditional_serde() { + // Conditional serde derives + #[nutype( + derive(Debug, PartialEq), + cfg_attr(feature = "serde", derive(Serialize, Deserialize)) + )] + pub struct Label(String); + + let label = Label::new("hello"); + let json = serde_json::to_string(&label).unwrap(); + assert_eq!(json, r#""hello""#); + let deserialized: Label = serde_json::from_str(&json).unwrap(); + assert_eq!(label, deserialized); + } +} diff --git a/test_suite/tests/ui/common/cfg_attr_default_missing_value.rs b/test_suite/tests/ui/common/cfg_attr_default_missing_value.rs new file mode 100644 index 00000000..299a71ad --- /dev/null +++ b/test_suite/tests/ui/common/cfg_attr_default_missing_value.rs @@ -0,0 +1,10 @@ +use nutype::nutype; + +// Default is derived conditionally but default = is missing +#[nutype( + derive(Debug), + cfg_attr(test, derive(Default)), +)] +struct Name(String); + +fn main() { } diff --git a/test_suite/tests/ui/common/cfg_attr_default_missing_value.stderr b/test_suite/tests/ui/common/cfg_attr_default_missing_value.stderr new file mode 100644 index 00000000..f56c6fcb --- /dev/null +++ b/test_suite/tests/ui/common/cfg_attr_default_missing_value.stderr @@ -0,0 +1,10 @@ +error: Trait `Default` is derived for type Name, but `default = ` parameter is missing in #[nutype] macro + --> tests/ui/common/cfg_attr_default_missing_value.rs:4:1 + | +4 | / #[nutype( +5 | | derive(Debug), +6 | | cfg_attr(test, derive(Default)), +7 | | )] + | |__^ + | + = note: this error originates in the attribute macro `nutype` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/test_suite/tests/ui/common/cfg_attr_duplicate_cross_conditional.rs b/test_suite/tests/ui/common/cfg_attr_duplicate_cross_conditional.rs new file mode 100644 index 00000000..bb411c19 --- /dev/null +++ b/test_suite/tests/ui/common/cfg_attr_duplicate_cross_conditional.rs @@ -0,0 +1,11 @@ +use nutype::nutype; + +// Same trait in multiple cfg_attr entries +#[nutype( + derive(Debug), + cfg_attr(test, derive(Clone)), + cfg_attr(test, derive(Clone)), +)] +struct Name(String); + +fn main() { } diff --git a/test_suite/tests/ui/common/cfg_attr_duplicate_cross_conditional.stderr b/test_suite/tests/ui/common/cfg_attr_duplicate_cross_conditional.stderr new file mode 100644 index 00000000..ae0b9cc8 --- /dev/null +++ b/test_suite/tests/ui/common/cfg_attr_duplicate_cross_conditional.stderr @@ -0,0 +1,5 @@ +error: Trait `Clone` appears in multiple `cfg_attr(...)` entries. If their predicates overlap at compile time, this will cause a compilation error. Combine them under a single predicate or ensure predicates are mutually exclusive. + --> tests/ui/common/cfg_attr_duplicate_cross_conditional.rs:7:27 + | +7 | cfg_attr(test, derive(Clone)), + | ^^^^^ diff --git a/test_suite/tests/ui/common/cfg_attr_duplicate_trait.rs b/test_suite/tests/ui/common/cfg_attr_duplicate_trait.rs new file mode 100644 index 00000000..c53caa9a --- /dev/null +++ b/test_suite/tests/ui/common/cfg_attr_duplicate_trait.rs @@ -0,0 +1,10 @@ +use nutype::nutype; + +// A trait cannot appear in both unconditional derive and cfg_attr derive +#[nutype( + derive(Debug, Clone), + cfg_attr(test, derive(Clone)), +)] +struct Name(String); + +fn main() { } diff --git a/test_suite/tests/ui/common/cfg_attr_duplicate_trait.stderr b/test_suite/tests/ui/common/cfg_attr_duplicate_trait.stderr new file mode 100644 index 00000000..3b721e1c --- /dev/null +++ b/test_suite/tests/ui/common/cfg_attr_duplicate_trait.stderr @@ -0,0 +1,5 @@ +error: Trait `Clone` appears in both unconditional `derive()` and conditional `cfg_attr(..., derive())`. Remove it from one of them. + --> tests/ui/common/cfg_attr_duplicate_trait.rs:6:27 + | +6 | cfg_attr(test, derive(Clone)), + | ^^^^^ diff --git a/test_suite/tests/ui/common/cfg_attr_empty_predicate.rs b/test_suite/tests/ui/common/cfg_attr_empty_predicate.rs new file mode 100644 index 00000000..de4b4344 --- /dev/null +++ b/test_suite/tests/ui/common/cfg_attr_empty_predicate.rs @@ -0,0 +1,10 @@ +use nutype::nutype; + +// Empty predicate should be rejected +#[nutype( + derive(Debug), + cfg_attr(, derive(Clone)), +)] +struct Name(String); + +fn main() { } diff --git a/test_suite/tests/ui/common/cfg_attr_empty_predicate.stderr b/test_suite/tests/ui/common/cfg_attr_empty_predicate.stderr new file mode 100644 index 00000000..fc54c03c --- /dev/null +++ b/test_suite/tests/ui/common/cfg_attr_empty_predicate.stderr @@ -0,0 +1,5 @@ +error: expected cfg predicate + --> tests/ui/common/cfg_attr_empty_predicate.rs:6:14 + | +6 | cfg_attr(, derive(Clone)), + | ^ diff --git a/test_suite/tests/ui/common/cfg_attr_missing_attribute.rs b/test_suite/tests/ui/common/cfg_attr_missing_attribute.rs new file mode 100644 index 00000000..45b011f4 --- /dev/null +++ b/test_suite/tests/ui/common/cfg_attr_missing_attribute.rs @@ -0,0 +1,10 @@ +use nutype::nutype; + +// cfg_attr with predicate but no attribute +#[nutype( + derive(Debug), + cfg_attr(test), +)] +struct Name(String); + +fn main() { } diff --git a/test_suite/tests/ui/common/cfg_attr_missing_attribute.stderr b/test_suite/tests/ui/common/cfg_attr_missing_attribute.stderr new file mode 100644 index 00000000..3080460d --- /dev/null +++ b/test_suite/tests/ui/common/cfg_attr_missing_attribute.stderr @@ -0,0 +1,5 @@ +error: expected `,` + --> tests/ui/common/cfg_attr_missing_attribute.rs:6:18 + | +6 | cfg_attr(test), + | ^ diff --git a/test_suite/tests/ui/common/cfg_attr_trailing_tokens.rs b/test_suite/tests/ui/common/cfg_attr_trailing_tokens.rs new file mode 100644 index 00000000..3d874b26 --- /dev/null +++ b/test_suite/tests/ui/common/cfg_attr_trailing_tokens.rs @@ -0,0 +1,10 @@ +use nutype::nutype; + +// Trailing tokens after derive(...) inside cfg_attr should be rejected +#[nutype( + derive(Debug), + cfg_attr(test, derive(Clone) some_garbage), +)] +struct Name(String); + +fn main() { } diff --git a/test_suite/tests/ui/common/cfg_attr_trailing_tokens.stderr b/test_suite/tests/ui/common/cfg_attr_trailing_tokens.stderr new file mode 100644 index 00000000..66d82b81 --- /dev/null +++ b/test_suite/tests/ui/common/cfg_attr_trailing_tokens.stderr @@ -0,0 +1,5 @@ +error: unexpected tokens after `derive(...)` inside `cfg_attr()` + --> tests/ui/common/cfg_attr_trailing_tokens.rs:6:34 + | +6 | cfg_attr(test, derive(Clone) some_garbage), + | ^^^^^^^^^^^^ diff --git a/test_suite/tests/ui/common/cfg_attr_unsupported_attr.rs b/test_suite/tests/ui/common/cfg_attr_unsupported_attr.rs new file mode 100644 index 00000000..be8ca5eb --- /dev/null +++ b/test_suite/tests/ui/common/cfg_attr_unsupported_attr.rs @@ -0,0 +1,10 @@ +use nutype::nutype; + +// Only derive and derive_unchecked are supported inside cfg_attr +#[nutype( + derive(Debug), + cfg_attr(test, sanitize(trim)), +)] +struct Name(String); + +fn main() { } diff --git a/test_suite/tests/ui/common/cfg_attr_unsupported_attr.stderr b/test_suite/tests/ui/common/cfg_attr_unsupported_attr.stderr new file mode 100644 index 00000000..ecc47816 --- /dev/null +++ b/test_suite/tests/ui/common/cfg_attr_unsupported_attr.stderr @@ -0,0 +1,6 @@ +error: Attribute `sanitize` is not supported inside `cfg_attr()`. + Only `derive(...)` and `derive_unchecked(...)` are allowed. + --> tests/ui/common/cfg_attr_unsupported_attr.rs:6:20 + | +6 | cfg_attr(test, sanitize(trim)), + | ^^^^^^^^