From cfa0bf171d4276ead08f57682ab1a71234b237a8 Mon Sep 17 00:00:00 2001 From: Serhii Potapov Date: Fri, 30 Jan 2026 22:26:16 +0100 Subject: [PATCH] Suppoer where clause in generics --- CHANGELOG.md | 1 + README.md | 25 ++ nutype/src/lib.rs | 27 ++ .../src/any/generate/traits/arbitrary.rs | 28 +- .../src/any/generate/traits/into_iter.rs | 65 +++- nutype_macros/src/common/generate/generics.rs | 143 +++++++ nutype_macros/src/common/generate/mod.rs | 110 +++--- .../src/common/generate/parse_error.rs | 47 ++- nutype_macros/src/common/generate/traits.rs | 187 +++++++--- test_suite/tests/where_clause.rs | 350 ++++++++++++++++++ 10 files changed, 822 insertions(+), 161 deletions(-) create mode 100644 nutype_macros/src/common/generate/generics.rs create mode 100644 test_suite/tests/where_clause.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 381de946..50a3ce14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ - **[FEATURE]** Ability to derive [`Valuable`](https://docs.rs/valuable/0.1.1/valuable/trait.Valuable.html) (requires `valuable` feature). - **[FEATURE]** Ability to control constructor visibility with `constructor(visibility = ...)` attribute (see [#211](https://github.com/greyblake/nutype/issues/211)). - **[FEATURE]** Add `len_utf16_min` and `len_utf16_max` validators for string types to validate UTF-16 code unit length (useful for JavaScript interop) (see [#162](https://github.com/greyblake/nutype/issues/162)). +- **[FEATURE]** Support `where` clauses in generic newtypes, including Higher-Ranked Trait Bounds (HRTB) like `for<'a> &'a C: IntoIterator` (see [#160](https://github.com/greyblake/nutype/issues/160)). ### v0.6.2 - 2025-06-30 - **[FEATURE]** Introduce `derive_unsafe(..)` attribute to derive any arbitrary trait (requires `derive_unsafe` feature to be enabled). diff --git a/README.md b/README.md index 30b4e3ef..50cea7dc 100644 --- a/README.md +++ b/README.md @@ -266,7 +266,32 @@ assert_eq!(numbers.as_ref(), &[1, 2, 4, 7]); assert_eq!(numbers.len(), 4); ``` +### Where clauses +Nutype fully supports `where` clauses in generic newtypes, including Higher-Ranked Trait Bounds (HRTB): + +```rust +use nutype::nutype; + +// Simple where clause +#[nutype(derive(Debug, Clone))] +struct Wrapper(T) +where + T: Default + Clone; + +// HRTB for collections - validate that collection is non-empty +#[nutype( + validate(predicate = |c| c.into_iter().next().is_some()), + derive(Debug) +)] +struct NonEmpty(C) +where + for<'a> &'a C: IntoIterator; + +// Usage +let non_empty = NonEmpty::try_new(vec![1, 2, 3]).unwrap(); +assert!(NonEmpty::try_new(Vec::::new()).is_err()); +``` ## Custom sanitizers diff --git a/nutype/src/lib.rs b/nutype/src/lib.rs index 1b4cf296..3761faac 100644 --- a/nutype/src/lib.rs +++ b/nutype/src/lib.rs @@ -311,6 +311,33 @@ //! assert_eq!(numbers.len(), 4); //! ``` //! +//! ### Where clauses +//! +//! Nutype fully supports `where` clauses in generic newtypes, including Higher-Ranked Trait Bounds (HRTB): +//! +//! ``` +//! use nutype::nutype; +//! +//! // Simple where clause +//! #[nutype(derive(Debug, Clone))] +//! struct Wrapper(T) +//! where +//! T: Default + Clone; +//! +//! // HRTB for collections - validate that collection is non-empty +//! #[nutype( +//! validate(predicate = |c| c.into_iter().next().is_some()), +//! derive(Debug) +//! )] +//! struct NonEmpty(C) +//! where +//! for<'a> &'a C: IntoIterator; +//! +//! // Usage +//! let non_empty = NonEmpty::try_new(vec![1, 2, 3]).unwrap(); +//! assert!(NonEmpty::try_new(Vec::::new()).is_err()); +//! ``` +//! //! ## Custom sanitizers //! //! You can set custom sanitizers using the `with` option. diff --git a/nutype_macros/src/any/generate/traits/arbitrary.rs b/nutype_macros/src/any/generate/traits/arbitrary.rs index e313119c..5beaacf6 100644 --- a/nutype_macros/src/any/generate/traits/arbitrary.rs +++ b/nutype_macros/src/any/generate/traits/arbitrary.rs @@ -4,7 +4,7 @@ use syn::Generics; use crate::{ any::models::{AnyGuard, AnyInnerType}, - common::generate::{add_bound_to_all_type_params, add_param, strip_trait_bounds_on_generics}, + common::generate::generics::{SplitGenerics, add_bound_to_all_type_params, add_generic_param}, common::models::TypeName, }; @@ -25,14 +25,32 @@ pub fn gen_impl_trait_arbitrary( // Generate implementation of `Arbitrary` trait, assuming that inner type implements Arbitrary // too. - let generics_without_bounds = strip_trait_bounds_on_generics(generics); - let generics_with_lifetime = add_param(&generics_without_bounds, quote!('nu_arb)); + // + // We need to: + // 1. Add a lifetime 'nu_arb + // 2. Add Arbitrary<'nu_arb> bound to all type params + let generics_with_lifetime = add_generic_param(generics, syn::parse_quote!('nu_arb)); let generics_with_bounds = add_bound_to_all_type_params( &generics_with_lifetime, - quote!(::arbitrary::Arbitrary<'nu_arb>), + syn::parse_quote!(::arbitrary::Arbitrary<'nu_arb>), ); + + let SplitGenerics { + impl_generics, + type_generics: _, + where_clause, + } = SplitGenerics::new(&generics_with_bounds); + + // Get type generics without the added lifetime + let SplitGenerics { type_generics, .. } = SplitGenerics::new(generics); + + // Example for `struct Wrapper(T) where T: Clone`: + // + // impl<'nu_arb, T: Arbitrary<'nu_arb>> Arbitrary<'nu_arb> for Wrapper where T: Clone { + // fn arbitrary(u: &mut Unstructured<'nu_arb>) -> Result { ... } + // } Ok(quote!( - impl #generics_with_bounds ::arbitrary::Arbitrary<'nu_arb> for #type_name #generics_without_bounds { + impl #impl_generics ::arbitrary::Arbitrary<'nu_arb> for #type_name #type_generics #where_clause { fn arbitrary(u: &mut ::arbitrary::Unstructured<'nu_arb>) -> ::arbitrary::Result { let inner_value: #inner_type = u.arbitrary()?; Ok(#type_name::new(inner_value)) diff --git a/nutype_macros/src/any/generate/traits/into_iter.rs b/nutype_macros/src/any/generate/traits/into_iter.rs index 83be807d..6c39ec0a 100644 --- a/nutype_macros/src/any/generate/traits/into_iter.rs +++ b/nutype_macros/src/any/generate/traits/into_iter.rs @@ -5,7 +5,7 @@ use syn::Generics; use crate::{ any::models::AnyInnerType, common::{ - generate::{add_param, strip_trait_bounds_on_generics}, + generate::generics::{SplitGenerics, add_generic_param}, models::TypeName, }, }; @@ -15,36 +15,61 @@ pub fn gen_impl_trait_into_iter( generics: &Generics, inner_type: &AnyInnerType, ) -> TokenStream { - let generics_without_bounds = strip_trait_bounds_on_generics(generics); - let generics_with_iter_lifetime = add_param(generics, quote!('__nutype_iter)); + let generics_with_iter_lifetime = + add_generic_param(generics, syn::parse_quote!('__nutype_iter)); + + let SplitGenerics { + impl_generics, + type_generics, + where_clause, + } = SplitGenerics::new(generics); + + let SplitGenerics { + impl_generics: impl_generics_with_lifetime, + type_generics: _, + where_clause: where_clause_with_lifetime, + } = SplitGenerics::new(&generics_with_iter_lifetime); // In the comments below, we assume that IntoIterator is derived for the following type // - // struct Names<'a, T: Display>(Vec<&'a T>); + // struct Names<'a, T: Display>(Vec<&'a T>) where T: Clone; // // NOTE: We deliberately do not generate an iterator over mutable references, because // this would allow the user to modify the elements of the collection, which may violate // the guarantees that nutype is supposed to provide. + // + // Example generated code: + // + // impl<'a, T: Display> IntoIterator for Names<'a, T> where T: Clone { + // type Item = as IntoIterator>::Item; + // type IntoIter = as IntoIterator>::IntoIter; + // fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } + // } + // + // impl<'a, '__nutype_iter, T: Display> IntoIterator for &'__nutype_iter Names<'a, T> where T: Clone { + // type Item = <&'__nutype_iter Vec<&'a T> as IntoIterator>::Item; + // ... + // } quote!( // Implement IntoIterator for the type. - impl #generics ::core::iter::IntoIterator for #type_name #generics_without_bounds { // impl<'a, T: Display> ::core::iter::IntoIterator for Names<'a, T> { - type Item = <#inner_type as ::core::iter::IntoIterator>::Item; // type Item = as ::core::iter::IntoIterator>::Item; - type IntoIter = <#inner_type as ::core::iter::IntoIterator>::IntoIter; // type IntoIter = as ::core::iter::IntoIterator>::IntoIter; - // - fn into_iter(self) -> Self::IntoIter { // fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() // self.0.into_iter() - } // } - } // } + impl #impl_generics ::core::iter::IntoIterator for #type_name #type_generics #where_clause { + type Item = <#inner_type as ::core::iter::IntoIterator>::Item; + type IntoIter = <#inner_type as ::core::iter::IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } + } // IntoIterator for the reference to the type (so it can be iterated over references). - impl #generics_with_iter_lifetime ::core::iter::IntoIterator // impl<'a, '__nutype_iter, T: Display> ::core::iter::IntoIterator - for &'__nutype_iter #type_name #generics_without_bounds { // for &'__nutype_iter Names<'a, T> { - type Item = <&'__nutype_iter #inner_type as ::core::iter::IntoIterator>::Item; // type Item = <&'__nutype_iter Vec<&'a T> as ::core::iter::IntoIterator>::Item; - type IntoIter = <&'__nutype_iter #inner_type as ::core::iter::IntoIterator>::IntoIter; // type IntoIter = <&'__nutype_iter Vec<&'a T> as ::core::iter::IntoIterator>::IntoIter; - - fn into_iter(self) -> Self::IntoIter { // fn into_iter(self) -> Self::IntoIter { - self.0.iter().into_iter() // self.0.iter().into_iter() - } // } + impl #impl_generics_with_lifetime ::core::iter::IntoIterator + for &'__nutype_iter #type_name #type_generics #where_clause_with_lifetime { + type Item = <&'__nutype_iter #inner_type as ::core::iter::IntoIterator>::Item; + type IntoIter = <&'__nutype_iter #inner_type as ::core::iter::IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter().into_iter() + } } ) } diff --git a/nutype_macros/src/common/generate/generics.rs b/nutype_macros/src/common/generate/generics.rs new file mode 100644 index 00000000..40abd085 --- /dev/null +++ b/nutype_macros/src/common/generate/generics.rs @@ -0,0 +1,143 @@ +//! Utilities for handling generics and where clauses in code generation. +//! +//! This module provides helper functions that properly handle `where` clauses +//! when generating impl blocks, including support for Higher-Ranked Trait Bounds (HRTB). + +use proc_macro2::TokenStream; +use quote::quote; +use syn::Generics; + +/// Split generics for use in impl blocks. +/// +/// This properly separates the generics into three parts: +/// - `impl_generics`: Goes after `impl` keyword (e.g., ``) +/// - `type_generics`: Goes after type name (e.g., ``) +/// - `where_clause`: Goes at the end of impl signature (e.g., `where T: Default`) +/// +/// # Example +/// +/// For `struct Foo(T) where T: Default`: +/// +/// ```ignore +/// let split = split_generics_for_impl(&generics); +/// quote! { +/// impl #impl_generics SomeTrait for Foo #type_generics #where_clause { +/// // ... +/// } +/// } +/// ``` +/// +/// Generates: +/// ```ignore +/// impl SomeTrait for Foo where T: Default { +/// // ... +/// } +/// ``` +pub struct SplitGenerics { + pub impl_generics: TokenStream, + pub type_generics: TokenStream, + pub where_clause: TokenStream, +} + +impl SplitGenerics { + pub fn new(generics: &Generics) -> Self { + let (impl_generics, type_generics, where_clause) = generics.split_for_impl(); + Self { + impl_generics: quote!(#impl_generics), + type_generics: quote!(#type_generics), + where_clause: quote!(#where_clause), + } + } +} + +/// Add a bound to all type parameters in generics. +/// +/// This adds the bound to inline type parameters. +/// +/// # Arguments +/// * `generics` - The original generics +/// * `bound` - The bound to add (e.g., `Display`, `Serialize`) +/// +/// # Example +/// +/// Input: `` with bound `Display` +/// Output: `` +pub fn add_bound_to_all_type_params(generics: &Generics, bound: syn::TypeParamBound) -> Generics { + let mut result = generics.clone(); + for param in &mut result.params { + if let syn::GenericParam::Type(type_param) = param { + type_param.bounds.push(bound.clone()); + } + } + result +} + +/// Add a generic parameter (typically a lifetime) to generics. +/// +/// The parameter is added at the end of the params list. +/// +/// # Example +/// +/// Input: `` with param `'de` +/// Output: `` +pub fn add_generic_param(generics: &Generics, param: syn::GenericParam) -> Generics { + let mut result = generics.clone(); + result.params.push(param); + result +} + +#[cfg(test)] +mod tests { + use super::*; + use syn::parse_quote; + + #[test] + fn test_split_generics_simple() { + let generics: Generics = parse_quote!(); + let split = SplitGenerics::new(&generics); + + // Just verify it doesn't panic and produces some output + assert!(!split.impl_generics.is_empty()); + assert!(!split.type_generics.is_empty()); + } + + #[test] + fn test_split_generics_with_where_clause() { + // Parse a full struct to get generics with where clause + let item: syn::ItemStruct = parse_quote! { + struct Foo where T: Clone { field: T } + }; + let split = SplitGenerics::new(&item.generics); + + // Verify where clause is captured + assert!(!split.where_clause.is_empty()); + } + + #[test] + fn test_split_generics_with_hrtb() { + // Parse a full struct to get generics with HRTB where clause + let item: syn::ItemStruct = parse_quote! { + struct Foo where for<'a> &'a C: IntoIterator { field: C } + }; + let split = SplitGenerics::new(&item.generics); + + // Verify HRTB where clause is captured + let where_str = split.where_clause.to_string(); + assert!(where_str.contains("for")); + assert!(where_str.contains("IntoIterator")); + } + + #[test] + fn test_add_bound() { + let generics: Generics = parse_quote!(); + let bound: syn::TypeParamBound = parse_quote!(Clone); + let result = add_bound_to_all_type_params(&generics, bound); + + // Verify bounds were added + for param in &result.params { + if let syn::GenericParam::Type(tp) = param { + assert!(!tp.bounds.is_empty()); + } + } + } +} diff --git a/nutype_macros/src/common/generate/mod.rs b/nutype_macros/src/common/generate/mod.rs index c51588f4..d603dc6f 100644 --- a/nutype_macros/src/common/generate/mod.rs +++ b/nutype_macros/src/common/generate/mod.rs @@ -1,4 +1,5 @@ pub mod error; +pub mod generics; pub mod new_unchecked; pub mod parse_error; pub mod tests; @@ -138,9 +139,22 @@ pub fn gen_impl_into_inner( inner_type: impl ToTokens, const_fn: ConstFn, ) -> TokenStream { - let generics_without_bounds = strip_trait_bounds_on_generics(generics); + let generics::SplitGenerics { + impl_generics, + type_generics, + where_clause, + } = generics::SplitGenerics::new(generics); + + // Example for `struct Wrapper(T) where T: Default`: + // + // impl Wrapper where T: Default { + // #[inline] + // pub fn into_inner(self) -> T { + // self.0 + // } + // } quote! { - impl #generics #type_name #generics_without_bounds { + impl #impl_generics #type_name #type_generics #where_clause { #[inline] pub #const_fn fn into_inner(self) -> #inner_type { self.0 @@ -149,59 +163,6 @@ pub fn gen_impl_into_inner( } } -/// Remove trait bounds from generics. -/// -/// Input: -/// -/// -/// Output: -/// -pub fn strip_trait_bounds_on_generics(original: &Generics) -> Generics { - let mut generics = original.clone(); - for param in &mut generics.params { - if let syn::GenericParam::Type(syn::TypeParam { bounds, .. }) = param { - *bounds = syn::punctuated::Punctuated::new(); - } - } - generics -} - -/// Add a bound to all generics types. -/// -/// Input: -/// -/// Serialize -/// -/// Output: -/// -pub fn add_bound_to_all_type_params(generics: &Generics, bound: TokenStream) -> Generics { - let mut generics = generics.clone(); - let parsed_bound: syn::TypeParamBound = - syn::parse2(bound).expect("Failed to parse TypeParamBound"); - for param in &mut generics.params { - if let syn::GenericParam::Type(syn::TypeParam { bounds, .. }) = param { - bounds.push(parsed_bound.clone()); - } - } - generics -} - -/// Add a parameter to generics. -/// -/// Input: -/// -/// 'a -/// -/// Output: -/// <'a, T, U> -/// -pub fn add_param(generics: &Generics, param: TokenStream) -> Generics { - let mut generics = generics.clone(); - let parsed_param: syn::GenericParam = syn::parse2(param).expect("Failed to parse GenericParam"); - generics.params.push(parsed_param); - generics -} - pub trait GenerateNewtype { type Sanitizer; type Validator; @@ -259,7 +220,11 @@ pub trait GenerateNewtype { const_fn: ConstFn, constructor_visibility: &ConstructorVisibility, ) -> TokenStream { - let generics_without_bounds = strip_trait_bounds_on_generics(generics); + let generics::SplitGenerics { + impl_generics, + type_generics, + where_clause, + } = generics::SplitGenerics::new(generics); let fn_sanitize = Self::gen_fn_sanitize(inner_type, sanitizers, const_fn); let maybe_generated_validation_error = match validation { @@ -296,10 +261,15 @@ pub trait GenerateNewtype { let error_type_path = validation.error_type_path(); + // Example for `struct Wrapper(T) where T: Default`: + // + // impl Wrapper where T: Default { + // pub fn try_new(raw_value: T) -> Result { ... } + // } quote!( #maybe_generated_validation_error - impl #generics #type_name #generics_without_bounds { + impl #impl_generics #type_name #type_generics #where_clause { #constructor_visibility #const_fn fn try_new(raw_value: #input_type) -> ::core::result::Result { #convert_raw_value_if_necessary @@ -330,7 +300,11 @@ pub trait GenerateNewtype { const_fn: ConstFn, constructor_visibility: &ConstructorVisibility, ) -> TokenStream { - let generics_without_bounds = strip_trait_bounds_on_generics(generics); + let generics::SplitGenerics { + impl_generics, + type_generics, + where_clause, + } = generics::SplitGenerics::new(generics); let fn_sanitize = Self::gen_fn_sanitize(inner_type, sanitizers, const_fn); let (input_type, convert_raw_value_if_necessary) = if Self::NEW_CONVERT_INTO_INNER_TYPE { @@ -342,8 +316,13 @@ pub trait GenerateNewtype { (quote!(#inner_type), quote!()) }; + // Example for `struct Wrapper(T) where T: Default`: + // + // impl Wrapper where T: Default { + // pub fn new(raw_value: T) -> Self { ... } + // } quote!( - impl #generics #type_name #generics_without_bounds { + impl #impl_generics #type_name #type_generics #where_clause { #constructor_visibility #const_fn fn new(raw_value: #input_type) -> Self { #convert_raw_value_if_necessary Self(Self::__sanitize__(raw_value)) @@ -484,6 +463,17 @@ pub trait GenerateNewtype { &guard, )?; + // Split generics for struct definition to properly handle where clauses + let generics::SplitGenerics { + impl_generics: struct_generics, + type_generics: _, + where_clause: struct_where_clause, + } = generics::SplitGenerics::new(&generics); + + // Example for `struct Wrapper(T) where T: Default`: + // + // #[derive(Debug, Clone)] + // pub struct Wrapper(T) where T: Default; Ok(quote!( #[doc(hidden)] #[allow(non_snake_case, reason = "we keep original structure name which is probably CamelCase")] @@ -492,7 +482,7 @@ pub trait GenerateNewtype { #(#doc_attrs)* #derive_transparent_traits - pub struct #type_name #generics(#inner_type); + pub struct #type_name #struct_generics (#inner_type) #struct_where_clause; #implementation #implement_traits diff --git a/nutype_macros/src/common/generate/parse_error.rs b/nutype_macros/src/common/generate/parse_error.rs index 198b27eb..a4ef69a6 100644 --- a/nutype_macros/src/common/generate/parse_error.rs +++ b/nutype_macros/src/common/generate/parse_error.rs @@ -4,7 +4,7 @@ use quote::{format_ident, quote}; use syn::Generics; use crate::common::{ - generate::{add_bound_to_all_type_params, strip_trait_bounds_on_generics}, + generate::generics::{SplitGenerics, add_bound_to_all_type_params}, models::{ErrorTypePath, InnerType, ParseErrorTypeName, TypeName}, }; @@ -26,21 +26,33 @@ pub fn gen_def_parse_error( let inner_type: InnerType = inner_type.into(); let type_name_str = type_name.to_string(); - let generics_without_bounds = strip_trait_bounds_on_generics(generics); let generics_with_fromstr_bound = add_bound_to_all_type_params( - &generics_without_bounds, + generics, syn::parse_quote!(::core::str::FromStr), ); + let SplitGenerics { + impl_generics: enum_generics, + type_generics, + where_clause, + } = SplitGenerics::new(&generics_with_fromstr_bound); + + // Example for `struct Wrapper(T) where T: Clone`: + // + // #[derive(Debug)] + // pub enum WrapperParseError> where T: Clone { + // Parse(::Err), + // Validate(WrapperError), + // } let definition = if let Some(error_type_name) = maybe_error_type_name { quote! { - #[derive(Debug)] // #[derive(Debug)] - pub enum #parse_error_type_name #generics_with_fromstr_bound { // pub enum ParseErrorFoo> { - Parse(<#inner_type as ::core::str::FromStr>::Err), // Parse(::Err), - Validate(#error_type_name), // Validate(ErrorFoo), - } // } + #[derive(Debug)] + pub enum #parse_error_type_name #enum_generics #where_clause { + Parse(<#inner_type as ::core::str::FromStr>::Err), + Validate(#error_type_name), + } - impl #generics_with_fromstr_bound ::core::fmt::Display for #parse_error_type_name #generics_without_bounds { + impl #enum_generics ::core::fmt::Display for #parse_error_type_name #type_generics #where_clause { fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { match self { #parse_error_type_name::Parse(err) => write!(f, "Failed to parse {}: {:?}", #type_name_str, err), @@ -52,12 +64,12 @@ pub fn gen_def_parse_error( } } else { quote! { - #[derive(Debug)] // #[derive(Debug) - pub enum #parse_error_type_name #generics_with_fromstr_bound { // pub enum ParseErrorFoo> { - Parse(<#inner_type as ::core::str::FromStr>::Err), // Parse(::Err), - } // } + #[derive(Debug)] + pub enum #parse_error_type_name #enum_generics #where_clause { + Parse(<#inner_type as ::core::str::FromStr>::Err), + } - impl #generics_with_fromstr_bound ::core::fmt::Display for #parse_error_type_name #generics_without_bounds { + impl #enum_generics ::core::fmt::Display for #parse_error_type_name #type_generics #where_clause { fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { match self { #parse_error_type_name::Parse(err) => write!(f, "Failed to parse {}: {:?}", #type_name_str, err), @@ -80,8 +92,13 @@ pub fn gen_def_parse_error( &generics_with_fromstr_bound, syn::parse_quote!(::core::fmt::Debug), ); + let SplitGenerics { + impl_generics: error_impl_generics, + type_generics: error_type_generics, + where_clause: error_where_clause, + } = SplitGenerics::new(&generics_with_fromstr_and_debug_bounds); let impl_error = quote! { - impl #generics_with_fromstr_and_debug_bounds #error for #parse_error_type_name #generics_without_bounds { + impl #error_impl_generics #error for #parse_error_type_name #error_type_generics #error_where_clause { fn source(&self) -> Option<&(dyn #error + 'static)> { None } diff --git a/nutype_macros/src/common/generate/traits.rs b/nutype_macros/src/common/generate/traits.rs index a469708f..4518c5b8 100644 --- a/nutype_macros/src/common/generate/traits.rs +++ b/nutype_macros/src/common/generate/traits.rs @@ -5,7 +5,7 @@ use quote::{ToTokens, quote}; use syn::Generics; use crate::common::{ - generate::{add_bound_to_all_type_params, strip_trait_bounds_on_generics}, + generate::generics::{SplitGenerics, add_bound_to_all_type_params}, models::{ErrorTypePath, InnerType, TypeName}, }; @@ -61,15 +61,20 @@ pub fn gen_impl_trait_into( inner_type: impl Into, ) -> TokenStream { let inner_type: InnerType = inner_type.into(); + let SplitGenerics { + impl_generics, + type_generics, + where_clause, + } = SplitGenerics::new(generics); // NOTE: We're getting blank implementation of // Into for Type // by implementing // From for Inner quote! { - impl #generics ::core::convert::From<#type_name #generics> for #inner_type { + impl #impl_generics ::core::convert::From<#type_name #type_generics> for #inner_type #where_clause { #[inline] - fn from(value: #type_name #generics) -> Self { + fn from(value: #type_name #type_generics) -> Self { value.into_inner() } } @@ -81,15 +86,27 @@ pub fn gen_impl_trait_as_ref( generics: &Generics, inner_type: impl ToTokens, ) -> TokenStream { - let generics_without_bounds = strip_trait_bounds_on_generics(generics); - + let SplitGenerics { + impl_generics, + type_generics, + where_clause, + } = SplitGenerics::new(generics); + + // Example for `struct Collection(Vec) where T: Clone`: + // + // impl AsRef> for Collection where T: Clone { + // #[inline] + // fn as_ref(&self) -> &Vec { + // &self.0 + // } + // } quote! { - impl #generics ::core::convert::AsRef<#inner_type> for #type_name #generics_without_bounds { // impl AsRef> for Collection { - #[inline] // #[inline] - fn as_ref(&self) -> &#inner_type { // fn as_ref(&self) -> &Vec { - &self.0 // &self.0 - } // } - } // } + impl #impl_generics ::core::convert::AsRef<#inner_type> for #type_name #type_generics #where_clause { + #[inline] + fn as_ref(&self) -> &#inner_type { + &self.0 + } + } } } @@ -98,26 +115,35 @@ pub fn gen_impl_trait_deref( generics: &Generics, inner_type: impl ToTokens, ) -> TokenStream { - let generics_without_bounds = strip_trait_bounds_on_generics(generics); + let SplitGenerics { + impl_generics, + type_generics, + where_clause, + } = SplitGenerics::new(generics); quote! { - impl #generics ::core::ops::Deref for #type_name #generics_without_bounds { // impl Deref for Collection { - type Target = #inner_type; // type Target = Vec; - // - #[inline] // #[inline] - fn deref(&self) -> &Self::Target { // fn deref(&self) -> &Self::Target { - &self.0 // &self.0 - } // } - } // } + impl #impl_generics ::core::ops::Deref for #type_name #type_generics #where_clause { + type Target = #inner_type; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } + } } } pub fn gen_impl_trait_display(type_name: &TypeName, generics: &Generics) -> TokenStream { - let generics_without_bounds = strip_trait_bounds_on_generics(generics); let generics_with_display_bound = add_bound_to_all_type_params(generics, syn::parse_quote!(::core::fmt::Display)); + let SplitGenerics { + impl_generics, + type_generics, + where_clause, + } = SplitGenerics::new(&generics_with_display_bound); + quote! { - impl #generics_with_display_bound ::core::fmt::Display for #type_name #generics_without_bounds { + impl #impl_generics ::core::fmt::Display for #type_name #type_generics #where_clause { #[inline] fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { // A tiny wrapper function with trait boundary that improves error reporting. @@ -139,15 +165,19 @@ pub fn gen_impl_trait_borrow( generics: &Generics, borrowed_type: impl ToTokens, ) -> TokenStream { - let generics_without_bounds = strip_trait_bounds_on_generics(generics); + let SplitGenerics { + impl_generics, + type_generics, + where_clause, + } = SplitGenerics::new(generics); quote! { - impl #generics ::core::borrow::Borrow<#borrowed_type> for #type_name #generics_without_bounds { // impl Borrow> for Collection { - #[inline] // #[inline] - fn borrow(&self) -> &#borrowed_type { // fn borrow(&self) -> &Vec { - &self.0 // &self.0 - } // } - } // } + impl #impl_generics ::core::borrow::Borrow<#borrowed_type> for #type_name #type_generics #where_clause { + #[inline] + fn borrow(&self) -> &#borrowed_type { + &self.0 + } + } } } @@ -156,15 +186,19 @@ pub fn gen_impl_trait_from( generics: &Generics, inner_type: impl ToTokens, ) -> TokenStream { - let generics_without_bounds = strip_trait_bounds_on_generics(generics); + let SplitGenerics { + impl_generics, + type_generics, + where_clause, + } = SplitGenerics::new(generics); quote! { - impl #generics ::core::convert::From<#inner_type> for #type_name #generics_without_bounds { // impl From> for Collection { - #[inline] // #[inline] - fn from(raw_value: #inner_type) -> Self { // fn from(raw_value: Vec) -> Self { - Self::new(raw_value) // Self::new(raw_value) - } // } - } // } + impl #impl_generics ::core::convert::From<#inner_type> for #type_name #type_generics #where_clause { + #[inline] + fn from(raw_value: #inner_type) -> Self { + Self::new(raw_value) + } + } } } @@ -174,18 +208,21 @@ pub fn gen_impl_trait_try_from( inner_type: impl ToTokens, maybe_error_type_name: Option<&ErrorTypePath>, ) -> TokenStream { - let generics_without_bounds = strip_trait_bounds_on_generics(generics); + let SplitGenerics { + impl_generics, + type_generics, + where_clause, + } = SplitGenerics::new(generics); match maybe_error_type_name { Some(error_type_name) => { // The case when there are validation - // quote! { - impl #generics ::core::convert::TryFrom<#inner_type> for #type_name #generics_without_bounds { + impl #impl_generics ::core::convert::TryFrom<#inner_type> for #type_name #type_generics #where_clause { type Error = #error_type_name; #[inline] - fn try_from(raw_value: #inner_type) -> ::core::result::Result<#type_name #generics_without_bounds, Self::Error> { + fn try_from(raw_value: #inner_type) -> ::core::result::Result<#type_name #type_generics, Self::Error> { Self::try_new(raw_value) } } @@ -193,13 +230,12 @@ pub fn gen_impl_trait_try_from( } None => { // The case when there are no validation - // quote! { - impl #generics ::core::convert::TryFrom<#inner_type> for #type_name #generics_without_bounds { + impl #impl_generics ::core::convert::TryFrom<#inner_type> for #type_name #type_generics #where_clause { type Error = ::core::convert::Infallible; #[inline] - fn try_from(raw_value: #inner_type) -> ::core::result::Result<#type_name #generics_without_bounds, Self::Error> { + fn try_from(raw_value: #inner_type) -> ::core::result::Result<#type_name #type_generics, Self::Error> { Ok(Self::new(raw_value)) } } @@ -225,19 +261,23 @@ pub fn gen_impl_trait_from_str( &parse_error_type_name, ); - let generics_without_bounds = strip_trait_bounds_on_generics(generics); let generics_with_fromstr_bound = add_bound_to_all_type_params( generics, syn::parse_quote!(::core::str::FromStr), ); + let SplitGenerics { + impl_generics, + type_generics, + where_clause, + } = SplitGenerics::new(&generics_with_fromstr_bound); if let Some(_error_type_name) = maybe_error_type_name { // The case with validation quote! { #def_parse_error - impl #generics_with_fromstr_bound ::core::str::FromStr for #type_name #generics_without_bounds { - type Err = #parse_error_type_name #generics_without_bounds; + impl #impl_generics ::core::str::FromStr for #type_name #type_generics #where_clause { + type Err = #parse_error_type_name #type_generics; fn from_str(raw_string: &str) -> ::core::result::Result { let raw_value: #inner_type = raw_string.parse().map_err(#parse_error_type_name::Parse)?; @@ -250,8 +290,8 @@ pub fn gen_impl_trait_from_str( quote! { #def_parse_error - impl #generics_with_fromstr_bound ::core::str::FromStr for #type_name #generics_without_bounds { - type Err = #parse_error_type_name #generics_without_bounds; + impl #impl_generics ::core::str::FromStr for #type_name #type_generics #where_clause { + type Err = #parse_error_type_name #type_generics; fn from_str(raw_string: &str) -> ::core::result::Result { let value: #inner_type = raw_string.parse().map_err(#parse_error_type_name::Parse)?; @@ -263,15 +303,18 @@ pub fn gen_impl_trait_from_str( } pub fn gen_impl_trait_serde_serialize(type_name: &TypeName, generics: &Generics) -> TokenStream { - let generics_without_bounds = strip_trait_bounds_on_generics(generics); - // Turn `` into `` let all_generics_with_serialize_bound = add_bound_to_all_type_params(generics, syn::parse_quote!(::serde::Serialize)); + let SplitGenerics { + impl_generics, + type_generics, + where_clause, + } = SplitGenerics::new(&all_generics_with_serialize_bound); let type_name_str = type_name.to_string(); quote! { - impl #all_generics_with_serialize_bound ::serde::Serialize for #type_name #generics_without_bounds { + impl #impl_generics ::serde::Serialize for #type_name #type_generics #where_clause { fn serialize(&self, serializer: S) -> ::core::result::Result where S: ::serde::Serializer @@ -312,23 +355,41 @@ pub fn gen_impl_trait_serde_deserialize( all_generics.params.push(syn::parse_quote!('de)); all_generics }; - let all_generics_without_bounds = strip_trait_bounds_on_generics(&all_generics); - let type_generics_without_bounds = strip_trait_bounds_on_generics(type_generics); // Turn `<'de, T>` into `<'de, T: Deserialize<'de>>` let all_generics_with_deserialize_bound = add_bound_to_all_type_params(&all_generics, syn::parse_quote!(::serde::Deserialize<'de>)); + // Split for the outer impl (with 'de) + let SplitGenerics { + impl_generics: all_impl_generics, + type_generics: all_type_generics, + where_clause: all_where_clause, + } = SplitGenerics::new(&all_generics_with_deserialize_bound); + + // Split for the type itself (without 'de) + let SplitGenerics { + type_generics: inner_type_generics, + .. + } = SplitGenerics::new(type_generics); + + // For the visitor struct, we need generics without bounds but with 'de + let SplitGenerics { + impl_generics: visitor_impl_generics, + type_generics: _visitor_type_generics, + where_clause: visitor_where_clause, + } = SplitGenerics::new(&all_generics); + quote! { - impl #all_generics_with_deserialize_bound ::serde::Deserialize<'de> for #type_name #type_generics_without_bounds { + impl #all_impl_generics ::serde::Deserialize<'de> for #type_name #inner_type_generics #all_where_clause { fn deserialize>(deserializer: D) -> ::core::result::Result { - struct __Visitor #all_generics { - marker: ::core::marker::PhantomData<#type_name #type_generics_without_bounds>, + struct __Visitor #visitor_impl_generics #visitor_where_clause { + marker: ::core::marker::PhantomData<#type_name #inner_type_generics>, lifetime: ::core::marker::PhantomData<&'de ()>, } - impl #all_generics_with_deserialize_bound ::serde::de::Visitor<'de> for __Visitor #all_generics_without_bounds { - type Value = #type_name #type_generics_without_bounds; + impl #all_impl_generics ::serde::de::Visitor<'de> for __Visitor #all_type_generics #all_where_clause { + type Value = #type_name #inner_type_generics; fn expecting(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { write!(formatter, #expecting_str) @@ -365,12 +426,16 @@ pub fn gen_impl_trait_default( default_value: impl ToTokens, has_validation: bool, ) -> TokenStream { - let generics_without_bounds = strip_trait_bounds_on_generics(generics); + let SplitGenerics { + impl_generics, + type_generics, + where_clause, + } = SplitGenerics::new(generics); if has_validation { let tp = type_name.to_string(); quote!( - impl #generics ::core::default::Default for #type_name #generics_without_bounds { + impl #impl_generics ::core::default::Default for #type_name #type_generics #where_clause { fn default() -> Self { Self::try_new(#default_value) .unwrap_or_else(|err| { @@ -382,7 +447,7 @@ pub fn gen_impl_trait_default( ) } else { quote!( - impl #generics ::core::default::Default for #type_name #generics_without_bounds { + impl #impl_generics ::core::default::Default for #type_name #type_generics #where_clause { #[inline] fn default() -> Self { Self::new(#default_value) diff --git a/test_suite/tests/where_clause.rs b/test_suite/tests/where_clause.rs new file mode 100644 index 00000000..967b5b04 --- /dev/null +++ b/test_suite/tests/where_clause.rs @@ -0,0 +1,350 @@ +//! Tests for where clause support in generic newtypes (Issue #160) +//! +//! These tests verify that nutype properly handles `where` clauses, +//! including Higher-Ranked Trait Bounds (HRTB). + +mod basic_where_clause { + use nutype::nutype; + + #[test] + fn test_simple_where_clause() { + #[nutype(derive(Debug, Clone))] + struct Wrapper(T) + where + T: Default; + + let w = Wrapper::new(42i32); + assert_eq!(w.into_inner(), 42); + } + + #[test] + fn test_where_clause_with_multiple_predicates() { + // Test where clause with multiple type parameters + #[nutype(derive(Debug, Clone))] + struct Pair((T, U)) + where + T: Clone + Default, + U: Clone + Default; + + let p: Pair = Pair::new((42, String::from("hello"))); + assert_eq!(p.into_inner(), (42, String::from("hello"))); + } + + #[test] + fn test_where_clause_with_inline_bounds_combined() { + // Both inline bounds AND where clause + #[nutype(derive(Debug, Clone))] + struct Combined(T) + where + T: Default; + + let c = Combined::new(String::from("test")); + let cloned = c.clone(); + assert_eq!(cloned.into_inner(), "test"); + } +} + +mod hrtb_where_clause { + use nutype::nutype; + + #[test] + fn test_hrtb_into_iterator() { + #[nutype( + validate(predicate = |c| c.into_iter().next().is_some()), + derive(Debug) + )] + struct NonEmpty(C) + where + for<'a> &'a C: IntoIterator; + + let vec = vec![1, 2, 3]; + let non_empty = NonEmpty::try_new(vec).unwrap(); + assert_eq!(non_empty.into_inner(), vec![1, 2, 3]); + } + + #[test] + fn test_hrtb_validation_failure() { + #[nutype( + validate(predicate = |c| c.into_iter().next().is_some()), + derive(Debug) + )] + struct NonEmpty(C) + where + for<'a> &'a C: IntoIterator; + + let empty: Vec = vec![]; + assert!(NonEmpty::try_new(empty).is_err()); + } + + #[test] + fn test_hrtb_with_clone() { + #[nutype(derive(Debug, Clone))] + struct Cloneable(C) + where + for<'a> &'a C: IntoIterator, + C: Clone; + + let c = Cloneable::new(vec![1, 2, 3]); + let cloned = c.clone(); + assert_eq!(cloned.into_inner(), vec![1, 2, 3]); + } +} + +mod where_clause_with_sanitize { + use nutype::nutype; + + #[test] + fn test_sanitize_with_where() { + #[nutype( + sanitize(with = |mut v: Vec| { v.sort(); v }), + derive(Debug) + )] + struct Sorted(Vec) + where + T: Ord; + + let sorted = Sorted::new(vec![3, 1, 2]); + assert_eq!(sorted.into_inner(), vec![1, 2, 3]); + } + + #[test] + fn test_sanitize_and_validate_with_where() { + #[nutype( + sanitize(with = |mut v: Vec| { v.sort(); v }), + validate(predicate = |v| !v.is_empty()), + derive(Debug) + )] + struct SortedNonEmpty(Vec) + where + T: Ord; + + let sorted = SortedNonEmpty::try_new(vec![3, 1, 2]).unwrap(); + assert_eq!(sorted.into_inner(), vec![1, 2, 3]); + + // Empty should fail validation + let empty: Vec = vec![]; + assert!(SortedNonEmpty::try_new(empty).is_err()); + } +} + +mod where_clause_with_traits { + use nutype::nutype; + use std::fmt::Display; + + #[test] + fn test_display_with_where() { + #[nutype(derive(Debug, Display))] + struct ShowIt(T) + where + T: Display; + + let s = ShowIt::new("hello"); + assert_eq!(format!("{}", s), "hello"); + } + + #[test] + fn test_as_ref_with_where() { + #[nutype(derive(Debug, AsRef))] + struct RefIt(T) + where + T: Clone; + + let r = RefIt::new(String::from("test")); + assert_eq!(r.as_ref(), "test"); + } + + #[test] + fn test_deref_with_where() { + use std::ops::Deref; + + #[nutype(derive(Debug, Deref))] + struct DerefIt(T) + where + T: Clone; + + let d = DerefIt::new(String::from("test")); + assert_eq!(d.deref(), "test"); + } + + #[test] + fn test_borrow_with_where() { + use std::borrow::Borrow; + + #[nutype(derive(Debug, Borrow))] + struct BorrowIt(Vec) + where + T: Clone; + + let b = BorrowIt::new(vec![1, 2, 3]); + let borrowed: &Vec = b.borrow(); + assert_eq!(borrowed, &vec![1, 2, 3]); + } + + #[test] + fn test_from_with_where() { + #[nutype(derive(Debug, From))] + struct FromIt(Vec) + where + T: Clone; + + let f: FromIt = vec![1, 2, 3].into(); + assert_eq!(f.into_inner(), vec![1, 2, 3]); + } + + #[test] + fn test_try_from_with_where_and_validation() { + #[nutype( + validate(predicate = |v| !v.is_empty()), + derive(Debug, TryFrom) + )] + struct NonEmptyVec(Vec) + where + T: Clone; + + use std::convert::TryFrom; + let f = NonEmptyVec::try_from(vec![1, 2, 3]).unwrap(); + assert_eq!(f.into_inner(), vec![1, 2, 3]); + + let empty: Vec = vec![]; + assert!(NonEmptyVec::try_from(empty).is_err()); + } + + #[test] + fn test_into_with_where() { + #[nutype(derive(Debug, Into))] + struct IntoIt(Vec) + where + T: Clone; + + let i = IntoIt::new(vec![1, 2, 3]); + let v: Vec = i.into(); + assert_eq!(v, vec![1, 2, 3]); + } + + #[test] + fn test_default_with_where() { + #[nutype(derive(Debug, Default), default = vec![])] + struct DefaultVec(Vec) + where + T: Default; + + let d: DefaultVec = DefaultVec::default(); + assert_eq!(d.into_inner(), Vec::::new()); + } + + #[test] + fn test_default_with_where_and_validation() { + #[nutype( + validate(predicate = |v| v.len() <= 10), + derive(Debug, Default), + default = vec![] + )] + struct BoundedVec(Vec) + where + T: Default; + + let d: BoundedVec = BoundedVec::default(); + assert_eq!(d.into_inner(), Vec::::new()); + } +} + +#[cfg(feature = "serde")] +mod where_clause_with_serde { + use nutype::nutype; + use serde::Serialize; + + #[test] + fn test_serialize_with_where() { + #[nutype(derive(Debug, Serialize))] + struct SerIt(Vec) + where + T: Serialize + Clone; + + let s = SerIt::new(vec![1, 2, 3]); + let json = serde_json::to_string(&s).unwrap(); + assert_eq!(json, "[1,2,3]"); + } + + #[test] + fn test_deserialize_with_simple_where() { + // Note: Using simple Clone bound - the Deserialize bound is added by nutype + #[nutype(derive(Debug, Deserialize))] + struct DeIt(Vec) + where + T: Clone; + + let d: DeIt = serde_json::from_str("[1,2,3]").unwrap(); + assert_eq!(d.into_inner(), vec![1, 2, 3]); + } + + #[test] + fn test_serde_roundtrip_with_where() { + // Note: Using simple Clone bound - Serialize/Deserialize bounds are added by nutype + #[nutype(derive(Debug, Clone, Serialize, Deserialize))] + struct RoundTrip(Vec) + where + T: Clone; + + let original = RoundTrip::new(vec![1, 2, 3]); + let json = serde_json::to_string(&original).unwrap(); + let deserialized: RoundTrip = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.into_inner(), vec![1, 2, 3]); + } + + #[test] + fn test_serde_with_validation_and_where() { + #[nutype( + validate(predicate = |v| !v.is_empty()), + derive(Debug, Serialize, Deserialize) + )] + struct NonEmptySerde(Vec) + where + T: Clone; + + let original = NonEmptySerde::try_new(vec![1, 2, 3]).unwrap(); + let json = serde_json::to_string(&original).unwrap(); + let deserialized: NonEmptySerde = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.into_inner(), vec![1, 2, 3]); + } + + #[test] + fn test_serialize_with_non_generic_where() { + // Test where clause on non-generic type (using static bound) + #[nutype(derive(Debug, Serialize))] + struct StaticStr(String) + where + String: Clone; + + let s = StaticStr::new(String::from("hello")); + let json = serde_json::to_string(&s).unwrap(); + assert_eq!(json, "\"hello\""); + } +} + +mod into_iterator_with_where { + use nutype::nutype; + + #[test] + fn test_into_iterator_with_where_clause() { + #[nutype(derive(Debug, IntoIterator))] + struct IterableVec(Vec) + where + T: Clone; + + let v = IterableVec::new(vec![1, 2, 3]); + let collected: Vec = v.into_iter().collect(); + assert_eq!(collected, vec![1, 2, 3]); + } + + #[test] + fn test_into_iterator_ref_with_where_clause() { + #[nutype(derive(Debug, IntoIterator))] + struct IterableVec(Vec) + where + T: Clone; + + let v = IterableVec::new(vec![1, 2, 3]); + let collected: Vec<&i32> = (&v).into_iter().collect(); + assert_eq!(collected, vec![&1, &2, &3]); + } +}