Skip to content
33 changes: 28 additions & 5 deletions crates/cgp-macro-lib/src/derive_getter/blanket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use alloc::string::ToString;

use proc_macro2::TokenStream;
use quote::{ToTokens, quote};
use syn::{Ident, ItemImpl, ItemTrait, parse2};
use syn::{Ident, ItemImpl, ItemTrait, TraitItemType, parse2};

use crate::derive_getter::getter_field::GetterField;
use crate::derive_getter::{
Expand All @@ -14,19 +14,38 @@ pub fn derive_blanket_impl(
context_type: &Ident,
consumer_trait: &ItemTrait,
fields: &[GetterField],
field_assoc_type: &Option<TraitItemType>,
) -> syn::Result<ItemImpl> {
let consumer_name = &consumer_trait.ident;

let supertrait_constraints = consumer_trait.supertraits.clone();

let mut methods: TokenStream = TokenStream::new();
let mut items: TokenStream = TokenStream::new();

let mut generics = consumer_trait.generics.clone();

generics
.params
.insert(0, parse2(context_type.to_token_stream())?);

if let Some(field_assoc_type) = field_assoc_type {
let field_assoc_type_ident = &field_assoc_type.ident;

generics
.params
.push(parse2(field_assoc_type_ident.to_token_stream())?);

items.extend(quote! {
type #field_assoc_type_ident = #field_assoc_type_ident;
});

let field_constraints = &field_assoc_type.bounds;

generics.make_where_clause().predicates.push(parse2(quote! {
#field_assoc_type_ident: #field_constraints
})?);
}

let where_clause = generics.make_where_clause();

if !supertrait_constraints.is_empty() {
Expand All @@ -53,9 +72,13 @@ pub fn derive_blanket_impl(
None,
);

methods.extend(method);
items.extend(method);

let constraint = derive_getter_constraint(field, quote! { #field_symbol })?;
let constraint = derive_getter_constraint(
field,
quote! { #field_symbol },
&field_assoc_type.as_ref().map(|item| item.ident.clone()),
)?;

where_clause.predicates.push(parse2(quote! {
#receiver_type: #constraint
Expand All @@ -69,7 +92,7 @@ pub fn derive_blanket_impl(
impl #impl_generics #consumer_name #type_generics for #context_type
#where_clause
{
#methods
#items
}
})?;

Expand Down
14 changes: 9 additions & 5 deletions crates/cgp-macro-lib/src/derive_getter/constraint.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
use proc_macro2::TokenStream;
use quote::quote;
use syn::{TypeParamBound, parse2};
use syn::{Ident, TypeParamBound, parse_quote, parse2};

use crate::derive_getter::{FieldMode, GetterField};

pub fn derive_getter_constraint(
spec: &GetterField,
field_symbol: TokenStream,
field_assoc_type: &Option<Ident>,
) -> syn::Result<TypeParamBound> {
let provider_type = &spec.field_type;
let field_type = match field_assoc_type {
Some(field_assoc_type) => parse_quote! { #field_assoc_type },
None => spec.field_type.clone(),
};

let constraint = if spec.field_mut.is_none() {
if let FieldMode::Slice = spec.field_mode {
quote! {
HasField< #field_symbol, Value: AsRef< [ #provider_type ] > + 'static >
HasField< #field_symbol, Value: AsRef< [ #field_type ] > + 'static >
}
} else {
quote! {
HasField< #field_symbol, Value = #provider_type >
HasField< #field_symbol, Value = #field_type >
}
}
} else {
quote! {
HasFieldMut< #field_symbol, Value = #provider_type >
HasFieldMut< #field_symbol, Value = #field_type >
}
};

Expand Down
76 changes: 68 additions & 8 deletions crates/cgp-macro-lib/src/derive_getter/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use syn::spanned::Spanned;
use syn::token::{Comma, Mut};
use syn::{
Error, FnArg, GenericArgument, Ident, ItemTrait, PathArguments, PathSegment, ReturnType,
Signature, TraitItem, TraitItemFn, Type, TypePath, parse_quote, parse2,
Signature, TraitItem, TraitItemFn, TraitItemType, Type, TypePath, parse_quote, parse2,
};

use crate::derive_getter::getter_field::GetterField;
Expand All @@ -16,16 +16,45 @@ use crate::replace_self::replace_self_type;
pub fn parse_getter_fields(
context_type: &Ident,
consumer_trait: &ItemTrait,
) -> syn::Result<Vec<GetterField>> {
) -> syn::Result<(Vec<GetterField>, Option<TraitItemType>)> {
let mut fields = Vec::new();
let mut field_assoc_type: Option<TraitItemType> = None;

// Extract optional associated type first
for item in consumer_trait.items.iter() {
if let TraitItem::Type(item_type) = item {
if field_assoc_type.is_some() {
return Err(Error::new(
item_type.span(),
"at most one associated type is allowed in getter trait",
));
}

if !item_type.generics.params.is_empty() {
return Err(Error::new(
item_type.generics.params.span(),
"associated type in getter trait must not contain generic params",
));
}

field_assoc_type = Some(item_type.clone());
}
}

for item in consumer_trait.items.iter() {
match item {
TraitItem::Fn(method) => {
let getter_spec = parse_getter_method(context_type, method)?;
let getter_spec = parse_getter_method(
context_type,
method,
&field_assoc_type.as_ref().map(|item| item.ident.clone()),
)?;

fields.push(getter_spec);
}
TraitItem::Type(_) => {
// Already processed in the previous loop
}
_ => {
return Err(Error::new(
item.span(),
Expand All @@ -35,10 +64,37 @@ pub fn parse_getter_fields(
}
}

Ok(fields)
match (&field_assoc_type, fields.first(), fields.len()) {
(None, _, _) => {}
(Some(field_assoc_type), Some(field), 1) => {
let field_assoc_type_ident = &field_assoc_type.ident;
let field_type = &field.field_type;

if field_type != &parse_quote! { Self :: #field_assoc_type_ident }
&& field_type != &parse_quote! { #context_type :: #field_assoc_type_ident }
{
return Err(Error::new(
field.field_type.span(),
"getter method return type must match the associated type",
));
}
}
_ => {
return Err(Error::new(
consumer_trait.span(),
"if associated type is defined, exactly one getter method must be defined",
));
}
}

Ok((fields, field_assoc_type))
}

fn parse_getter_method(context_type: &Ident, method: &TraitItemFn) -> syn::Result<GetterField> {
fn parse_getter_method(
context_type: &Ident,
method: &TraitItemFn,
field_assoc_type: &Option<Ident>,
) -> syn::Result<GetterField> {
let signature = &method.sig;

validate_getter_method_signature(signature)?;
Expand All @@ -49,7 +105,7 @@ fn parse_getter_method(context_type: &Ident, method: &TraitItemFn) -> syn::Resul

let (receiver_mode, field_mut) = parse_receiver(context_type, arg)?;

let return_type = parse_return_type(context_type, &signature.output)?;
let return_type = parse_return_type(context_type, &signature.output, field_assoc_type)?;

let (field_type, field_mode) = parse_field_type(&return_type, &field_mut)?;

Expand Down Expand Up @@ -187,12 +243,16 @@ fn parse_receiver(context_ident: &Ident, arg: &FnArg) -> syn::Result<(ReceiverMo
}
}

fn parse_return_type(context_type: &Ident, return_type: &ReturnType) -> syn::Result<Type> {
fn parse_return_type(
context_type: &Ident,
return_type: &ReturnType,
field_assoc_type: &Option<Ident>,
) -> syn::Result<Type> {
match return_type {
ReturnType::Type(_, ty) => parse2(replace_self_type(
ty.to_token_stream(),
context_type.to_token_stream(),
&Vec::new(),
&field_assoc_type.iter().cloned().collect::<Vec<_>>(),
)),
_ => Err(Error::new(
return_type.span(),
Expand Down
49 changes: 42 additions & 7 deletions crates/cgp-macro-lib/src/derive_getter/use_field.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use proc_macro2::TokenStream;
use quote::{ToTokens, quote};
use syn::punctuated::Punctuated;
use syn::token::Plus;
use syn::{Generics, ItemImpl, ItemTrait, TypeParamBound, parse2};
use syn::{Generics, ItemImpl, ItemTrait, TraitItemType, TypeParamBound, parse2};

use crate::derive_getter::getter_field::GetterField;
use crate::derive_getter::{
Expand All @@ -13,6 +14,7 @@ pub fn derive_use_field_impl(
spec: &ComponentSpec,
provider_trait: &ItemTrait,
field: &GetterField,
field_assoc_type: &Option<TraitItemType>,
) -> syn::Result<ItemImpl> {
let context_type = &spec.context_type;
let provider_name = &provider_trait.ident;
Expand All @@ -26,20 +28,53 @@ pub fn derive_use_field_impl(

let tag_type = quote! { __Tag__ };

let method = derive_getter_method(&ContextArg::Ident(receiver_type.clone()), field, None, None);
let mut items = TokenStream::new();

let constraint = derive_getter_constraint(field, quote! { #tag_type })?;
let mut provider_generics = provider_trait.generics.clone();

field_constraints.push(constraint);
if let Some(field_assoc_type) = field_assoc_type {
let field_assoc_type_ident = &field_assoc_type.ident;

let mut provider_generics = provider_trait.generics.clone();
provider_generics
.params
.push(parse2(field_assoc_type_ident.to_token_stream())?);

items.extend(quote! {
type #field_assoc_type_ident = #field_assoc_type_ident;
});

let field_constraints = &field_assoc_type.bounds;

provider_generics
.make_where_clause()
.predicates
.push(parse2(quote! {
#field_assoc_type_ident: #field_constraints
})?);
}

items.extend(derive_getter_method(
&ContextArg::Ident(receiver_type.clone()),
field,
None,
None,
));

let constraint = derive_getter_constraint(
field,
quote! { #tag_type },
&field_assoc_type.as_ref().map(|item| item.ident.clone()),
)?;

field_constraints.push(constraint);

let mut where_clause = provider_generics.make_where_clause().clone();
where_clause
.predicates
.push(parse2(quote! { #receiver_type: #field_constraints })?);

let (impl_generics, type_generics, _) = provider_generics.split_for_impl();
let (_, type_generics, _) = provider_trait.generics.split_for_impl();
let (impl_generics, _, _) = provider_generics.split_for_impl();

let impl_generics = {
let mut generics: Generics = parse2(impl_generics.to_token_stream())?;
Expand All @@ -51,7 +86,7 @@ pub fn derive_use_field_impl(
impl #impl_generics #provider_name #type_generics for UseField< #tag_type >
#where_clause
{
#method
#items
}
})?;

Expand Down
Loading
Loading