From ecabdcc52be1908f34d823264aaf215279614c68 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 2 Apr 2025 23:26:55 +0200 Subject: [PATCH 01/12] Add first draft of a derive based type info crate --- Cargo.lock | 17 +++ Cargo.toml | 3 +- marrow-typeinfo-derive/Cargo.toml | 12 ++ marrow-typeinfo-derive/src/lib.rs | 116 ++++++++++++++++++ marrow-typeinfo/Cargo.toml | 8 ++ marrow-typeinfo/src/lib.rs | 189 ++++++++++++++++++++++++++++++ marrow-typeinfo/tests/derive.rs | 77 ++++++++++++ x.py | 6 +- 8 files changed, 423 insertions(+), 5 deletions(-) create mode 100644 marrow-typeinfo-derive/Cargo.toml create mode 100644 marrow-typeinfo-derive/src/lib.rs create mode 100644 marrow-typeinfo/Cargo.toml create mode 100644 marrow-typeinfo/src/lib.rs create mode 100644 marrow-typeinfo/tests/derive.rs diff --git a/Cargo.lock b/Cargo.lock index 19c1897..f9fd416 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1236,6 +1236,23 @@ dependencies = [ "serde", ] +[[package]] +name = "marrow-typeinfo" +version = "0.1.0" +dependencies = [ + "marrow", + "marrow-typeinfo-derive", +] + +[[package]] +name = "marrow-typeinfo-derive" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "memchr" version = "2.7.4" diff --git a/Cargo.toml b/Cargo.toml index e169fc3..a272f7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,5 @@ [workspace] -members = ["marrow", "test_with_arrow"] -default-members = ["marrow", "test_with_arrow"] +members = ["marrow", "marrow-typeinfo", "marrow-typeinfo-derive", "test_with_arrow"] resolver = "2" diff --git a/marrow-typeinfo-derive/Cargo.toml b/marrow-typeinfo-derive/Cargo.toml new file mode 100644 index 0000000..550c8b7 --- /dev/null +++ b/marrow-typeinfo-derive/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "marrow-typeinfo-derive" +version = "0.1.0" +edition = "2024" + +[lib] +proc-macro = true + +[dependencies] +syn = "2.0" +quote = "1.0" +proc-macro2 = "1.0" diff --git a/marrow-typeinfo-derive/src/lib.rs b/marrow-typeinfo-derive/src/lib.rs new file mode 100644 index 0000000..dcfc9e3 --- /dev/null +++ b/marrow-typeinfo-derive/src/lib.rs @@ -0,0 +1,116 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{ + Data, DeriveInput, Fields, FieldsNamed, Ident, Variant, parse_macro_input, + punctuated::Punctuated, token::Comma, +}; + +#[proc_macro_derive(TypeInfo)] +pub fn array_builder(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + + let expanded = match input.data { + Data::Struct(data) => match &data.fields { + Fields::Named(fields) => derive_for_struct(&input.ident, fields), + Fields::Unnamed(_) => { + panic!("Deriving TypeInfo for tuple structs is not yet supported") + } + Fields::Unit => { + panic!("Deriving TypeInfo for unit structs is not yet supported") + } + }, + Data::Enum(data) => derive_for_enum(&input.ident, &data.variants), + Data::Union(_) => { + panic!("Deriving TypeInfo for unions is currently not supported") + } + }; + + TokenStream::from(expanded) +} + +fn derive_for_struct(name: &Ident, fields: &FieldsNamed) -> proc_macro2::TokenStream { + let mut field_exprs = Vec::new(); + + for field in &fields.named { + let field_name = field.ident.as_ref().expect("named filed without ident"); + let ty = &field.ty; + + field_exprs.push(quote! { + fields.push(<#ty as ::marrow_typeinfo::TypeInfo>::get_field(stringify!(#field_name), context)?); + }) + } + + quote! { + const _: () = { + impl ::marrow_typeinfo::TypeInfo for #name { + fn get_field( + name: &::std::primitive::str, + context: &::marrow_typeinfo::Context, + ) -> ::std::result::Result< + ::marrow::datatypes::Field, + ::marrow_typeinfo::Error, + > { + let mut fields = ::std::vec::Vec::<::marrow::datatypes::Field>::new(); + #( #field_exprs; )* + + Ok(::marrow::datatypes::Field { + name: ::std::string::String::from(name), + data_type: ::marrow::datatypes::DataType::Struct(fields), + nullable: false, + metadata: ::std::default::Default::default(), + }) + } + } + }; + } +} + +fn derive_for_enum( + name: &Ident, + variants: &Punctuated, +) -> proc_macro2::TokenStream { + let mut variant_exprs = Vec::new(); + + for (idx, variant) in variants.iter().enumerate() { + let variant_name = &variant.ident; + + match variant.fields { + Fields::Unit => { + variant_exprs.push(quote! { + variants.push((i8::try_from(#idx)?, ::marrow::datatypes::Field { + name: ::std::string::String::from(stringify!(#variant_name)), + data_type: ::marrow::datatypes::DataType::Null, + nullable: true, + metadata: ::std::default::Default::default(), + })); + }); + } + Fields::Named(_) => panic!("enums with named fields are currently supported"), + Fields::Unnamed(_) => panic!("enums with unnamed fields are currently supported"), + } + } + + quote! { + const _: () = { + impl ::marrow_typeinfo::TypeInfo for #name { + fn get_field( + name: &::std::primitive::str, + context: &::marrow_typeinfo::Context, + ) -> ::std::result::Result< + ::marrow::datatypes::Field, + ::marrow_typeinfo::Error, + > { + let mut variants = ::std::vec::Vec::<(::std::primitive::i8, ::marrow::datatypes::Field)>::new(); + #( #variant_exprs; )* + + Ok(::marrow::datatypes::Field { + name: ::std::string::String::from(name), + data_type: ::marrow::datatypes::DataType::Union(variants, ::marrow::datatypes::UnionMode::Dense), + nullable: false, + metadata: ::std::default::Default::default(), + }) + } + } + }; + } +} diff --git a/marrow-typeinfo/Cargo.toml b/marrow-typeinfo/Cargo.toml new file mode 100644 index 0000000..c2d9f82 --- /dev/null +++ b/marrow-typeinfo/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "marrow-typeinfo" +version = "0.1.0" +edition = "2024" + +[dependencies] +marrow = { path = "../marrow" } +marrow-typeinfo-derive = { path = "../marrow-typeinfo-derive" } diff --git a/marrow-typeinfo/src/lib.rs b/marrow-typeinfo/src/lib.rs new file mode 100644 index 0000000..65911e4 --- /dev/null +++ b/marrow-typeinfo/src/lib.rs @@ -0,0 +1,189 @@ +use std::{ + any::{Any, TypeId}, + collections::HashMap, + convert::Infallible, + num::TryFromIntError, + rc::Rc, + sync::Arc, +}; + +use marrow::datatypes::{DataType, Field}; + +pub use marrow_typeinfo_derive::TypeInfo; + +// TODO: include the path in context to allow overwrites +#[derive(Debug, Default)] +pub struct Context { + data: HashMap>, +} + +struct DefaultStringType(DataType); + +impl Context { + pub fn new() -> Self { + Self::default() + } + + pub fn set(&mut self, value: T) { + let type_id = TypeId::of::(); + self.data.insert(type_id, Rc::new(value)); + } + + pub fn get(&self) -> Option<&T> { + let key = TypeId::of::(); + let value = self.data.get(&key)?; + let Some(value) = value.downcast_ref() else { + unreachable!(); + }; + Some(value) + } + + pub fn with_default_string_type(mut self, ty: DataType) -> Self { + // TODO: check that ty is compatible with strings + self.set(DefaultStringType(ty)); + self + } +} + +#[derive(Debug, PartialEq)] +pub struct Error(String); + +impl std::error::Error for Error {} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Error({:?})", self.0) + } +} + +impl From for Error { + fn from(_: Infallible) -> Self { + unreachable!() + } +} + +impl From for Error { + fn from(value: TryFromIntError) -> Self { + Self(value.to_string()) + } +} + +pub trait TypeInfo { + fn get_field(name: &str, context: &Context) -> Result; + + fn get_data_type(context: &Context) -> Result { + Ok(Self::get_field("item", context)?.data_type) + } +} + +macro_rules! define_primitive { + ($(($ty:ty, $dt:expr)),*) => { + $( + impl TypeInfo for $ty { + fn get_field(name: &str, context: &Context) -> Result { + let _ = context; + Ok(Field { + name: name.to_owned(), + data_type: $dt, + nullable: false, + metadata: Default::default(), + }) + } + } + )* + }; +} + +define_primitive!( + ((), DataType::Null), + (bool, DataType::Boolean), + (u8, DataType::UInt8), + (u16, DataType::UInt16), + (u32, DataType::UInt32), + (u64, DataType::UInt64), + (i8, DataType::Int8), + (i16, DataType::Int16), + (i32, DataType::Int32), + (i64, DataType::Int64) +); + +fn get_default_string_type(context: &Context) -> DataType { + if let Some(DefaultStringType(ty)) = context.get() { + ty.clone() + } else { + DataType::LargeUtf8 + } +} + +fn new_field(name: &str, data_type: DataType) -> Field { + Field { + name: name.to_owned(), + data_type, + nullable: false, + metadata: Default::default(), + } +} + +impl TypeInfo for &str { + fn get_field(name: &str, context: &Context) -> Result { + Ok(new_field(name, get_default_string_type(context))) + } +} + +impl TypeInfo for String { + fn get_field(name: &str, context: &Context) -> Result { + Ok(new_field(name, get_default_string_type(context))) + } +} + +impl TypeInfo for Box { + fn get_field(name: &str, context: &Context) -> Result { + Ok(new_field(name, get_default_string_type(context))) + } +} + +impl TypeInfo for Arc { + fn get_field(name: &str, context: &Context) -> Result { + Ok(new_field(name, get_default_string_type(context))) + } +} + +impl TypeInfo for [T; N] { + fn get_field(name: &str, context: &Context) -> Result { + let base_type = T::get_data_type(context)?; + let n = i32::try_from(N)?; + + // TODO: allow to customize + let data_type = match base_type { + DataType::UInt8 => DataType::FixedSizeBinary(n), + base_type => DataType::FixedSizeList( + Box::new(Field { + name: String::from("element"), + data_type: base_type, + nullable: false, + metadata: Default::default(), + }), + n, + ), + }; + + Ok(Field { + name: name.to_owned(), + data_type, + nullable: false, + metadata: Default::default(), + }) + } +} + +#[test] +fn examples() { + assert_eq!( + ::get_data_type(&Default::default()), + Ok(DataType::Int64) + ); + assert_eq!( + <[u8; 8] as TypeInfo>::get_data_type(&Default::default()), + Ok(DataType::FixedSizeBinary(8)) + ); +} diff --git a/marrow-typeinfo/tests/derive.rs b/marrow-typeinfo/tests/derive.rs new file mode 100644 index 0000000..8cebafb --- /dev/null +++ b/marrow-typeinfo/tests/derive.rs @@ -0,0 +1,77 @@ +use marrow::datatypes::{DataType, Field, UnionMode}; +use marrow_typeinfo::TypeInfo; + +#[test] +fn example() { + #[derive(TypeInfo)] + #[allow(dead_code)] + struct S { + a: i64, + b: [u8; 4], + } + + assert_eq!( + ::get_data_type(&Default::default()), + Ok(DataType::Struct(vec![ + Field { + name: String::from("a"), + data_type: DataType::Int64, + nullable: false, + metadata: Default::default(), + }, + Field { + name: String::from("b"), + data_type: DataType::FixedSizeBinary(4), + nullable: false, + metadata: Default::default(), + } + ])) + ); +} + +#[test] +fn fieldless_union() { + #[derive(TypeInfo)] + #[allow(dead_code)] + enum E { + A, + B, + C, + } + + assert_eq!( + ::get_data_type(&Default::default()), + Ok(DataType::Union( + vec![ + ( + 0, + Field { + name: String::from("A"), + data_type: DataType::Null, + nullable: true, + metadata: Default::default(), + } + ), + ( + 1, + Field { + name: String::from("B"), + data_type: DataType::Null, + nullable: true, + metadata: Default::default(), + } + ), + ( + 2, + Field { + name: String::from("C"), + data_type: DataType::Null, + nullable: true, + metadata: Default::default(), + } + ), + ], + UnionMode::Dense + )) + ); +} diff --git a/x.py b/x.py index 7ceb9e5..40488a1 100644 --- a/x.py +++ b/x.py @@ -145,7 +145,7 @@ def _workflow_check_steps(): @cmd(help="Format the code") def format(): - _sh(f"{python} -m black {_q(__file__)}") + #_sh(f"{python} -m black {_q(__file__)}") _sh("cargo fmt") # the impl files are not found by cargo fmt @@ -209,11 +209,11 @@ def doc(private=False, open=False): @cmd() def check_cargo_toml(): - import tomli + import tomllib print(":: check Cargo.toml") with open(self_path / "marrow" / "Cargo.toml", "rb") as fobj: - config = tomli.load(fobj) + config = tomllib.load(fobj) for label, features in [ ( From 5181aef641a5f0eff8c5711b15769b95c4e0c0c6 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 3 Apr 2025 20:02:52 +0200 Subject: [PATCH 02/12] Add more types, allow to overwrit derive logic for select fields --- marrow-typeinfo-derive/src/lib.rs | 59 +++++++++-- marrow-typeinfo/src/lib.rs | 159 +++++++++++++++++++++++------- marrow-typeinfo/tests/derive.rs | 47 ++++++++- marrow/src/types.rs | 3 + 4 files changed, 223 insertions(+), 45 deletions(-) diff --git a/marrow-typeinfo-derive/src/lib.rs b/marrow-typeinfo-derive/src/lib.rs index dcfc9e3..db1ebb8 100644 --- a/marrow-typeinfo-derive/src/lib.rs +++ b/marrow-typeinfo-derive/src/lib.rs @@ -1,11 +1,11 @@ use proc_macro::TokenStream; use quote::quote; use syn::{ - Data, DeriveInput, Fields, FieldsNamed, Ident, Variant, parse_macro_input, - punctuated::Punctuated, token::Comma, + Attribute, Data, DeriveInput, Expr, Fields, FieldsNamed, Ident, Lit, Meta, Token, Variant, + parse_macro_input, punctuated::Punctuated, token::Comma, }; -#[proc_macro_derive(TypeInfo)] +#[proc_macro_derive(TypeInfo, attributes(marrow_type_info))] pub fn array_builder(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); @@ -28,6 +28,37 @@ pub fn array_builder(input: TokenStream) -> TokenStream { TokenStream::from(expanded) } +fn get_use_call(attrs: &[Attribute]) -> Option { + for attr in attrs { + if !attr.path().is_ident("marrow_type_info") { + continue; + } + + let nested = attr + .parse_args_with(Punctuated::::parse_terminated) + .unwrap(); + for meta in nested { + match meta { + Meta::NameValue(meta) => { + if !meta.path.is_ident("with") { + continue; + } + match meta.value { + Expr::Lit(lit) => match lit.lit { + Lit::Str(str) => return Some(Ident::new(&str.value(), str.span())), + _ => unimplemented!(), + }, + _ => unimplemented!(), + } + } + _ => unimplemented!(), + } + } + } + + None +} + fn derive_for_struct(name: &Ident, fields: &FieldsNamed) -> proc_macro2::TokenStream { let mut field_exprs = Vec::new(); @@ -35,9 +66,16 @@ fn derive_for_struct(name: &Ident, fields: &FieldsNamed) -> proc_macro2::TokenSt let field_name = field.ident.as_ref().expect("named filed without ident"); let ty = &field.ty; - field_exprs.push(quote! { - fields.push(<#ty as ::marrow_typeinfo::TypeInfo>::get_field(stringify!(#field_name), context)?); - }) + if let Some(func) = get_use_call(&field.attrs) { + field_exprs.push(quote! { + // TODO: pass context, include type? + fields.push(#func(context.get_context(), stringify!(#field_name))); + }); + } else { + field_exprs.push(quote! { + fields.push(context.get_context().get_field::<#ty>(stringify!(#field_name))?); + }) + } } quote! { @@ -45,7 +83,7 @@ fn derive_for_struct(name: &Ident, fields: &FieldsNamed) -> proc_macro2::TokenSt impl ::marrow_typeinfo::TypeInfo for #name { fn get_field( name: &::std::primitive::str, - context: &::marrow_typeinfo::Context, + context: ::marrow_typeinfo::ContextRef<'_>, ) -> ::std::result::Result< ::marrow::datatypes::Field, ::marrow_typeinfo::Error, @@ -74,6 +112,11 @@ fn derive_for_enum( for (idx, variant) in variants.iter().enumerate() { let variant_name = &variant.ident; + if let Some(func) = get_use_call(&variant.attrs) { + variant_exprs.push(quote! { #func(stringify!(#variant_name)) }); + continue; + } + match variant.fields { Fields::Unit => { variant_exprs.push(quote! { @@ -95,7 +138,7 @@ fn derive_for_enum( impl ::marrow_typeinfo::TypeInfo for #name { fn get_field( name: &::std::primitive::str, - context: &::marrow_typeinfo::Context, + context: ::marrow_typeinfo::ContextRef<'_>, ) -> ::std::result::Result< ::marrow::datatypes::Field, ::marrow_typeinfo::Error, diff --git a/marrow-typeinfo/src/lib.rs b/marrow-typeinfo/src/lib.rs index 65911e4..db6428d 100644 --- a/marrow-typeinfo/src/lib.rs +++ b/marrow-typeinfo/src/lib.rs @@ -7,18 +7,26 @@ use std::{ sync::Arc, }; -use marrow::datatypes::{DataType, Field}; +use marrow::{ + datatypes::{DataType, Field}, + types::f16, +}; +/// Derive [TypeInfo] for a given type +/// +/// Currently structs and enums with any type of lifetime parameters are supported. pub use marrow_typeinfo_derive::TypeInfo; // TODO: include the path in context to allow overwrites -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct Context { data: HashMap>, } struct DefaultStringType(DataType); +struct LargeList(bool); + impl Context { pub fn new() -> Self { Self::default() @@ -43,6 +51,29 @@ impl Context { self.set(DefaultStringType(ty)); self } + + pub fn with_large_list(mut self, large_list: bool) -> Self { + self.set(LargeList(large_list)); + self + } + + pub fn get_field(&self, name: &str) -> Result { + // TODO: allow to overwrite child fields + T::get_field(name, ContextRef(self)) + } + + pub fn get_data_type(&self) -> Result { + Ok(self.get_field::("item")?.data_type) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct ContextRef<'a>(&'a Context); + +impl<'a> ContextRef<'a> { + pub fn get_context(self) -> &'a Context { + self.0 + } } #[derive(Debug, PartialEq)] @@ -68,19 +99,25 @@ impl From for Error { } } +/// Get the Arrow type information for a given Rust type +/// +/// The functions cannot be called directly. First construct a [Context], then call the +/// corresponding methods. pub trait TypeInfo { - fn get_field(name: &str, context: &Context) -> Result; + /// See [Context::get_field] + fn get_field(name: &str, context: ContextRef<'_>) -> Result; - fn get_data_type(context: &Context) -> Result { + /// See [Context::get_data_type] + fn get_data_type(context: ContextRef<'_>) -> Result { Ok(Self::get_field("item", context)?.data_type) } } macro_rules! define_primitive { - ($(($ty:ty, $dt:expr)),*) => { + ($(($ty:ty, $dt:expr),)*) => { $( impl TypeInfo for $ty { - fn get_field(name: &str, context: &Context) -> Result { + fn get_field(name: &str, context: ContextRef<'_>) -> Result { let _ = context; Ok(Field { name: name.to_owned(), @@ -104,7 +141,10 @@ define_primitive!( (i8, DataType::Int8), (i16, DataType::Int16), (i32, DataType::Int32), - (i64, DataType::Int64) + (i64, DataType::Int64), + (f16, DataType::Float16), + (f32, DataType::Float32), + (f64, DataType::Float64), ); fn get_default_string_type(context: &Context) -> DataType { @@ -125,46 +165,51 @@ fn new_field(name: &str, data_type: DataType) -> Field { } impl TypeInfo for &str { - fn get_field(name: &str, context: &Context) -> Result { - Ok(new_field(name, get_default_string_type(context))) + fn get_field(name: &str, context: ContextRef<'_>) -> Result { + Ok(new_field( + name, + get_default_string_type(context.get_context()), + )) } } impl TypeInfo for String { - fn get_field(name: &str, context: &Context) -> Result { - Ok(new_field(name, get_default_string_type(context))) + fn get_field(name: &str, context: ContextRef<'_>) -> Result { + Ok(new_field( + name, + get_default_string_type(context.get_context()), + )) } } impl TypeInfo for Box { - fn get_field(name: &str, context: &Context) -> Result { - Ok(new_field(name, get_default_string_type(context))) + fn get_field(name: &str, context: ContextRef<'_>) -> Result { + Ok(new_field( + name, + get_default_string_type(context.get_context()), + )) } } impl TypeInfo for Arc { - fn get_field(name: &str, context: &Context) -> Result { - Ok(new_field(name, get_default_string_type(context))) + fn get_field(name: &str, context: ContextRef<'_>) -> Result { + Ok(new_field( + name, + get_default_string_type(context.get_context()), + )) } } impl TypeInfo for [T; N] { - fn get_field(name: &str, context: &Context) -> Result { - let base_type = T::get_data_type(context)?; + fn get_field(name: &str, context: ContextRef<'_>) -> Result { + let base_field = context.get_context().get_field::("element")?; let n = i32::try_from(N)?; // TODO: allow to customize - let data_type = match base_type { - DataType::UInt8 => DataType::FixedSizeBinary(n), - base_type => DataType::FixedSizeList( - Box::new(Field { - name: String::from("element"), - data_type: base_type, - nullable: false, - metadata: Default::default(), - }), - n, - ), + let data_type = if matches!(base_field.data_type, DataType::UInt8) { + DataType::FixedSizeBinary(n) + } else { + DataType::FixedSizeList(Box::new(base_field), n) }; Ok(Field { @@ -176,14 +221,62 @@ impl TypeInfo for [T; N] { } } +fn get_list_field(name: &str, context: ContextRef<'_>) -> Result { + let larget_list = if let Some(LargeList(large_list)) = context.get_context().get() { + *large_list + } else { + false + }; + + let base_field = context.get_context().get_field::("element")?; + + Ok(Field { + name: name.to_owned(), + data_type: if larget_list { + DataType::LargeList(Box::new(base_field)) + } else { + DataType::List(Box::new(base_field)) + }, + nullable: false, + metadata: Default::default(), + }) +} + +impl TypeInfo for Vec { + fn get_field(name: &str, context: ContextRef<'_>) -> Result { + get_list_field::(name, context) + } +} + +impl TypeInfo for &[T] { + fn get_field(name: &str, context: ContextRef<'_>) -> Result { + get_list_field::(name, context) + } +} + +impl TypeInfo for Option { + fn get_field(name: &str, context: ContextRef<'_>) -> Result { + let mut base_field = T::get_field(name, context)?; + base_field.nullable = true; + Ok(base_field) + } +} + +impl TypeInfo for HashMap { + fn get_field(name: &str, context: ContextRef<'_>) -> Result { + let key_field = context.get_context().get_field::("key")?; + let value_field = context.get_context().get_field::("value")?; + let entry_field = new_field("entry", DataType::Struct(vec![key_field, value_field])); + + Ok(new_field(name, DataType::Map(Box::new(entry_field), false))) + } +} + #[test] fn examples() { + assert_eq!(Context::new().get_data_type::(), Ok(DataType::Int64)); assert_eq!( - ::get_data_type(&Default::default()), - Ok(DataType::Int64) - ); - assert_eq!( - <[u8; 8] as TypeInfo>::get_data_type(&Default::default()), + Context::new().get_data_type::<[u8; 8]>(), Ok(DataType::FixedSizeBinary(8)) ); } diff --git a/marrow-typeinfo/tests/derive.rs b/marrow-typeinfo/tests/derive.rs index 8cebafb..0a9d0b6 100644 --- a/marrow-typeinfo/tests/derive.rs +++ b/marrow-typeinfo/tests/derive.rs @@ -1,5 +1,5 @@ -use marrow::datatypes::{DataType, Field, UnionMode}; -use marrow_typeinfo::TypeInfo; +use marrow::datatypes::{DataType, Field, TimeUnit, UnionMode}; +use marrow_typeinfo::{Context, TypeInfo}; #[test] fn example() { @@ -11,7 +11,7 @@ fn example() { } assert_eq!( - ::get_data_type(&Default::default()), + Context::default().get_data_type::(), Ok(DataType::Struct(vec![ Field { name: String::from("a"), @@ -29,6 +29,45 @@ fn example() { ); } +#[test] +fn customize() { + #[derive(TypeInfo)] + #[allow(dead_code)] + struct S { + #[marrow_type_info(with = "timestamp_field")] + a: i64, + b: [u8; 4], + } + + assert_eq!( + Context::default().get_data_type::(), + Ok(DataType::Struct(vec![ + Field { + name: String::from("a"), + data_type: DataType::Timestamp(TimeUnit::Millisecond, None), + nullable: false, + metadata: Default::default(), + }, + Field { + name: String::from("b"), + data_type: DataType::FixedSizeBinary(4), + nullable: false, + metadata: Default::default(), + } + ])) + ); +} + +// TODO: pass context +fn timestamp_field(_: &Context, name: &str) -> Field { + Field { + name: String::from(name), + data_type: DataType::Timestamp(TimeUnit::Millisecond, None), + nullable: false, + metadata: Default::default(), + } +} + #[test] fn fieldless_union() { #[derive(TypeInfo)] @@ -40,7 +79,7 @@ fn fieldless_union() { } assert_eq!( - ::get_data_type(&Default::default()), + Context::default().get_data_type::(), Ok(DataType::Union( vec![ ( diff --git a/marrow/src/types.rs b/marrow/src/types.rs index ec64d1e..4a63387 100644 --- a/marrow/src/types.rs +++ b/marrow/src/types.rs @@ -1,5 +1,8 @@ //! Specialized element types of arrays +/// Reexport the used f16 type +pub use half::f16; + /// Represent a calendar interval as days and milliseconds #[derive(Debug, PartialEq, Clone, Copy, bytemuck::AnyBitPattern, bytemuck::NoUninit)] #[repr(C)] From 880fd96d0e8db091761452e7d246d5eccad0c4c6 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 3 Apr 2025 20:25:35 +0200 Subject: [PATCH 03/12] Fix unit field def, implement new type union wrappers --- marrow-typeinfo-derive/src/lib.rs | 114 +++++++++++++++++++++++------- marrow-typeinfo/src/lib.rs | 13 +++- marrow-typeinfo/tests/derive.rs | 56 +++++++++++++++ 3 files changed, 157 insertions(+), 26 deletions(-) diff --git a/marrow-typeinfo-derive/src/lib.rs b/marrow-typeinfo-derive/src/lib.rs index db1ebb8..9d97fa9 100644 --- a/marrow-typeinfo-derive/src/lib.rs +++ b/marrow-typeinfo-derive/src/lib.rs @@ -28,35 +28,87 @@ pub fn array_builder(input: TokenStream) -> TokenStream { TokenStream::from(expanded) } -fn get_use_call(attrs: &[Attribute]) -> Option { - for attr in attrs { - if !attr.path().is_ident("marrow_type_info") { - continue; - } +#[derive(Debug, Default)] +struct FieldArgs { + // TODO: use a path here + with: Option, +} - let nested = attr - .parse_args_with(Punctuated::::parse_terminated) - .unwrap(); - for meta in nested { - match meta { - Meta::NameValue(meta) => { - if !meta.path.is_ident("with") { - continue; - } - match meta.value { - Expr::Lit(lit) => match lit.lit { - Lit::Str(str) => return Some(Ident::new(&str.value(), str.span())), +impl FieldArgs { + pub fn from_attrs(attrs: &[Attribute]) -> Self { + let mut result = Self::default(); + + for attr in attrs { + if !attr.path().is_ident("marrow_type_info") { + continue; + } + + let nested = attr + .parse_args_with(Punctuated::::parse_terminated) + .unwrap(); + for meta in nested { + match meta { + Meta::NameValue(meta) => { + if !meta.path.is_ident("with") { + continue; + } + match meta.value { + Expr::Lit(lit) => match lit.lit { + Lit::Str(str) => { + result.with = Some(Ident::new(&str.value(), str.span())); + } + _ => unimplemented!(), + }, _ => unimplemented!(), - }, - _ => unimplemented!(), + } } + _ => unimplemented!(), } - _ => unimplemented!(), } } + result } +} - None +#[derive(Debug, Default)] +struct VariantArgs { + with: Option, +} + +impl VariantArgs { + pub fn from_attrs(attrs: &[Attribute]) -> Self { + let mut result = Self::default(); + + for attr in attrs { + if !attr.path().is_ident("marrow_type_info") { + continue; + } + + let nested = attr + .parse_args_with(Punctuated::::parse_terminated) + .unwrap(); + for meta in nested { + match meta { + Meta::NameValue(meta) => { + if !meta.path.is_ident("with") { + continue; + } + match meta.value { + Expr::Lit(lit) => match lit.lit { + Lit::Str(str) => { + result.with = Some(Ident::new(&str.value(), str.span())); + } + _ => unimplemented!(), + }, + _ => unimplemented!(), + } + } + _ => unimplemented!(), + } + } + } + result + } } fn derive_for_struct(name: &Ident, fields: &FieldsNamed) -> proc_macro2::TokenStream { @@ -65,8 +117,9 @@ fn derive_for_struct(name: &Ident, fields: &FieldsNamed) -> proc_macro2::TokenSt for field in &fields.named { let field_name = field.ident.as_ref().expect("named filed without ident"); let ty = &field.ty; + let args = FieldArgs::from_attrs(&field.attrs); - if let Some(func) = get_use_call(&field.attrs) { + if let Some(func) = args.with.as_ref() { field_exprs.push(quote! { // TODO: pass context, include type? fields.push(#func(context.get_context(), stringify!(#field_name))); @@ -111,13 +164,14 @@ fn derive_for_enum( for (idx, variant) in variants.iter().enumerate() { let variant_name = &variant.ident; + let variant_args = VariantArgs::from_attrs(&variant.attrs); - if let Some(func) = get_use_call(&variant.attrs) { + if let Some(func) = variant_args.with.as_ref() { variant_exprs.push(quote! { #func(stringify!(#variant_name)) }); continue; } - match variant.fields { + match &variant.fields { Fields::Unit => { variant_exprs.push(quote! { variants.push((i8::try_from(#idx)?, ::marrow::datatypes::Field { @@ -128,8 +182,18 @@ fn derive_for_enum( })); }); } - Fields::Named(_) => panic!("enums with named fields are currently supported"), + Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { + let Some(field) = fields.unnamed.first() else { + unreachable!("checked in guard that exactly 1 field is available"); + }; + + let field_ty = &field.ty; + variant_exprs.push(quote! { + variants.push((i8::try_from(#idx)?, context.get_context().get_field::<#field_ty>(stringify!(#variant_name))?)); + }); + } Fields::Unnamed(_) => panic!("enums with unnamed fields are currently supported"), + Fields::Named(_) => panic!("enums with named fields are currently supported"), } } diff --git a/marrow-typeinfo/src/lib.rs b/marrow-typeinfo/src/lib.rs index db6428d..9948635 100644 --- a/marrow-typeinfo/src/lib.rs +++ b/marrow-typeinfo/src/lib.rs @@ -132,7 +132,6 @@ macro_rules! define_primitive { } define_primitive!( - ((), DataType::Null), (bool, DataType::Boolean), (u8, DataType::UInt8), (u16, DataType::UInt16), @@ -147,6 +146,18 @@ define_primitive!( (f64, DataType::Float64), ); +impl TypeInfo for () { + fn get_field(name: &str, context: ContextRef<'_>) -> Result { + let _ = context; + Ok(Field { + name: name.to_owned(), + data_type: DataType::Null, + nullable: true, + metadata: Default::default(), + }) + } +} + fn get_default_string_type(context: &Context) -> DataType { if let Some(DefaultStringType(ty)) = context.get() { ty.clone() diff --git a/marrow-typeinfo/tests/derive.rs b/marrow-typeinfo/tests/derive.rs index 0a9d0b6..65c7c73 100644 --- a/marrow-typeinfo/tests/derive.rs +++ b/marrow-typeinfo/tests/derive.rs @@ -114,3 +114,59 @@ fn fieldless_union() { )) ); } + +#[test] +fn new_type_enum() { + #[derive(TypeInfo)] + #[allow(dead_code)] + enum Enum { + Struct(Struct), + Int64(i64), + } + + #[derive(TypeInfo)] + struct Struct { + a: bool, + b: (), + } + + assert_eq!( + Context::default().get_data_type::(), + Ok(DataType::Union( + vec![ + ( + 0, + Field { + name: String::from("Struct"), + data_type: DataType::Struct(vec![ + Field { + name: String::from("a"), + data_type: DataType::Boolean, + nullable: false, + metadata: Default::default(), + }, + Field { + name: String::from("b"), + data_type: DataType::Null, + nullable: true, + metadata: Default::default(), + }, + ]), + nullable: false, + metadata: Default::default(), + } + ), + ( + 1, + Field { + name: String::from("Int64"), + data_type: DataType::Int64, + nullable: false, + metadata: Default::default(), + } + ), + ], + UnionMode::Dense + )) + ); +} From 9184306101a9a17ea72266beca5390311bac7630 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Fri, 4 Apr 2025 12:33:05 +0200 Subject: [PATCH 04/12] Implement newtype structs, tuple structs, tuple variants, newtype structs --- marrow-typeinfo-derive/src/lib.rs | 168 +++++++++++++++++++----------- marrow-typeinfo/tests/derive.rs | 152 +++++++++++++++++++++++++-- marrow/src/datatypes.rs | 2 +- 3 files changed, 250 insertions(+), 72 deletions(-) diff --git a/marrow-typeinfo-derive/src/lib.rs b/marrow-typeinfo-derive/src/lib.rs index 9d97fa9..ad744b4 100644 --- a/marrow-typeinfo-derive/src/lib.rs +++ b/marrow-typeinfo-derive/src/lib.rs @@ -1,28 +1,22 @@ use proc_macro::TokenStream; use quote::quote; use syn::{ - Attribute, Data, DeriveInput, Expr, Fields, FieldsNamed, Ident, Lit, Meta, Token, Variant, - parse_macro_input, punctuated::Punctuated, token::Comma, + Attribute, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, Ident, Lit, LitStr, + Meta, Token, parse_macro_input, punctuated::Punctuated, spanned::Spanned, }; #[proc_macro_derive(TypeInfo, attributes(marrow_type_info))] pub fn array_builder(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); + if !input.generics.params.is_empty() { + panic!("Deriving TypeInfo for generic is not supported") + } + let expanded = match input.data { - Data::Struct(data) => match &data.fields { - Fields::Named(fields) => derive_for_struct(&input.ident, fields), - Fields::Unnamed(_) => { - panic!("Deriving TypeInfo for tuple structs is not yet supported") - } - Fields::Unit => { - panic!("Deriving TypeInfo for unit structs is not yet supported") - } - }, - Data::Enum(data) => derive_for_enum(&input.ident, &data.variants), - Data::Union(_) => { - panic!("Deriving TypeInfo for unions is currently not supported") - } + Data::Struct(data) => derive_for_struct(&input.ident, &data), + Data::Enum(data) => derive_for_enum(&input.ident, &data), + Data::Union(_) => panic!("Deriving TypeInfo for unions is not supported"), }; TokenStream::from(expanded) @@ -111,25 +105,47 @@ impl VariantArgs { } } -fn derive_for_struct(name: &Ident, fields: &FieldsNamed) -> proc_macro2::TokenStream { - let mut field_exprs = Vec::new(); - - for field in &fields.named { - let field_name = field.ident.as_ref().expect("named filed without ident"); - let ty = &field.ty; - let args = FieldArgs::from_attrs(&field.attrs); - - if let Some(func) = args.with.as_ref() { - field_exprs.push(quote! { - // TODO: pass context, include type? - fields.push(#func(context.get_context(), stringify!(#field_name))); - }); - } else { - field_exprs.push(quote! { - fields.push(context.get_context().get_field::<#ty>(stringify!(#field_name))?); - }) +fn derive_for_struct(name: &Ident, data: &DataStruct) -> proc_macro2::TokenStream { + let fields = get_fields(&data.fields); + let body = match fields.as_slice() { + [] => panic!(), + [(NameSource::Index, _, field)] => { + // TODO: ensure no args + let field_ty = &field.ty; + quote! { context.get_context().get_field::<#field_ty>(name) } } - } + fields => { + let mut field_exprs = Vec::new(); + + for (_, field_name, field) in fields { + let ty = &field.ty; + let args = FieldArgs::from_attrs(&field.attrs); + + if let Some(func) = args.with.as_ref() { + field_exprs.push(quote! { + // TODO: pass context, include type? + fields.push(#func::<#ty>(context.get_context(), #field_name)); + }); + } else { + field_exprs.push(quote! { + fields.push(context.get_context().get_field::<#ty>(#field_name)?); + }) + } + } + + quote! { + let mut fields = ::std::vec::Vec::<::marrow::datatypes::Field>::new(); + #( #field_exprs; )* + + Ok(::marrow::datatypes::Field { + name: ::std::string::String::from(name), + data_type: ::marrow::datatypes::DataType::Struct(fields), + nullable: false, + metadata: ::std::default::Default::default(), + }) + } + } + }; quote! { const _: () = { @@ -141,28 +157,17 @@ fn derive_for_struct(name: &Ident, fields: &FieldsNamed) -> proc_macro2::TokenSt ::marrow::datatypes::Field, ::marrow_typeinfo::Error, > { - let mut fields = ::std::vec::Vec::<::marrow::datatypes::Field>::new(); - #( #field_exprs; )* - - Ok(::marrow::datatypes::Field { - name: ::std::string::String::from(name), - data_type: ::marrow::datatypes::DataType::Struct(fields), - nullable: false, - metadata: ::std::default::Default::default(), - }) + #body } } }; } } -fn derive_for_enum( - name: &Ident, - variants: &Punctuated, -) -> proc_macro2::TokenStream { +fn derive_for_enum(name: &Ident, data: &DataEnum) -> proc_macro2::TokenStream { let mut variant_exprs = Vec::new(); - for (idx, variant) in variants.iter().enumerate() { + for (idx, variant) in data.variants.iter().enumerate() { let variant_name = &variant.ident; let variant_args = VariantArgs::from_attrs(&variant.attrs); @@ -171,29 +176,43 @@ fn derive_for_enum( continue; } - match &variant.fields { - Fields::Unit => { + let variant_idx = i8::try_from(idx).unwrap(); + + let fields = get_fields(&variant.fields); + match fields.as_slice() { + [] => { variant_exprs.push(quote! { - variants.push((i8::try_from(#idx)?, ::marrow::datatypes::Field { + (#variant_idx, ::marrow::datatypes::Field { name: ::std::string::String::from(stringify!(#variant_name)), data_type: ::marrow::datatypes::DataType::Null, nullable: true, metadata: ::std::default::Default::default(), - })); + }) }); } - Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { - let Some(field) = fields.unnamed.first() else { - unreachable!("checked in guard that exactly 1 field is available"); - }; - + [(NameSource::Index, _, field)] => { let field_ty = &field.ty; variant_exprs.push(quote! { - variants.push((i8::try_from(#idx)?, context.get_context().get_field::<#field_ty>(stringify!(#variant_name))?)); + (#variant_idx, context.get_context().get_field::<#field_ty>(stringify!(#variant_name))?) + }); + } + fields => { + let mut field_exprs = Vec::new(); + for (_, field_name, field) in fields { + let field_ty = &field.ty; + field_exprs.push(quote! { + context.get_context().get_field::<#field_ty>(#field_name)? + }); + } + variant_exprs.push(quote! { + (#variant_idx, ::marrow::datatypes::Field { + name: ::std::string::String::from(stringify!(#variant_name)), + data_type: ::marrow::datatypes::DataType::Struct(vec![#(#field_exprs),*]), + nullable: false, + metadata: ::std::default::Default::default(), + }) }); } - Fields::Unnamed(_) => panic!("enums with unnamed fields are currently supported"), - Fields::Named(_) => panic!("enums with named fields are currently supported"), } } @@ -208,7 +227,7 @@ fn derive_for_enum( ::marrow_typeinfo::Error, > { let mut variants = ::std::vec::Vec::<(::std::primitive::i8, ::marrow::datatypes::Field)>::new(); - #( #variant_exprs; )* + #( variants.push(#variant_exprs); )* Ok(::marrow::datatypes::Field { name: ::std::string::String::from(name), @@ -221,3 +240,32 @@ fn derive_for_enum( }; } } + +fn get_fields(fields: &Fields) -> Vec<(NameSource, LitStr, &Field)> { + let mut result = Vec::new(); + match fields { + Fields::Unit => {} + Fields::Named(fields) => { + for field in &fields.named { + let Some(name) = field.ident.as_ref() else { + unreachable!("Named field must have a name"); + }; + let name = LitStr::new(&name.to_string(), name.span()); + result.push((NameSource::Ident, name, field)); + } + } + Fields::Unnamed(fields) => { + for (idx, field) in fields.unnamed.iter().enumerate() { + let name = LitStr::new(&idx.to_string(), field.span()); + result.push((NameSource::Index, name, field)); + } + } + } + result +} + +#[derive(Debug, Clone, Copy, PartialEq)] +enum NameSource { + Ident, + Index, +} diff --git a/marrow-typeinfo/tests/derive.rs b/marrow-typeinfo/tests/derive.rs index 65c7c73..469315a 100644 --- a/marrow-typeinfo/tests/derive.rs +++ b/marrow-typeinfo/tests/derive.rs @@ -1,4 +1,7 @@ -use marrow::datatypes::{DataType, Field, TimeUnit, UnionMode}; +use marrow::{ + datatypes::{DataType, Field, TimeUnit, UnionMode}, + types::f16, +}; use marrow_typeinfo::{Context, TypeInfo}; #[test] @@ -29,6 +32,41 @@ fn example() { ); } +#[test] +fn newtype() { + #[derive(TypeInfo)] + #[allow(dead_code)] + struct S(f16); + + assert_eq!( + Context::default().get_data_type::(), + Ok(DataType::Float16) + ); +} + +#[test] +fn tuple() { + #[derive(TypeInfo)] + #[allow(dead_code)] + struct S(u8, [u8; 4]); + + assert_eq!( + Context::default().get_data_type::(), + Ok(DataType::Struct(vec![ + Field { + name: String::from("0"), + data_type: DataType::UInt8, + ..Field::default() + }, + Field { + name: String::from("1"), + data_type: DataType::FixedSizeBinary(4), + ..Field::default() + }, + ])) + ); +} + #[test] fn customize() { #[derive(TypeInfo)] @@ -39,6 +77,15 @@ fn customize() { b: [u8; 4], } + fn timestamp_field(_: &Context, name: &str) -> Field { + Field { + name: String::from(name), + data_type: DataType::Timestamp(TimeUnit::Millisecond, None), + nullable: false, + metadata: Default::default(), + } + } + assert_eq!( Context::default().get_data_type::(), Ok(DataType::Struct(vec![ @@ -58,16 +105,6 @@ fn customize() { ); } -// TODO: pass context -fn timestamp_field(_: &Context, name: &str) -> Field { - Field { - name: String::from(name), - data_type: DataType::Timestamp(TimeUnit::Millisecond, None), - nullable: false, - metadata: Default::default(), - } -} - #[test] fn fieldless_union() { #[derive(TypeInfo)] @@ -125,6 +162,7 @@ fn new_type_enum() { } #[derive(TypeInfo)] + #[allow(dead_code)] struct Struct { a: bool, b: (), @@ -170,3 +208,95 @@ fn new_type_enum() { )) ); } + +#[test] +fn new_tuple_enum() { + #[derive(TypeInfo)] + #[allow(dead_code)] + enum Enum { + Int64(i64), + Tuple(i8, u32), + } + + assert_eq!( + Context::default().get_data_type::(), + Ok(DataType::Union( + vec![ + ( + 0, + Field { + name: String::from("Int64"), + data_type: DataType::Int64, + ..Field::default() + } + ), + ( + 1, + Field { + name: String::from("Tuple"), + data_type: DataType::Struct(vec![ + Field { + name: String::from("0"), + data_type: DataType::Int8, + ..Field::default() + }, + Field { + name: String::from("1"), + data_type: DataType::UInt32, + ..Field::default() + }, + ]), + ..Field::default() + } + ), + ], + UnionMode::Dense + )) + ); +} + +#[test] +fn new_struct_enum() { + #[derive(TypeInfo)] + #[allow(dead_code)] + enum Enum { + Int64(i64), + Struct { a: f32, b: String }, + } + + assert_eq!( + Context::default().get_data_type::(), + Ok(DataType::Union( + vec![ + ( + 0, + Field { + name: String::from("Int64"), + data_type: DataType::Int64, + ..Field::default() + } + ), + ( + 1, + Field { + name: String::from("Struct"), + data_type: DataType::Struct(vec![ + Field { + name: String::from("a"), + data_type: DataType::Float32, + ..Field::default() + }, + Field { + name: String::from("b"), + data_type: DataType::LargeUtf8, + ..Field::default() + }, + ]), + ..Field::default() + } + ), + ], + UnionMode::Dense + )) + ); +} diff --git a/marrow/src/datatypes.rs b/marrow/src/datatypes.rs index b5e5bfd..b20ea63 100644 --- a/marrow/src/datatypes.rs +++ b/marrow/src/datatypes.rs @@ -37,7 +37,7 @@ impl std::default::Default for Field { Self { data_type: DataType::Null, name: Default::default(), - nullable: Default::default(), + nullable: false, metadata: Default::default(), } } From 418a66c06dc76cd2b6d6d1246c6455ba21592cee Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Fri, 4 Apr 2025 13:40:56 +0200 Subject: [PATCH 05/12] Implement overwrites --- marrow-typeinfo-derive/src/lib.rs | 53 +++---- marrow-typeinfo/src/lib.rs | 249 ++++++++++++++++++------------ marrow-typeinfo/tests/derive.rs | 60 +++++-- 3 files changed, 224 insertions(+), 138 deletions(-) diff --git a/marrow-typeinfo-derive/src/lib.rs b/marrow-typeinfo-derive/src/lib.rs index ad744b4..6a66b7d 100644 --- a/marrow-typeinfo-derive/src/lib.rs +++ b/marrow-typeinfo-derive/src/lib.rs @@ -112,7 +112,7 @@ fn derive_for_struct(name: &Ident, data: &DataStruct) -> proc_macro2::TokenStrea [(NameSource::Index, _, field)] => { // TODO: ensure no args let field_ty = &field.ty; - quote! { context.get_context().get_field::<#field_ty>(name) } + quote! { <#field_ty>::get_field(context) } } fields => { let mut field_exprs = Vec::new(); @@ -123,12 +123,11 @@ fn derive_for_struct(name: &Ident, data: &DataStruct) -> proc_macro2::TokenStrea if let Some(func) = args.with.as_ref() { field_exprs.push(quote! { - // TODO: pass context, include type? - fields.push(#func::<#ty>(context.get_context(), #field_name)); + fields.push(context.nest(#field_name, #func::<#ty>)?); }); } else { field_exprs.push(quote! { - fields.push(context.get_context().get_field::<#ty>(#field_name)?); + fields.push(context.get_field::<#ty>(#field_name)?); }) } } @@ -138,7 +137,7 @@ fn derive_for_struct(name: &Ident, data: &DataStruct) -> proc_macro2::TokenStrea #( #field_exprs; )* Ok(::marrow::datatypes::Field { - name: ::std::string::String::from(name), + name: ::std::string::String::from(context.get_name()), data_type: ::marrow::datatypes::DataType::Struct(fields), nullable: false, metadata: ::std::default::Default::default(), @@ -151,12 +150,8 @@ fn derive_for_struct(name: &Ident, data: &DataStruct) -> proc_macro2::TokenStrea const _: () = { impl ::marrow_typeinfo::TypeInfo for #name { fn get_field( - name: &::std::primitive::str, - context: ::marrow_typeinfo::ContextRef<'_>, - ) -> ::std::result::Result< - ::marrow::datatypes::Field, - ::marrow_typeinfo::Error, - > { + context: ::marrow_typeinfo::Context<'_>, + ) -> ::marrow_typeinfo::Result<::marrow::datatypes::Field> { #body } } @@ -169,6 +164,7 @@ fn derive_for_enum(name: &Ident, data: &DataEnum) -> proc_macro2::TokenStream { for (idx, variant) in data.variants.iter().enumerate() { let variant_name = &variant.ident; + let variant_name = LitStr::new(&variant_name.to_string(), variant_name.span()); let variant_args = VariantArgs::from_attrs(&variant.attrs); if let Some(func) = variant_args.with.as_ref() { @@ -181,19 +177,22 @@ fn derive_for_enum(name: &Ident, data: &DataEnum) -> proc_macro2::TokenStream { let fields = get_fields(&variant.fields); match fields.as_slice() { [] => { + // use nesting to allow overwrites variant_exprs.push(quote! { - (#variant_idx, ::marrow::datatypes::Field { - name: ::std::string::String::from(stringify!(#variant_name)), - data_type: ::marrow::datatypes::DataType::Null, - nullable: true, - metadata: ::std::default::Default::default(), - }) + (#variant_idx, context.nest(#variant_name, |context| { + Ok(::marrow::datatypes::Field { + name: ::std::string::String::from(context.get_name()), + data_type: ::marrow::datatypes::DataType::Null, + nullable: true, + metadata: ::std::default::Default::default(), + }) + })?) }); } [(NameSource::Index, _, field)] => { let field_ty = &field.ty; variant_exprs.push(quote! { - (#variant_idx, context.get_context().get_field::<#field_ty>(stringify!(#variant_name))?) + (#variant_idx, context.nest(#variant_name, <#field_ty>::get_field)?) }); } fields => { @@ -201,16 +200,16 @@ fn derive_for_enum(name: &Ident, data: &DataEnum) -> proc_macro2::TokenStream { for (_, field_name, field) in fields { let field_ty = &field.ty; field_exprs.push(quote! { - context.get_context().get_field::<#field_ty>(#field_name)? + context.get_field::<#field_ty>(#field_name)? }); } variant_exprs.push(quote! { - (#variant_idx, ::marrow::datatypes::Field { - name: ::std::string::String::from(stringify!(#variant_name)), + (#variant_idx, context.nest(#variant_name, |context| Ok(::marrow::datatypes::Field { + name: ::std::string::String::from(context.get_name()), data_type: ::marrow::datatypes::DataType::Struct(vec![#(#field_exprs),*]), nullable: false, metadata: ::std::default::Default::default(), - }) + }))?) }); } } @@ -220,17 +219,13 @@ fn derive_for_enum(name: &Ident, data: &DataEnum) -> proc_macro2::TokenStream { const _: () = { impl ::marrow_typeinfo::TypeInfo for #name { fn get_field( - name: &::std::primitive::str, - context: ::marrow_typeinfo::ContextRef<'_>, - ) -> ::std::result::Result< - ::marrow::datatypes::Field, - ::marrow_typeinfo::Error, - > { + context: ::marrow_typeinfo::Context<'_>, + ) -> ::marrow_typeinfo::Result<::marrow::datatypes::Field> { let mut variants = ::std::vec::Vec::<(::std::primitive::i8, ::marrow::datatypes::Field)>::new(); #( variants.push(#variant_exprs); )* Ok(::marrow::datatypes::Field { - name: ::std::string::String::from(name), + name: ::std::string::String::from(context.get_name()), data_type: ::marrow::datatypes::DataType::Union(variants, ::marrow::datatypes::UnionMode::Dense), nullable: false, metadata: ::std::default::Default::default(), diff --git a/marrow-typeinfo/src/lib.rs b/marrow-typeinfo/src/lib.rs index 9948635..0103072 100644 --- a/marrow-typeinfo/src/lib.rs +++ b/marrow-typeinfo/src/lib.rs @@ -17,21 +17,38 @@ use marrow::{ /// Currently structs and enums with any type of lifetime parameters are supported. pub use marrow_typeinfo_derive::TypeInfo; -// TODO: include the path in context to allow overwrites -#[derive(Debug, Default, Clone)] -pub struct Context { - data: HashMap>, -} +pub type Result = std::result::Result; -struct DefaultStringType(DataType); +#[derive(Debug, PartialEq)] +pub struct Error(String); -struct LargeList(bool); +impl std::error::Error for Error {} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Error({:?})", self.0) + } +} -impl Context { - pub fn new() -> Self { - Self::default() +impl From for Error { + fn from(_: Infallible) -> Self { + unreachable!() + } +} + +impl From for Error { + fn from(value: TryFromIntError) -> Self { + Self(value.to_string()) } +} + +#[derive(Debug, Default)] +pub struct Options { + data: HashMap>, + overwrites: HashMap, +} +impl Options { pub fn set(&mut self, value: T) { let type_id = TypeId::of::(); self.data.insert(type_id, Rc::new(value)); @@ -46,70 +63,117 @@ impl Context { Some(value) } - pub fn with_default_string_type(mut self, ty: DataType) -> Self { - // TODO: check that ty is compatible with strings - self.set(DefaultStringType(ty)); + pub fn with_default_string_type(mut self, data_type: DataType) -> Self { + // TOOD: check for valid string type + self.set(DefaultStringType(data_type)); self } - pub fn with_large_list(mut self, large_list: bool) -> Self { - self.set(LargeList(large_list)); + pub fn with_default_list_index_type(mut self, list_type: ListIndexType) -> Self { + self.set(LargeList(matches!(list_type, ListIndexType::Int64))); self } - pub fn get_field(&self, name: &str) -> Result { - // TODO: allow to overwrite child fields - T::get_field(name, ContextRef(self)) + pub fn overwrite(mut self, path: &str, field: Field) -> Self { + self.overwrites.insert(path.to_owned(), field); + self } +} - pub fn get_data_type(&self) -> Result { - Ok(self.get_field::("item")?.data_type) +pub enum ListIndexType { + Int32, + Int64, +} + +impl TryFrom for ListIndexType { + type Error = Error; + + fn try_from(value: DataType) -> std::result::Result { + match value { + DataType::Int32 => Ok(Self::Int32), + DataType::Int64 => Ok(Self::Int64), + dt => Err(Error(format!( + "Cannot interpretr {dt:?} as a ListIndexType" + ))), + } } } #[derive(Debug, Clone, Copy)] -pub struct ContextRef<'a>(&'a Context); +pub struct Context<'a> { + path: &'a str, + name: &'a str, + options: &'a Options, +} -impl<'a> ContextRef<'a> { - pub fn get_context(self) -> &'a Context { - self.0 +impl<'a> Context<'a> { + pub fn get_name(&self) -> &str { + self.name } -} -#[derive(Debug, PartialEq)] -pub struct Error(String); + pub fn get_path(&self) -> &str { + self.path + } -impl std::error::Error for Error {} + pub fn get_options(&self) -> &Options { + self.options + } -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Error({:?})", self.0) + pub fn get_field(&self, name: &str) -> Result { + self.nest(name, T::get_field) } -} -impl From for Error { - fn from(_: Infallible) -> Self { - unreachable!() + pub fn nest) -> Result>( + &self, + name: &str, + scope: F, + ) -> Result { + let path = format!("{}.{}", self.path, name); + + if let Some(overwrite) = self.options.overwrites.get(&path) { + let mut overwrite = overwrite.clone(); + overwrite.name = String::from(name); + return Ok(overwrite); + } + + let child_context = Context { + path: &path, + name, + options: self.options, + }; + + scope(child_context) } } -impl From for Error { - fn from(value: TryFromIntError) -> Self { - Self(value.to_string()) - } +pub fn get_field(name: &str, options: &Options) -> Result { + let context = Context { + path: "$", + name, + options, + }; + T::get_field(context) } +pub fn get_data_type(options: &Options) -> Result { + Ok(get_field::("item", options)?.data_type) +} + +struct DefaultStringType(DataType); + +struct LargeList(bool); + /// Get the Arrow type information for a given Rust type /// /// The functions cannot be called directly. First construct a [Context], then call the /// corresponding methods. pub trait TypeInfo { /// See [Context::get_field] - fn get_field(name: &str, context: ContextRef<'_>) -> Result; + fn get_field(context: Context<'_>) -> Result; /// See [Context::get_data_type] - fn get_data_type(context: ContextRef<'_>) -> Result { - Ok(Self::get_field("item", context)?.data_type) + fn get_data_type(context: Context<'_>) -> Result { + Ok(Self::get_field(context)?.data_type) } } @@ -117,13 +181,11 @@ macro_rules! define_primitive { ($(($ty:ty, $dt:expr),)*) => { $( impl TypeInfo for $ty { - fn get_field(name: &str, context: ContextRef<'_>) -> Result { - let _ = context; + fn get_field(context: Context<'_>) -> Result { Ok(Field { - name: name.to_owned(), + name: context.get_name().to_owned(), data_type: $dt, - nullable: false, - metadata: Default::default(), + ..Field::default() }) } } @@ -147,10 +209,10 @@ define_primitive!( ); impl TypeInfo for () { - fn get_field(name: &str, context: ContextRef<'_>) -> Result { + fn get_field(context: Context<'_>) -> Result { let _ = context; Ok(Field { - name: name.to_owned(), + name: context.get_name().to_owned(), data_type: DataType::Null, nullable: true, metadata: Default::default(), @@ -158,14 +220,6 @@ impl TypeInfo for () { } } -fn get_default_string_type(context: &Context) -> DataType { - if let Some(DefaultStringType(ty)) = context.get() { - ty.clone() - } else { - DataType::LargeUtf8 - } -} - fn new_field(name: &str, data_type: DataType) -> Field { Field { name: name.to_owned(), @@ -175,45 +229,42 @@ fn new_field(name: &str, data_type: DataType) -> Field { } } +fn new_string_field(context: Context<'_>) -> Field { + let ty = if let Some(DefaultStringType(ty)) = context.get_options().get() { + ty.clone() + } else { + DataType::LargeUtf8 + }; + new_field(context.get_name(), ty) +} + impl TypeInfo for &str { - fn get_field(name: &str, context: ContextRef<'_>) -> Result { - Ok(new_field( - name, - get_default_string_type(context.get_context()), - )) + fn get_field(context: Context<'_>) -> Result { + Ok(new_string_field(context)) } } impl TypeInfo for String { - fn get_field(name: &str, context: ContextRef<'_>) -> Result { - Ok(new_field( - name, - get_default_string_type(context.get_context()), - )) + fn get_field(context: Context<'_>) -> Result { + Ok(new_string_field(context)) } } impl TypeInfo for Box { - fn get_field(name: &str, context: ContextRef<'_>) -> Result { - Ok(new_field( - name, - get_default_string_type(context.get_context()), - )) + fn get_field(context: Context<'_>) -> Result { + Ok(new_string_field(context)) } } impl TypeInfo for Arc { - fn get_field(name: &str, context: ContextRef<'_>) -> Result { - Ok(new_field( - name, - get_default_string_type(context.get_context()), - )) + fn get_field(context: Context<'_>) -> Result { + Ok(new_string_field(context)) } } impl TypeInfo for [T; N] { - fn get_field(name: &str, context: ContextRef<'_>) -> Result { - let base_field = context.get_context().get_field::("element")?; + fn get_field(context: Context<'_>) -> Result { + let base_field = context.get_field::("element")?; let n = i32::try_from(N)?; // TODO: allow to customize @@ -224,7 +275,7 @@ impl TypeInfo for [T; N] { }; Ok(Field { - name: name.to_owned(), + name: context.get_name().to_owned(), data_type, nullable: false, metadata: Default::default(), @@ -232,17 +283,17 @@ impl TypeInfo for [T; N] { } } -fn get_list_field(name: &str, context: ContextRef<'_>) -> Result { - let larget_list = if let Some(LargeList(large_list)) = context.get_context().get() { +fn new_list_field(context: Context<'_>) -> Result { + let larget_list = if let Some(LargeList(large_list)) = context.get_options().get() { *large_list } else { false }; - let base_field = context.get_context().get_field::("element")?; + let base_field = context.get_field::("element")?; Ok(Field { - name: name.to_owned(), + name: context.get_name().to_owned(), data_type: if larget_list { DataType::LargeList(Box::new(base_field)) } else { @@ -254,40 +305,46 @@ fn get_list_field(name: &str, context: ContextRef<'_>) -> Result TypeInfo for Vec { - fn get_field(name: &str, context: ContextRef<'_>) -> Result { - get_list_field::(name, context) + fn get_field(context: Context<'_>) -> Result { + new_list_field::(context) } } impl TypeInfo for &[T] { - fn get_field(name: &str, context: ContextRef<'_>) -> Result { - get_list_field::(name, context) + fn get_field(context: Context<'_>) -> Result { + new_list_field::(context) } } impl TypeInfo for Option { - fn get_field(name: &str, context: ContextRef<'_>) -> Result { - let mut base_field = T::get_field(name, context)?; + fn get_field(context: Context<'_>) -> Result { + let mut base_field = T::get_field(context)?; base_field.nullable = true; Ok(base_field) } } impl TypeInfo for HashMap { - fn get_field(name: &str, context: ContextRef<'_>) -> Result { - let key_field = context.get_context().get_field::("key")?; - let value_field = context.get_context().get_field::("value")?; + fn get_field(context: Context<'_>) -> Result { + let key_field = context.get_field::("key")?; + let value_field = context.get_field::("value")?; let entry_field = new_field("entry", DataType::Struct(vec![key_field, value_field])); - Ok(new_field(name, DataType::Map(Box::new(entry_field), false))) + Ok(new_field( + context.get_name(), + DataType::Map(Box::new(entry_field), false), + )) } } #[test] fn examples() { - assert_eq!(Context::new().get_data_type::(), Ok(DataType::Int64)); assert_eq!( - Context::new().get_data_type::<[u8; 8]>(), + get_data_type::(&Options::default()), + Ok(DataType::Int64) + ); + assert_eq!( + get_data_type::<[u8; 8]>(&Options::default()), Ok(DataType::FixedSizeBinary(8)) ); } diff --git a/marrow-typeinfo/tests/derive.rs b/marrow-typeinfo/tests/derive.rs index 469315a..8743b87 100644 --- a/marrow-typeinfo/tests/derive.rs +++ b/marrow-typeinfo/tests/derive.rs @@ -2,7 +2,7 @@ use marrow::{ datatypes::{DataType, Field, TimeUnit, UnionMode}, types::f16, }; -use marrow_typeinfo::{Context, TypeInfo}; +use marrow_typeinfo::{Context, Options, Result, TypeInfo}; #[test] fn example() { @@ -14,7 +14,7 @@ fn example() { } assert_eq!( - Context::default().get_data_type::(), + marrow_typeinfo::get_data_type::(&Options::default()), Ok(DataType::Struct(vec![ Field { name: String::from("a"), @@ -32,6 +32,40 @@ fn example() { ); } +#[test] +fn overwrites() { + #[derive(TypeInfo)] + #[allow(dead_code)] + struct S { + a: i64, + b: [u8; 4], + } + + assert_eq!( + marrow_typeinfo::get_data_type::(&Options::default().overwrite( + "$.b", + Field { + data_type: DataType::Binary, + ..Field::default() + } + )), + Ok(DataType::Struct(vec![ + Field { + name: String::from("a"), + data_type: DataType::Int64, + nullable: false, + metadata: Default::default(), + }, + Field { + name: String::from("b"), + data_type: DataType::Binary, + nullable: false, + metadata: Default::default(), + } + ])) + ); +} + #[test] fn newtype() { #[derive(TypeInfo)] @@ -39,7 +73,7 @@ fn newtype() { struct S(f16); assert_eq!( - Context::default().get_data_type::(), + marrow_typeinfo::get_data_type::(&Options::default()), Ok(DataType::Float16) ); } @@ -51,7 +85,7 @@ fn tuple() { struct S(u8, [u8; 4]); assert_eq!( - Context::default().get_data_type::(), + marrow_typeinfo::get_data_type::(&Options::default()), Ok(DataType::Struct(vec![ Field { name: String::from("0"), @@ -77,17 +111,17 @@ fn customize() { b: [u8; 4], } - fn timestamp_field(_: &Context, name: &str) -> Field { - Field { - name: String::from(name), + fn timestamp_field(context: Context<'_>) -> Result { + Ok(Field { + name: String::from(context.get_name()), data_type: DataType::Timestamp(TimeUnit::Millisecond, None), nullable: false, metadata: Default::default(), - } + }) } assert_eq!( - Context::default().get_data_type::(), + marrow_typeinfo::get_data_type::(&Options::default()), Ok(DataType::Struct(vec![ Field { name: String::from("a"), @@ -116,7 +150,7 @@ fn fieldless_union() { } assert_eq!( - Context::default().get_data_type::(), + marrow_typeinfo::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( @@ -169,7 +203,7 @@ fn new_type_enum() { } assert_eq!( - Context::default().get_data_type::(), + marrow_typeinfo::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( @@ -219,7 +253,7 @@ fn new_tuple_enum() { } assert_eq!( - Context::default().get_data_type::(), + marrow_typeinfo::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( @@ -265,7 +299,7 @@ fn new_struct_enum() { } assert_eq!( - Context::default().get_data_type::(), + marrow_typeinfo::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( From ee2a365a0e17b1f990f58ff56ee474e08ddf353c Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Fri, 4 Apr 2025 13:58:59 +0200 Subject: [PATCH 06/12] Add references, tuple impls --- marrow-typeinfo/src/lib.rs | 70 ++++++++++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 10 deletions(-) diff --git a/marrow-typeinfo/src/lib.rs b/marrow-typeinfo/src/lib.rs index 0103072..9d9e308 100644 --- a/marrow-typeinfo/src/lib.rs +++ b/marrow-typeinfo/src/lib.rs @@ -168,20 +168,15 @@ struct LargeList(bool); /// The functions cannot be called directly. First construct a [Context], then call the /// corresponding methods. pub trait TypeInfo { - /// See [Context::get_field] - fn get_field(context: Context<'_>) -> Result; - - /// See [Context::get_data_type] - fn get_data_type(context: Context<'_>) -> Result { - Ok(Self::get_field(context)?.data_type) - } + /// See [crate::get_field] + fn get_field(context: Context<'_>) -> Result; } macro_rules! define_primitive { ($(($ty:ty, $dt:expr),)*) => { $( impl TypeInfo for $ty { - fn get_field(context: Context<'_>) -> Result { + fn get_field(context: Context<'_>) -> Result { Ok(Field { name: context.get_name().to_owned(), data_type: $dt, @@ -208,8 +203,20 @@ define_primitive!( (f64, DataType::Float64), ); +impl TypeInfo for &T { + fn get_field(context: Context<'_>) -> Result { + T::get_field(context) + } +} + +impl TypeInfo for &mut T { + fn get_field(context: Context<'_>) -> Result { + T::get_field(context) + } +} + impl TypeInfo for () { - fn get_field(context: Context<'_>) -> Result { + fn get_field(context: Context<'_>) -> Result { let _ = context; Ok(Field { name: context.get_name().to_owned(), @@ -239,7 +246,7 @@ fn new_string_field(context: Context<'_>) -> Field { } impl TypeInfo for &str { - fn get_field(context: Context<'_>) -> Result { + fn get_field(context: Context<'_>) -> Result { Ok(new_string_field(context)) } } @@ -337,6 +344,49 @@ impl TypeInfo for HashMap { } } +macro_rules! impl_tuples { + ($( ( $($name:ident,)* ), )*) => { + $( + impl<$($name: TypeInfo),*> TypeInfo for ( $($name,)* ) { + #[allow(unused_assignments)] + fn get_field(context: Context<'_>) -> Result { + let mut idx = 0; + let mut fields = Vec::new(); + $( + fields.push(context.get_field::<$name>(&idx.to_string())?); + idx += 1; + )* + + Ok(Field { + name: context.get_name().to_owned(), + data_type: DataType::Struct(fields), + ..Field::default() + }) + } + } + )* + }; +} + +impl_tuples!( + (A,), + (A, B,), + (A, B, C,), + (A, B, C, D,), + (A, B, C, D, E,), + (A, B, C, D, E, F,), + (A, B, C, D, E, F, G,), + (A, B, C, D, E, F, G, H,), + (A, B, C, D, E, F, G, H, I,), + (A, B, C, D, E, F, G, H, I, J,), + (A, B, C, D, E, F, G, H, I, J, K,), + (A, B, C, D, E, F, G, H, I, J, K, L,), + (A, B, C, D, E, F, G, H, I, J, K, L, M,), + (A, B, C, D, E, F, G, H, I, J, K, L, M, N,), + (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O,), + (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P,), +); + #[test] fn examples() { assert_eq!( From 31701e79f5a8518af501ae6d46c7cc9eb5ae27b4 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 12 Apr 2025 10:36:06 +0200 Subject: [PATCH 07/12] Add jiff, chrono, bigdecimal, uuid types --- Cargo.lock | 70 ++++++- marrow-typeinfo-derive/src/lib.rs | 100 ++++++++-- marrow-typeinfo/Cargo.toml | 8 +- marrow-typeinfo/src/ext/bigdecimal.rs | 14 ++ marrow-typeinfo/src/ext/chrono.rs | 44 +++++ marrow-typeinfo/src/ext/jiff.rs | 43 +++++ marrow-typeinfo/src/ext/mod.rs | 4 + marrow-typeinfo/src/ext/uuid.rs | 20 ++ marrow-typeinfo/src/lib.rs | 2 + marrow-typeinfo/tests/derive.rs | 262 ++++++++++++++++++++++++-- 10 files changed, 522 insertions(+), 45 deletions(-) create mode 100644 marrow-typeinfo/src/ext/bigdecimal.rs create mode 100644 marrow-typeinfo/src/ext/chrono.rs create mode 100644 marrow-typeinfo/src/ext/jiff.rs create mode 100644 marrow-typeinfo/src/ext/mod.rs create mode 100644 marrow-typeinfo/src/ext/uuid.rs diff --git a/Cargo.lock b/Cargo.lock index f9fd416..d50db5e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -933,6 +933,19 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "bigdecimal" +version = "0.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a22f228ab7a1b23027ccc6c350b72868017af7ea8356fbdf19f8d991c690013" +dependencies = [ + "autocfg", + "libm", + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -1126,6 +1139,28 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "jiff" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f33145a5cbea837164362c7bd596106eb7c5198f97d1ba6f6ebb3223952e488" +dependencies = [ + "jiff-static", + "portable-atomic", + "portable-atomic-util", +] + +[[package]] +name = "jiff-static" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43ce13c40ec6956157a3635d97a1ee2df323b263f09ea14165131289cb0f5c19" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "js-sys" version = "0.3.70" @@ -1240,8 +1275,12 @@ dependencies = [ name = "marrow-typeinfo" version = "0.1.0" dependencies = [ + "bigdecimal", + "chrono", + "jiff", "marrow", "marrow-typeinfo-derive", + "uuid", ] [[package]] @@ -1344,24 +1383,33 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.9.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.37" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" dependencies = [ "proc-macro2", ] @@ -1433,9 +1481,9 @@ checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" [[package]] name = "syn" -version = "2.0.79" +version = "2.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" dependencies = [ "proc-macro2", "quote", @@ -1503,6 +1551,12 @@ version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +[[package]] +name = "uuid" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" + [[package]] name = "version_check" version = "0.9.5" diff --git a/marrow-typeinfo-derive/src/lib.rs b/marrow-typeinfo-derive/src/lib.rs index 6a66b7d..88c7edc 100644 --- a/marrow-typeinfo-derive/src/lib.rs +++ b/marrow-typeinfo-derive/src/lib.rs @@ -1,25 +1,32 @@ use proc_macro::TokenStream; -use quote::quote; +use quote::{ToTokens, quote}; use syn::{ - Attribute, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, Ident, Lit, LitStr, - Meta, Token, parse_macro_input, punctuated::Punctuated, spanned::Spanned, + Attribute, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, GenericParam, Ident, + Lit, LitStr, Meta, Token, punctuated::Punctuated, spanned::Spanned, }; #[proc_macro_derive(TypeInfo, attributes(marrow_type_info))] -pub fn array_builder(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); +pub fn derive_type_info(input: TokenStream) -> TokenStream { + derive_type_info_impl(input.into()).into() +} + +fn derive_type_info_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream { + let input: DeriveInput = syn::parse2(input).unwrap(); - if !input.generics.params.is_empty() { - panic!("Deriving TypeInfo for generic is not supported") + if input + .generics + .params + .iter() + .any(|p| matches!(p, GenericParam::Type(_))) + { + panic!("Deriving TypeInfo for generics with type parameters is not supported") } - let expanded = match input.data { - Data::Struct(data) => derive_for_struct(&input.ident, &data), - Data::Enum(data) => derive_for_enum(&input.ident, &data), + match &input.data { + Data::Struct(data) => derive_for_struct(&input, data), + Data::Enum(data) => derive_for_enum(&input, data), Data::Union(_) => panic!("Deriving TypeInfo for unions is not supported"), - }; - - TokenStream::from(expanded) + } } #[derive(Debug, Default)] @@ -105,7 +112,23 @@ impl VariantArgs { } } -fn derive_for_struct(name: &Ident, data: &DataStruct) -> proc_macro2::TokenStream { +fn derive_for_struct(input: &DeriveInput, data: &DataStruct) -> proc_macro2::TokenStream { + let name = &input.ident; + + let generics_decl = &input.generics; + let generics_use = if !input.generics.params.is_empty() { + let generics_use = input.generics.params.iter().map(|p| match p { + GenericParam::Const(p) => p.ident.to_token_stream(), + GenericParam::Lifetime(p) => p.lifetime.to_token_stream(), + GenericParam::Type(_) => panic!(), + }); + quote! { + <#(#generics_use),*> + } + } else { + quote! {} + }; + let fields = get_fields(&data.fields); let body = match fields.as_slice() { [] => panic!(), @@ -148,7 +171,7 @@ fn derive_for_struct(name: &Ident, data: &DataStruct) -> proc_macro2::TokenStrea quote! { const _: () = { - impl ::marrow_typeinfo::TypeInfo for #name { + impl #generics_decl ::marrow_typeinfo::TypeInfo for #name #generics_use { fn get_field( context: ::marrow_typeinfo::Context<'_>, ) -> ::marrow_typeinfo::Result<::marrow::datatypes::Field> { @@ -159,9 +182,24 @@ fn derive_for_struct(name: &Ident, data: &DataStruct) -> proc_macro2::TokenStrea } } -fn derive_for_enum(name: &Ident, data: &DataEnum) -> proc_macro2::TokenStream { +fn derive_for_enum(input: &DeriveInput, data: &DataEnum) -> proc_macro2::TokenStream { let mut variant_exprs = Vec::new(); + let name = &input.ident; + let generics_decl = &input.generics; + let generics_use = if !input.generics.params.is_empty() { + let generics_use = input.generics.params.iter().map(|p| match p { + GenericParam::Const(p) => p.ident.to_token_stream(), + GenericParam::Lifetime(p) => p.lifetime.to_token_stream(), + GenericParam::Type(_) => panic!(), + }); + quote! { + <#(#generics_use),*> + } + } else { + quote! {} + }; + for (idx, variant) in data.variants.iter().enumerate() { let variant_name = &variant.ident; let variant_name = LitStr::new(&variant_name.to_string(), variant_name.span()); @@ -217,7 +255,7 @@ fn derive_for_enum(name: &Ident, data: &DataEnum) -> proc_macro2::TokenStream { quote! { const _: () = { - impl ::marrow_typeinfo::TypeInfo for #name { + impl #generics_decl ::marrow_typeinfo::TypeInfo for #name #generics_use { fn get_field( context: ::marrow_typeinfo::Context<'_>, ) -> ::marrow_typeinfo::Result<::marrow::datatypes::Field> { @@ -264,3 +302,31 @@ enum NameSource { Ident, Index, } + +#[test] +#[should_panic(expected = "Deriving TypeInfo for generics with type parameters is not supported")] +fn reject_unsupported() { + derive_type_info_impl(quote! { + struct Example { + field: T, + } + }); +} + +#[test] +fn lifetimes_are_supported() { + derive_type_info_impl(quote! { + struct Example<'a> { + field: &'a i64, + } + }); +} + +#[test] +fn const_params_are_supported() { + derive_type_info_impl(quote! { + struct Example { + field: [u8; N], + } + }); +} diff --git a/marrow-typeinfo/Cargo.toml b/marrow-typeinfo/Cargo.toml index c2d9f82..6dfdfda 100644 --- a/marrow-typeinfo/Cargo.toml +++ b/marrow-typeinfo/Cargo.toml @@ -4,5 +4,11 @@ version = "0.1.0" edition = "2024" [dependencies] -marrow = { path = "../marrow" } +marrow = { path = "../marrow", default-features = false } marrow-typeinfo-derive = { path = "../marrow-typeinfo-derive" } + +jiff = { version = "0.2", default-features = false } + +chrono = { version = "0.4", default-features = false } +bigdecimal = {version = "0.4", default-features = false } +uuid = { version = "1.10.0", default-features = false} diff --git a/marrow-typeinfo/src/ext/bigdecimal.rs b/marrow-typeinfo/src/ext/bigdecimal.rs new file mode 100644 index 0000000..342a1f7 --- /dev/null +++ b/marrow-typeinfo/src/ext/bigdecimal.rs @@ -0,0 +1,14 @@ +use marrow::datatypes::{DataType, Field}; + +use crate::TypeInfo; + +impl TypeInfo for bigdecimal::BigDecimal { + fn get_field(context: crate::Context<'_>) -> crate::Result { + Ok(Field { + name: String::from(context.get_name()), + // TODO: find better defaults + data_type: DataType::Decimal128(5, 5), + ..Default::default() + }) + } +} diff --git a/marrow-typeinfo/src/ext/chrono.rs b/marrow-typeinfo/src/ext/chrono.rs new file mode 100644 index 0000000..f21ca02 --- /dev/null +++ b/marrow-typeinfo/src/ext/chrono.rs @@ -0,0 +1,44 @@ +use chrono::Utc; +use marrow::datatypes::{DataType, Field, TimeUnit}; + +use crate::TypeInfo; + +impl TypeInfo for chrono::NaiveDate { + fn get_field(context: crate::Context<'_>) -> crate::Result { + Ok(Field { + name: String::from(context.get_name()), + data_type: DataType::Date32, + ..Default::default() + }) + } +} + +impl TypeInfo for chrono::NaiveTime { + fn get_field(context: crate::Context<'_>) -> crate::Result { + Ok(Field { + name: String::from(context.get_name()), + data_type: DataType::Time32(TimeUnit::Millisecond), + ..Default::default() + }) + } +} + +impl TypeInfo for chrono::NaiveDateTime { + fn get_field(context: crate::Context<'_>) -> crate::Result { + Ok(Field { + name: String::from(context.get_name()), + data_type: DataType::Timestamp(TimeUnit::Millisecond, None), + ..Default::default() + }) + } +} + +impl TypeInfo for chrono::DateTime { + fn get_field(context: crate::Context<'_>) -> crate::Result { + Ok(Field { + name: String::from(context.get_name()), + data_type: DataType::Timestamp(TimeUnit::Millisecond, Some(String::from("UTC"))), + ..Default::default() + }) + } +} diff --git a/marrow-typeinfo/src/ext/jiff.rs b/marrow-typeinfo/src/ext/jiff.rs new file mode 100644 index 0000000..9a75104 --- /dev/null +++ b/marrow-typeinfo/src/ext/jiff.rs @@ -0,0 +1,43 @@ +use marrow::datatypes::{DataType, Field, TimeUnit}; + +use crate::TypeInfo; + +impl TypeInfo for jiff::civil::Date { + fn get_field(context: crate::Context<'_>) -> crate::Result { + Ok(Field { + name: String::from(context.get_name()), + data_type: DataType::Date32, + ..Default::default() + }) + } +} + +impl TypeInfo for jiff::civil::Time { + fn get_field(context: crate::Context<'_>) -> crate::Result { + Ok(Field { + name: String::from(context.get_name()), + data_type: DataType::Time32(TimeUnit::Millisecond), + ..Default::default() + }) + } +} + +impl TypeInfo for jiff::Span { + fn get_field(context: crate::Context<'_>) -> crate::Result { + Ok(Field { + name: String::from(context.get_name()), + data_type: DataType::Duration(TimeUnit::Millisecond), + ..Default::default() + }) + } +} + +impl TypeInfo for jiff::Timestamp { + fn get_field(context: crate::Context<'_>) -> crate::Result { + Ok(Field { + name: String::from(context.get_name()), + data_type: DataType::Timestamp(TimeUnit::Millisecond, None), + ..Default::default() + }) + } +} diff --git a/marrow-typeinfo/src/ext/mod.rs b/marrow-typeinfo/src/ext/mod.rs new file mode 100644 index 0000000..b9f9ea4 --- /dev/null +++ b/marrow-typeinfo/src/ext/mod.rs @@ -0,0 +1,4 @@ +mod bigdecimal; +mod chrono; +mod jiff; +mod uuid; diff --git a/marrow-typeinfo/src/ext/uuid.rs b/marrow-typeinfo/src/ext/uuid.rs new file mode 100644 index 0000000..f37fb8b --- /dev/null +++ b/marrow-typeinfo/src/ext/uuid.rs @@ -0,0 +1,20 @@ +use std::collections::HashMap; + +use marrow::datatypes::{DataType, Field}; + +use crate::{Context, Result, TypeInfo}; + +impl TypeInfo for uuid::Uuid { + fn get_field(context: Context<'_>) -> Result { + let mut metadata = HashMap::new(); + metadata.insert("ARROW:extension:name".into(), "arrow.uuid".into()); + metadata.insert("ARROW:extension:metadata".into(), String::new()); + + Ok(Field { + name: String::from(context.get_name()), + data_type: DataType::FixedSizeBinary(16), + metadata, + ..Default::default() + }) + } +} diff --git a/marrow-typeinfo/src/lib.rs b/marrow-typeinfo/src/lib.rs index 9d9e308..8c30043 100644 --- a/marrow-typeinfo/src/lib.rs +++ b/marrow-typeinfo/src/lib.rs @@ -12,6 +12,8 @@ use marrow::{ types::f16, }; +mod ext; + /// Derive [TypeInfo] for a given type /// /// Currently structs and enums with any type of lifetime parameters are supported. diff --git a/marrow-typeinfo/tests/derive.rs b/marrow-typeinfo/tests/derive.rs index 8743b87..ec26fd5 100644 --- a/marrow-typeinfo/tests/derive.rs +++ b/marrow-typeinfo/tests/derive.rs @@ -19,14 +19,12 @@ fn example() { Field { name: String::from("a"), data_type: DataType::Int64, - nullable: false, - metadata: Default::default(), + ..Default::default() }, Field { name: String::from("b"), data_type: DataType::FixedSizeBinary(4), - nullable: false, - metadata: Default::default(), + ..Default::default() } ])) ); @@ -53,14 +51,12 @@ fn overwrites() { Field { name: String::from("a"), data_type: DataType::Int64, - nullable: false, - metadata: Default::default(), + ..Default::default() }, Field { name: String::from("b"), data_type: DataType::Binary, - nullable: false, - metadata: Default::default(), + ..Default::default() } ])) ); @@ -115,8 +111,7 @@ fn customize() { Ok(Field { name: String::from(context.get_name()), data_type: DataType::Timestamp(TimeUnit::Millisecond, None), - nullable: false, - metadata: Default::default(), + ..Default::default() }) } @@ -126,14 +121,12 @@ fn customize() { Field { name: String::from("a"), data_type: DataType::Timestamp(TimeUnit::Millisecond, None), - nullable: false, - metadata: Default::default(), + ..Default::default() }, Field { name: String::from("b"), data_type: DataType::FixedSizeBinary(4), - nullable: false, - metadata: Default::default(), + ..Default::default() } ])) ); @@ -214,14 +207,13 @@ fn new_type_enum() { Field { name: String::from("a"), data_type: DataType::Boolean, - nullable: false, - metadata: Default::default(), + ..Default::default() }, Field { name: String::from("b"), data_type: DataType::Null, nullable: true, - metadata: Default::default(), + ..Default::default() }, ]), nullable: false, @@ -233,8 +225,7 @@ fn new_type_enum() { Field { name: String::from("Int64"), data_type: DataType::Int64, - nullable: false, - metadata: Default::default(), + ..Default::default() } ), ], @@ -334,3 +325,236 @@ fn new_struct_enum() { )) ); } + +#[test] +fn const_generics() { + #[derive(TypeInfo)] + #[allow(unused)] + struct Struct { + data: [u8; N], + } + + assert_eq!( + marrow_typeinfo::get_data_type::>(&Options::default()), + Ok(DataType::Struct(vec![Field { + name: String::from("data"), + data_type: DataType::FixedSizeBinary(4), + nullable: false, + metadata: Default::default(), + },])) + ); +} + +#[test] +fn liftime_generics() { + #[derive(TypeInfo)] + #[allow(unused)] + struct Struct<'a, 'b> { + a: &'a u8, + b: &'b u16, + } + + assert_eq!( + marrow_typeinfo::get_data_type::(&Options::default()), + Ok(DataType::Struct(vec![ + Field { + name: String::from("a"), + data_type: DataType::UInt8, + ..Default::default() + }, + Field { + name: String::from("b"), + data_type: DataType::UInt16, + ..Default::default() + }, + ])) + ); +} + +#[test] +fn liftime_generics_with_bounds() { + #[derive(TypeInfo)] + #[allow(unused)] + struct Struct<'a, 'b: 'a> { + a: &'a u8, + b: &'b u16, + } + + assert_eq!( + marrow_typeinfo::get_data_type::(&Options::default()), + Ok(DataType::Struct(vec![ + Field { + name: String::from("a"), + data_type: DataType::UInt8, + ..Default::default() + }, + Field { + name: String::from("b"), + data_type: DataType::UInt16, + ..Default::default() + }, + ])) + ); +} + +#[test] +fn liftime_generics_with_where_clause() { + #[derive(TypeInfo)] + #[allow(unused)] + struct Struct<'a, 'b> + where + 'a: 'b, + { + a: &'a u8, + b: &'b u16, + } + + assert_eq!( + marrow_typeinfo::get_data_type::(&Options::default()), + Ok(DataType::Struct(vec![ + Field { + name: String::from("a"), + data_type: DataType::UInt8, + ..Default::default() + }, + Field { + name: String::from("b"), + data_type: DataType::UInt16, + ..Default::default() + }, + ])) + ); +} + +#[test] +fn enums_const_generics() { + #[derive(TypeInfo)] + #[allow(unused)] + enum Enum { + Data([u8; N]), + } + + assert_eq!( + marrow_typeinfo::get_data_type::>(&Options::default()), + Ok(DataType::Union( + vec![( + 0, + Field { + name: String::from("Data"), + data_type: DataType::FixedSizeBinary(4), + nullable: false, + metadata: Default::default(), + } + ),], + UnionMode::Dense + )), + ); +} + +#[test] +fn enums_with_liftime_generics() { + #[derive(TypeInfo)] + #[allow(unused)] + enum Enum<'a, 'b> { + A(&'a u8), + B(&'b u16), + } + + assert_eq!( + marrow_typeinfo::get_data_type::(&Options::default()), + Ok(DataType::Union( + vec![ + ( + 0, + Field { + name: String::from("A"), + data_type: DataType::UInt8, + ..Default::default() + } + ), + ( + 1, + Field { + name: String::from("B"), + data_type: DataType::UInt16, + ..Default::default() + } + ), + ], + UnionMode::Dense + )) + ); +} + +#[test] +fn enum_liftime_generics_with_bounds() { + #[derive(TypeInfo)] + #[allow(unused)] + enum Enum<'a, 'b: 'a> { + A(&'a u8), + B(&'b u16), + } + + assert_eq!( + marrow_typeinfo::get_data_type::(&Options::default()), + Ok(DataType::Union( + vec![ + ( + 0, + Field { + name: String::from("A"), + data_type: DataType::UInt8, + ..Default::default() + } + ), + ( + 1, + Field { + name: String::from("B"), + data_type: DataType::UInt16, + ..Default::default() + } + ), + ], + UnionMode::Dense + )) + ); +} + +#[test] +fn enum_liftime_generics_with_where_clause() { + #[derive(TypeInfo)] + #[allow(unused)] + enum Enum<'a, 'b> + where + 'a: 'b, + { + A(&'a u8), + B(&'b u16), + } + + assert_eq!( + marrow_typeinfo::get_data_type::(&Options::default()), + Ok(DataType::Union( + vec![ + ( + 0, + Field { + name: String::from("A"), + data_type: DataType::UInt8, + ..Default::default() + } + ), + ( + 1, + Field { + name: String::from("B"), + data_type: DataType::UInt16, + ..Default::default() + } + ), + ], + UnionMode::Dense + )) + ); +} From 4f5ea2a0299a7d8d71e9e8f47162ac740a301b50 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 12 Apr 2025 12:30:21 +0200 Subject: [PATCH 08/12] Refactor impls, add more impls --- marrow-typeinfo/src/impls/collections.rs | 63 +++++ marrow-typeinfo/src/impls/compounds.rs | 75 ++++++ .../src/{ => impls}/ext/bigdecimal.rs | 0 marrow-typeinfo/src/{ => impls}/ext/chrono.rs | 0 marrow-typeinfo/src/{ => impls}/ext/jiff.rs | 0 marrow-typeinfo/src/{ => impls}/ext/mod.rs | 0 marrow-typeinfo/src/{ => impls}/ext/uuid.rs | 0 marrow-typeinfo/src/impls/mod.rs | 7 + marrow-typeinfo/src/impls/primitives.rs | 70 +++++ marrow-typeinfo/src/impls/std.rs | 109 ++++++++ marrow-typeinfo/src/impls/utils.rs | 53 ++++ marrow-typeinfo/src/impls/wrappers.rs | 67 +++++ marrow-typeinfo/src/lib.rs | 240 +----------------- marrow-typeinfo/src/tests.rs | 15 ++ 14 files changed, 465 insertions(+), 234 deletions(-) create mode 100644 marrow-typeinfo/src/impls/collections.rs create mode 100644 marrow-typeinfo/src/impls/compounds.rs rename marrow-typeinfo/src/{ => impls}/ext/bigdecimal.rs (100%) rename marrow-typeinfo/src/{ => impls}/ext/chrono.rs (100%) rename marrow-typeinfo/src/{ => impls}/ext/jiff.rs (100%) rename marrow-typeinfo/src/{ => impls}/ext/mod.rs (100%) rename marrow-typeinfo/src/{ => impls}/ext/uuid.rs (100%) create mode 100644 marrow-typeinfo/src/impls/mod.rs create mode 100644 marrow-typeinfo/src/impls/primitives.rs create mode 100644 marrow-typeinfo/src/impls/std.rs create mode 100644 marrow-typeinfo/src/impls/utils.rs create mode 100644 marrow-typeinfo/src/impls/wrappers.rs create mode 100644 marrow-typeinfo/src/tests.rs diff --git a/marrow-typeinfo/src/impls/collections.rs b/marrow-typeinfo/src/impls/collections.rs new file mode 100644 index 0000000..3ed77b1 --- /dev/null +++ b/marrow-typeinfo/src/impls/collections.rs @@ -0,0 +1,63 @@ +use std::collections::{BTreeMap, BTreeSet, BinaryHeap, HashMap, HashSet, LinkedList, VecDeque}; + +use marrow::datatypes::Field; + +use crate::{Context, Result, TypeInfo}; + +use super::utils::{new_list_field, new_map_field}; + +/// Map a vec to an Arrow List +impl TypeInfo for Vec { + fn get_field(context: Context<'_>) -> Result { + new_list_field::(context) + } +} + +/// Map a `VecDeque` to an Arrow List +impl TypeInfo for VecDeque { + fn get_field(context: Context<'_>) -> Result { + new_list_field::(context) + } +} + +/// Map a `LinkedList` to an Arrow List +impl TypeInfo for LinkedList { + fn get_field(context: Context<'_>) -> Result { + new_list_field::(context) + } +} + +/// Map a `BinaryHeap` to an Arrow List +impl TypeInfo for BinaryHeap { + fn get_field(context: Context<'_>) -> Result { + new_list_field::(context) + } +} + +/// Map a `BTreeSet` to an Arrow List +impl TypeInfo for BTreeSet { + fn get_field(context: Context<'_>) -> Result { + new_list_field::(context) + } +} + +/// Map a `HashSet` to an Arrow List +impl TypeInfo for HashSet { + fn get_field(context: Context<'_>) -> Result { + new_list_field::(context) + } +} + +/// Map a `BTreeMap` to an Arrow Map +impl TypeInfo for BTreeMap { + fn get_field(context: Context<'_>) -> Result { + new_map_field::(context) + } +} + +/// Map a `HashMap` to an Arrow Map +impl TypeInfo for HashMap { + fn get_field(context: Context<'_>) -> Result { + new_map_field::(context) + } +} diff --git a/marrow-typeinfo/src/impls/compounds.rs b/marrow-typeinfo/src/impls/compounds.rs new file mode 100644 index 0000000..f36d3a9 --- /dev/null +++ b/marrow-typeinfo/src/impls/compounds.rs @@ -0,0 +1,75 @@ +use marrow::datatypes::{DataType, Field}; + +use crate::{Context, Result, TypeInfo}; + +use super::utils::new_list_field; + +impl TypeInfo for [T] { + fn get_field(context: Context<'_>) -> Result { + new_list_field::(context) + } +} + +impl TypeInfo for [T; N] { + fn get_field(context: Context<'_>) -> Result { + let base_field = context.get_field::("element")?; + let n = i32::try_from(N)?; + + // TODO: allow to customize + let data_type = if matches!(base_field.data_type, DataType::UInt8) { + DataType::FixedSizeBinary(n) + } else { + DataType::FixedSizeList(Box::new(base_field), n) + }; + + Ok(Field { + name: context.get_name().to_owned(), + data_type, + nullable: false, + metadata: Default::default(), + }) + } +} + +macro_rules! impl_tuples { + ($( ( $($name:ident,)* ), )*) => { + $( + impl<$($name: TypeInfo),*> TypeInfo for ( $($name,)* ) { + #[allow(unused_assignments, clippy::vec_init_then_push)] + fn get_field(context: Context<'_>) -> Result { + let mut idx = 0; + let mut fields = Vec::new(); + $( + fields.push(context.get_field::<$name>(&idx.to_string())?); + idx += 1; + )* + + Ok(Field { + name: context.get_name().to_owned(), + data_type: DataType::Struct(fields), + ..Field::default() + }) + } + } + )* + }; +} + +impl_tuples!( + (A,), + (A, B,), + (A, B, C,), + (A, B, C, D,), + (A, B, C, D, E,), + (A, B, C, D, E, F,), + (A, B, C, D, E, F, G,), + (A, B, C, D, E, F, G, H,), + (A, B, C, D, E, F, G, H, I,), + (A, B, C, D, E, F, G, H, I, J,), + (A, B, C, D, E, F, G, H, I, J, K,), + (A, B, C, D, E, F, G, H, I, J, K, L,), + (A, B, C, D, E, F, G, H, I, J, K, L, M,), + (A, B, C, D, E, F, G, H, I, J, K, L, M, N,), + (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O,), + (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P,), +); diff --git a/marrow-typeinfo/src/ext/bigdecimal.rs b/marrow-typeinfo/src/impls/ext/bigdecimal.rs similarity index 100% rename from marrow-typeinfo/src/ext/bigdecimal.rs rename to marrow-typeinfo/src/impls/ext/bigdecimal.rs diff --git a/marrow-typeinfo/src/ext/chrono.rs b/marrow-typeinfo/src/impls/ext/chrono.rs similarity index 100% rename from marrow-typeinfo/src/ext/chrono.rs rename to marrow-typeinfo/src/impls/ext/chrono.rs diff --git a/marrow-typeinfo/src/ext/jiff.rs b/marrow-typeinfo/src/impls/ext/jiff.rs similarity index 100% rename from marrow-typeinfo/src/ext/jiff.rs rename to marrow-typeinfo/src/impls/ext/jiff.rs diff --git a/marrow-typeinfo/src/ext/mod.rs b/marrow-typeinfo/src/impls/ext/mod.rs similarity index 100% rename from marrow-typeinfo/src/ext/mod.rs rename to marrow-typeinfo/src/impls/ext/mod.rs diff --git a/marrow-typeinfo/src/ext/uuid.rs b/marrow-typeinfo/src/impls/ext/uuid.rs similarity index 100% rename from marrow-typeinfo/src/ext/uuid.rs rename to marrow-typeinfo/src/impls/ext/uuid.rs diff --git a/marrow-typeinfo/src/impls/mod.rs b/marrow-typeinfo/src/impls/mod.rs new file mode 100644 index 0000000..797a4cd --- /dev/null +++ b/marrow-typeinfo/src/impls/mod.rs @@ -0,0 +1,7 @@ +mod collections; +mod compounds; +mod ext; +mod primitives; +mod std; +mod utils; +mod wrappers; diff --git a/marrow-typeinfo/src/impls/primitives.rs b/marrow-typeinfo/src/impls/primitives.rs new file mode 100644 index 0000000..5e667c6 --- /dev/null +++ b/marrow-typeinfo/src/impls/primitives.rs @@ -0,0 +1,70 @@ +use marrow::{ + datatypes::{DataType, Field}, + types::f16, +}; + +use crate::{Context, Result, TypeInfo}; + +use super::utils::new_string_field; + +macro_rules! define_primitive { + ($(($ty:ty, $dt:expr),)*) => { + $( + impl TypeInfo for $ty { + fn get_field(context: Context<'_>) -> Result { + Ok(Field { + name: context.get_name().to_owned(), + data_type: $dt, + ..Field::default() + }) + } + } + )* + }; +} + +define_primitive!( + (bool, DataType::Boolean), + (u8, DataType::UInt8), + (u16, DataType::UInt16), + (u32, DataType::UInt32), + (u64, DataType::UInt64), + (i8, DataType::Int8), + (i16, DataType::Int16), + (i32, DataType::Int32), + (i64, DataType::Int64), + (f16, DataType::Float16), + (f32, DataType::Float32), + (f64, DataType::Float64), + (char, DataType::UInt32), +); + +impl TypeInfo for () { + fn get_field(context: Context<'_>) -> Result { + let _ = context; + Ok(Field { + name: context.get_name().to_owned(), + data_type: DataType::Null, + nullable: true, + metadata: Default::default(), + }) + } +} + +impl TypeInfo for str { + fn get_field(context: Context<'_>) -> Result { + Ok(new_string_field(context)) + } +} + +impl TypeInfo for &T { + fn get_field(context: Context<'_>) -> Result { + T::get_field(context) + } +} + +impl TypeInfo for &mut T { + fn get_field(context: Context<'_>) -> Result { + T::get_field(context) + } +} diff --git a/marrow-typeinfo/src/impls/std.rs b/marrow-typeinfo/src/impls/std.rs new file mode 100644 index 0000000..81f0eb3 --- /dev/null +++ b/marrow-typeinfo/src/impls/std.rs @@ -0,0 +1,109 @@ +use std::{ + num::NonZero, + ops::Range, + sync::atomic::{ + AtomicBool, AtomicI8, AtomicI16, AtomicI32, AtomicI64, AtomicU8, AtomicU16, AtomicU32, + AtomicU64, + }, + time::{Duration, SystemTime}, +}; + +use marrow::datatypes::{DataType, Field, TimeUnit, UnionMode}; + +use crate::{Context, Result, TypeInfo}; + +use super::utils::new_string_field; + +impl TypeInfo for String { + fn get_field(context: Context<'_>) -> Result { + Ok(new_string_field(context)) + } +} + +/// Map an option to a nullable field +impl TypeInfo for Option { + fn get_field(context: Context<'_>) -> Result { + let mut base_field = T::get_field(context)?; + base_field.nullable = true; + Ok(base_field) + } +} + +/// Map a `Result` to an Arrow Union with `Ok` and `Err` variants +impl TypeInfo for Result { + fn get_field(context: Context<'_>) -> Result { + let ok = context.get_field::("Ok")?; + let err = context.get_field::("Err")?; + + Ok(Field { + name: String::from(context.get_name()), + data_type: DataType::Union(vec![(0, ok), (1, err)], UnionMode::Dense), + ..Default::default() + }) + } +} + +/// Map a `Range` to an Arrow `FixedSizeList(.., 2)` +impl TypeInfo for Range { + fn get_field(context: Context<'_>) -> Result { + <[T; 2]>::get_field(context) + } +} + +macro_rules! impl_nonzero { + ($($ty:ident),* $(,)?) => { + $( + impl TypeInfo for NonZero<$ty> { + fn get_field(context: Context<'_>) -> Result { + <$ty>::get_field(context) + } + } + )* + }; +} + +impl_nonzero!(u8, u16, u32, u64, i8, i16, i32, i64); + +macro_rules! impl_atomic { + ($(($atomic:ident, $ty:ident)),* $(,)?) => { + $( + impl TypeInfo for $atomic { + fn get_field(context: Context<'_>) -> Result { + $ty::get_field(context) + } + } + )* + }; +} + +impl_atomic!( + (AtomicBool, bool), + (AtomicI8, i8), + (AtomicI16, i16), + (AtomicI32, i32), + (AtomicI64, i64), + (AtomicU8, u8), + (AtomicU16, u16), + (AtomicU32, u32), + (AtomicU64, u64), +); + +impl TypeInfo for Duration { + fn get_field(context: Context<'_>) -> Result { + Ok(Field { + name: String::from(context.get_name()), + data_type: DataType::Duration(TimeUnit::Millisecond), + ..Default::default() + }) + } +} + +impl TypeInfo for SystemTime { + fn get_field(context: Context<'_>) -> Result { + Ok(Field { + name: String::from(context.get_name()), + data_type: DataType::Timestamp(TimeUnit::Millisecond, None), + ..Default::default() + }) + } +} diff --git a/marrow-typeinfo/src/impls/utils.rs b/marrow-typeinfo/src/impls/utils.rs new file mode 100644 index 0000000..6a9a534 --- /dev/null +++ b/marrow-typeinfo/src/impls/utils.rs @@ -0,0 +1,53 @@ +use marrow::datatypes::{DataType, Field}; + +use crate::{Context, DefaultStringType, LargeList, Result, TypeInfo}; + +pub fn new_field(name: &str, data_type: DataType) -> Field { + Field { + name: name.to_owned(), + data_type, + nullable: false, + metadata: Default::default(), + } +} + +pub fn new_string_field(context: Context<'_>) -> Field { + let ty = if let Some(DefaultStringType(ty)) = context.get_options().get() { + ty.clone() + } else { + DataType::LargeUtf8 + }; + new_field(context.get_name(), ty) +} + +pub fn new_list_field(context: Context<'_>) -> Result { + let larget_list = if let Some(LargeList(large_list)) = context.get_options().get() { + *large_list + } else { + false + }; + + let base_field = context.get_field::("element")?; + + Ok(Field { + name: context.get_name().to_owned(), + data_type: if larget_list { + DataType::LargeList(Box::new(base_field)) + } else { + DataType::List(Box::new(base_field)) + }, + nullable: false, + metadata: Default::default(), + }) +} + +pub fn new_map_field(context: Context<'_>) -> Result { + let key_field = context.get_field::("key")?; + let value_field = context.get_field::("value")?; + let entry_field = new_field("entry", DataType::Struct(vec![key_field, value_field])); + + Ok(new_field( + context.get_name(), + DataType::Map(Box::new(entry_field), false), + )) +} diff --git a/marrow-typeinfo/src/impls/wrappers.rs b/marrow-typeinfo/src/impls/wrappers.rs new file mode 100644 index 0000000..e23342f --- /dev/null +++ b/marrow-typeinfo/src/impls/wrappers.rs @@ -0,0 +1,67 @@ +use std::{ + borrow::Cow, + cell::{Cell, RefCell}, + marker::PhantomData, + rc::Rc, + sync::{Arc, Mutex, RwLock}, +}; + +use marrow::datatypes::Field; + +use crate::{Context, Result, TypeInfo}; + +impl TypeInfo for PhantomData { + fn get_field(context: Context<'_>) -> Result { + let mut field = T::get_field(context)?; + field.nullable = true; + Ok(field) + } +} + +impl TypeInfo for Box { + fn get_field(context: Context<'_>) -> Result { + T::get_field(context) + } +} + +impl TypeInfo for Cell { + fn get_field(context: Context<'_>) -> Result { + T::get_field(context) + } +} + +impl TypeInfo for RefCell { + fn get_field(context: Context<'_>) -> Result { + T::get_field(context) + } +} + +impl TypeInfo for Mutex { + fn get_field(context: Context<'_>) -> Result { + T::get_field(context) + } +} + +impl TypeInfo for RwLock { + fn get_field(context: Context<'_>) -> Result { + T::get_field(context) + } +} + +impl TypeInfo for Rc { + fn get_field(context: Context<'_>) -> Result { + T::get_field(context) + } +} + +impl TypeInfo for Arc { + fn get_field(context: Context<'_>) -> Result { + T::get_field(context) + } +} + +impl<'a, T: TypeInfo + ToOwned + ?Sized + 'a> TypeInfo for Cow<'a, T> { + fn get_field(context: Context<'_>) -> Result { + T::get_field(context) + } +} diff --git a/marrow-typeinfo/src/lib.rs b/marrow-typeinfo/src/lib.rs index 8c30043..1bf6d61 100644 --- a/marrow-typeinfo/src/lib.rs +++ b/marrow-typeinfo/src/lib.rs @@ -4,15 +4,14 @@ use std::{ convert::Infallible, num::TryFromIntError, rc::Rc, - sync::Arc, }; -use marrow::{ - datatypes::{DataType, Field}, - types::f16, -}; +use marrow::datatypes::{DataType, Field}; + +mod impls; -mod ext; +#[cfg(test)] +mod tests; /// Derive [TypeInfo] for a given type /// @@ -108,7 +107,7 @@ pub struct Context<'a> { options: &'a Options, } -impl<'a> Context<'a> { +impl Context<'_> { pub fn get_name(&self) -> &str { self.name } @@ -173,230 +172,3 @@ pub trait TypeInfo { /// See [crate::get_field] fn get_field(context: Context<'_>) -> Result; } - -macro_rules! define_primitive { - ($(($ty:ty, $dt:expr),)*) => { - $( - impl TypeInfo for $ty { - fn get_field(context: Context<'_>) -> Result { - Ok(Field { - name: context.get_name().to_owned(), - data_type: $dt, - ..Field::default() - }) - } - } - )* - }; -} - -define_primitive!( - (bool, DataType::Boolean), - (u8, DataType::UInt8), - (u16, DataType::UInt16), - (u32, DataType::UInt32), - (u64, DataType::UInt64), - (i8, DataType::Int8), - (i16, DataType::Int16), - (i32, DataType::Int32), - (i64, DataType::Int64), - (f16, DataType::Float16), - (f32, DataType::Float32), - (f64, DataType::Float64), -); - -impl TypeInfo for &T { - fn get_field(context: Context<'_>) -> Result { - T::get_field(context) - } -} - -impl TypeInfo for &mut T { - fn get_field(context: Context<'_>) -> Result { - T::get_field(context) - } -} - -impl TypeInfo for () { - fn get_field(context: Context<'_>) -> Result { - let _ = context; - Ok(Field { - name: context.get_name().to_owned(), - data_type: DataType::Null, - nullable: true, - metadata: Default::default(), - }) - } -} - -fn new_field(name: &str, data_type: DataType) -> Field { - Field { - name: name.to_owned(), - data_type, - nullable: false, - metadata: Default::default(), - } -} - -fn new_string_field(context: Context<'_>) -> Field { - let ty = if let Some(DefaultStringType(ty)) = context.get_options().get() { - ty.clone() - } else { - DataType::LargeUtf8 - }; - new_field(context.get_name(), ty) -} - -impl TypeInfo for &str { - fn get_field(context: Context<'_>) -> Result { - Ok(new_string_field(context)) - } -} - -impl TypeInfo for String { - fn get_field(context: Context<'_>) -> Result { - Ok(new_string_field(context)) - } -} - -impl TypeInfo for Box { - fn get_field(context: Context<'_>) -> Result { - Ok(new_string_field(context)) - } -} - -impl TypeInfo for Arc { - fn get_field(context: Context<'_>) -> Result { - Ok(new_string_field(context)) - } -} - -impl TypeInfo for [T; N] { - fn get_field(context: Context<'_>) -> Result { - let base_field = context.get_field::("element")?; - let n = i32::try_from(N)?; - - // TODO: allow to customize - let data_type = if matches!(base_field.data_type, DataType::UInt8) { - DataType::FixedSizeBinary(n) - } else { - DataType::FixedSizeList(Box::new(base_field), n) - }; - - Ok(Field { - name: context.get_name().to_owned(), - data_type, - nullable: false, - metadata: Default::default(), - }) - } -} - -fn new_list_field(context: Context<'_>) -> Result { - let larget_list = if let Some(LargeList(large_list)) = context.get_options().get() { - *large_list - } else { - false - }; - - let base_field = context.get_field::("element")?; - - Ok(Field { - name: context.get_name().to_owned(), - data_type: if larget_list { - DataType::LargeList(Box::new(base_field)) - } else { - DataType::List(Box::new(base_field)) - }, - nullable: false, - metadata: Default::default(), - }) -} - -impl TypeInfo for Vec { - fn get_field(context: Context<'_>) -> Result { - new_list_field::(context) - } -} - -impl TypeInfo for &[T] { - fn get_field(context: Context<'_>) -> Result { - new_list_field::(context) - } -} - -impl TypeInfo for Option { - fn get_field(context: Context<'_>) -> Result { - let mut base_field = T::get_field(context)?; - base_field.nullable = true; - Ok(base_field) - } -} - -impl TypeInfo for HashMap { - fn get_field(context: Context<'_>) -> Result { - let key_field = context.get_field::("key")?; - let value_field = context.get_field::("value")?; - let entry_field = new_field("entry", DataType::Struct(vec![key_field, value_field])); - - Ok(new_field( - context.get_name(), - DataType::Map(Box::new(entry_field), false), - )) - } -} - -macro_rules! impl_tuples { - ($( ( $($name:ident,)* ), )*) => { - $( - impl<$($name: TypeInfo),*> TypeInfo for ( $($name,)* ) { - #[allow(unused_assignments)] - fn get_field(context: Context<'_>) -> Result { - let mut idx = 0; - let mut fields = Vec::new(); - $( - fields.push(context.get_field::<$name>(&idx.to_string())?); - idx += 1; - )* - - Ok(Field { - name: context.get_name().to_owned(), - data_type: DataType::Struct(fields), - ..Field::default() - }) - } - } - )* - }; -} - -impl_tuples!( - (A,), - (A, B,), - (A, B, C,), - (A, B, C, D,), - (A, B, C, D, E,), - (A, B, C, D, E, F,), - (A, B, C, D, E, F, G,), - (A, B, C, D, E, F, G, H,), - (A, B, C, D, E, F, G, H, I,), - (A, B, C, D, E, F, G, H, I, J,), - (A, B, C, D, E, F, G, H, I, J, K,), - (A, B, C, D, E, F, G, H, I, J, K, L,), - (A, B, C, D, E, F, G, H, I, J, K, L, M,), - (A, B, C, D, E, F, G, H, I, J, K, L, M, N,), - (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O,), - (A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P,), -); - -#[test] -fn examples() { - assert_eq!( - get_data_type::(&Options::default()), - Ok(DataType::Int64) - ); - assert_eq!( - get_data_type::<[u8; 8]>(&Options::default()), - Ok(DataType::FixedSizeBinary(8)) - ); -} diff --git a/marrow-typeinfo/src/tests.rs b/marrow-typeinfo/src/tests.rs new file mode 100644 index 0000000..e5f7e01 --- /dev/null +++ b/marrow-typeinfo/src/tests.rs @@ -0,0 +1,15 @@ +use marrow::datatypes::DataType; + +use crate::{Options, get_data_type}; + +#[test] +fn examples() { + assert_eq!( + get_data_type::(&Options::default()), + Ok(DataType::Int64) + ); + assert_eq!( + get_data_type::<[u8; 8]>(&Options::default()), + Ok(DataType::FixedSizeBinary(8)) + ); +} From e772bad9e3ca4e776f2dd3bebb2fda4d1313f689 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 12 Apr 2025 13:59:06 +0200 Subject: [PATCH 09/12] Add more range related types --- marrow-typeinfo/src/impls/std.rs | 48 +++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/marrow-typeinfo/src/impls/std.rs b/marrow-typeinfo/src/impls/std.rs index 81f0eb3..2aec829 100644 --- a/marrow-typeinfo/src/impls/std.rs +++ b/marrow-typeinfo/src/impls/std.rs @@ -1,6 +1,6 @@ use std::{ num::NonZero, - ops::Range, + ops::{Bound, Range, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive}, sync::atomic::{ AtomicBool, AtomicI8, AtomicI16, AtomicI32, AtomicI64, AtomicU8, AtomicU16, AtomicU32, AtomicU64, @@ -50,6 +50,52 @@ impl TypeInfo for Range { } } +/// Map a `RangeInclusive` to an Arrow `FixedSizeList(.., 2)` +impl TypeInfo for RangeInclusive { + fn get_field(context: Context<'_>) -> Result { + <[T; 2]>::get_field(context) + } +} + +/// Map a `RangeTo` to the index type +impl TypeInfo for RangeTo { + fn get_field(context: Context<'_>) -> Result { + T::get_field(context) + } +} + +/// Map a `RangeToInclusive` to the index type +impl TypeInfo for RangeToInclusive { + fn get_field(context: Context<'_>) -> Result { + T::get_field(context) + } +} + +/// Map a `RangeFrom` to the index type +impl TypeInfo for RangeFrom { + fn get_field(context: Context<'_>) -> Result { + T::get_field(context) + } +} + +/// Map a `Bound` to an Arrow Union with variants `Included`, `Excluded`, `Unbounded` +impl TypeInfo for Bound { + fn get_field(context: Context<'_>) -> Result { + let included = context.get_field::("Included")?; + let excluded = context.get_field::("Excluded")?; + let unbounded = context.get_field::<()>("Unbounded")?; + + Ok(Field { + name: String::from(context.get_name()), + data_type: DataType::Union( + vec![(0, included), (1, excluded), (2, unbounded)], + UnionMode::Dense, + ), + ..Default::default() + }) + } +} + macro_rules! impl_nonzero { ($($ty:ident),* $(,)?) => { $( From 12b94a31a98b6f9033c346e9f4f592c299bc0da4 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 12 Apr 2025 15:36:12 +0200 Subject: [PATCH 10/12] Rename crates --- Cargo.lock | 6 +-- Cargo.toml | 2 +- .../Cargo.toml | 2 +- .../src/lib.rs | 12 +++--- .../Cargo.toml | 4 +- marrow-convert/src/error.rs | 26 ++++++++++++ .../src/impls/collections.rs | 0 .../src/impls/compounds.rs | 0 .../src/impls/ext/bigdecimal.rs | 0 .../src/impls/ext/chrono.rs | 0 .../src/impls/ext/jiff.rs | 0 .../src/impls/ext/mod.rs | 0 .../src/impls/ext/uuid.rs | 0 .../src/impls/mod.rs | 0 .../src/impls/primitives.rs | 0 .../src/impls/std.rs | 0 .../src/impls/utils.rs | 5 ++- .../src/impls/wrappers.rs | 0 marrow-convert/src/lib.rs | 14 +++++++ .../src/tests.rs | 0 .../lib.rs => marrow-convert/src/typeinfo.rs | 41 ++----------------- .../tests/derive.rs | 36 ++++++++-------- 22 files changed, 78 insertions(+), 70 deletions(-) rename {marrow-typeinfo-derive => marrow-convert-derive}/Cargo.toml (80%) rename {marrow-typeinfo-derive => marrow-convert-derive}/src/lib.rs (96%) rename {marrow-typeinfo => marrow-convert}/Cargo.toml (79%) create mode 100644 marrow-convert/src/error.rs rename {marrow-typeinfo => marrow-convert}/src/impls/collections.rs (100%) rename {marrow-typeinfo => marrow-convert}/src/impls/compounds.rs (100%) rename {marrow-typeinfo => marrow-convert}/src/impls/ext/bigdecimal.rs (100%) rename {marrow-typeinfo => marrow-convert}/src/impls/ext/chrono.rs (100%) rename {marrow-typeinfo => marrow-convert}/src/impls/ext/jiff.rs (100%) rename {marrow-typeinfo => marrow-convert}/src/impls/ext/mod.rs (100%) rename {marrow-typeinfo => marrow-convert}/src/impls/ext/uuid.rs (100%) rename {marrow-typeinfo => marrow-convert}/src/impls/mod.rs (100%) rename {marrow-typeinfo => marrow-convert}/src/impls/primitives.rs (100%) rename {marrow-typeinfo => marrow-convert}/src/impls/std.rs (100%) rename {marrow-typeinfo => marrow-convert}/src/impls/utils.rs (94%) rename {marrow-typeinfo => marrow-convert}/src/impls/wrappers.rs (100%) create mode 100644 marrow-convert/src/lib.rs rename {marrow-typeinfo => marrow-convert}/src/tests.rs (100%) rename marrow-typeinfo/src/lib.rs => marrow-convert/src/typeinfo.rs (79%) rename {marrow-typeinfo => marrow-convert}/tests/derive.rs (91%) diff --git a/Cargo.lock b/Cargo.lock index d50db5e..579e1c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1272,19 +1272,19 @@ dependencies = [ ] [[package]] -name = "marrow-typeinfo" +name = "marrow-convert" version = "0.1.0" dependencies = [ "bigdecimal", "chrono", "jiff", "marrow", - "marrow-typeinfo-derive", + "marrow-convert-derive", "uuid", ] [[package]] -name = "marrow-typeinfo-derive" +name = "marrow-convert-derive" version = "0.1.0" dependencies = [ "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index a272f7a..b610719 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["marrow", "marrow-typeinfo", "marrow-typeinfo-derive", "test_with_arrow"] +members = ["marrow", "marrow-convert", "marrow-convert-derive", "test_with_arrow"] resolver = "2" diff --git a/marrow-typeinfo-derive/Cargo.toml b/marrow-convert-derive/Cargo.toml similarity index 80% rename from marrow-typeinfo-derive/Cargo.toml rename to marrow-convert-derive/Cargo.toml index 550c8b7..d369e5f 100644 --- a/marrow-typeinfo-derive/Cargo.toml +++ b/marrow-convert-derive/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "marrow-typeinfo-derive" +name = "marrow-convert-derive" version = "0.1.0" edition = "2024" diff --git a/marrow-typeinfo-derive/src/lib.rs b/marrow-convert-derive/src/lib.rs similarity index 96% rename from marrow-typeinfo-derive/src/lib.rs rename to marrow-convert-derive/src/lib.rs index 88c7edc..09ceb30 100644 --- a/marrow-typeinfo-derive/src/lib.rs +++ b/marrow-convert-derive/src/lib.rs @@ -171,10 +171,10 @@ fn derive_for_struct(input: &DeriveInput, data: &DataStruct) -> proc_macro2::Tok quote! { const _: () = { - impl #generics_decl ::marrow_typeinfo::TypeInfo for #name #generics_use { + impl #generics_decl ::marrow_convert::TypeInfo for #name #generics_use { fn get_field( - context: ::marrow_typeinfo::Context<'_>, - ) -> ::marrow_typeinfo::Result<::marrow::datatypes::Field> { + context: ::marrow_convert::Context<'_>, + ) -> ::marrow_convert::Result<::marrow::datatypes::Field> { #body } } @@ -255,10 +255,10 @@ fn derive_for_enum(input: &DeriveInput, data: &DataEnum) -> proc_macro2::TokenSt quote! { const _: () = { - impl #generics_decl ::marrow_typeinfo::TypeInfo for #name #generics_use { + impl #generics_decl ::marrow_convert::TypeInfo for #name #generics_use { fn get_field( - context: ::marrow_typeinfo::Context<'_>, - ) -> ::marrow_typeinfo::Result<::marrow::datatypes::Field> { + context: ::marrow_convert::Context<'_>, + ) -> ::marrow_convert::Result<::marrow::datatypes::Field> { let mut variants = ::std::vec::Vec::<(::std::primitive::i8, ::marrow::datatypes::Field)>::new(); #( variants.push(#variant_exprs); )* diff --git a/marrow-typeinfo/Cargo.toml b/marrow-convert/Cargo.toml similarity index 79% rename from marrow-typeinfo/Cargo.toml rename to marrow-convert/Cargo.toml index 6dfdfda..adc86af 100644 --- a/marrow-typeinfo/Cargo.toml +++ b/marrow-convert/Cargo.toml @@ -1,11 +1,11 @@ [package] -name = "marrow-typeinfo" +name = "marrow-convert" version = "0.1.0" edition = "2024" [dependencies] marrow = { path = "../marrow", default-features = false } -marrow-typeinfo-derive = { path = "../marrow-typeinfo-derive" } +marrow-convert-derive = { path = "../marrow-convert-derive" } jiff = { version = "0.2", default-features = false } diff --git a/marrow-convert/src/error.rs b/marrow-convert/src/error.rs new file mode 100644 index 0000000..b2a3d2c --- /dev/null +++ b/marrow-convert/src/error.rs @@ -0,0 +1,26 @@ +use std::{convert::Infallible, num::TryFromIntError}; + +pub type Result = std::result::Result; + +#[derive(Debug, PartialEq)] +pub struct Error(pub(crate) String); + +impl std::error::Error for Error {} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Error({:?})", self.0) + } +} + +impl From for Error { + fn from(_: Infallible) -> Self { + unreachable!() + } +} + +impl From for Error { + fn from(value: TryFromIntError) -> Self { + Self(value.to_string()) + } +} diff --git a/marrow-typeinfo/src/impls/collections.rs b/marrow-convert/src/impls/collections.rs similarity index 100% rename from marrow-typeinfo/src/impls/collections.rs rename to marrow-convert/src/impls/collections.rs diff --git a/marrow-typeinfo/src/impls/compounds.rs b/marrow-convert/src/impls/compounds.rs similarity index 100% rename from marrow-typeinfo/src/impls/compounds.rs rename to marrow-convert/src/impls/compounds.rs diff --git a/marrow-typeinfo/src/impls/ext/bigdecimal.rs b/marrow-convert/src/impls/ext/bigdecimal.rs similarity index 100% rename from marrow-typeinfo/src/impls/ext/bigdecimal.rs rename to marrow-convert/src/impls/ext/bigdecimal.rs diff --git a/marrow-typeinfo/src/impls/ext/chrono.rs b/marrow-convert/src/impls/ext/chrono.rs similarity index 100% rename from marrow-typeinfo/src/impls/ext/chrono.rs rename to marrow-convert/src/impls/ext/chrono.rs diff --git a/marrow-typeinfo/src/impls/ext/jiff.rs b/marrow-convert/src/impls/ext/jiff.rs similarity index 100% rename from marrow-typeinfo/src/impls/ext/jiff.rs rename to marrow-convert/src/impls/ext/jiff.rs diff --git a/marrow-typeinfo/src/impls/ext/mod.rs b/marrow-convert/src/impls/ext/mod.rs similarity index 100% rename from marrow-typeinfo/src/impls/ext/mod.rs rename to marrow-convert/src/impls/ext/mod.rs diff --git a/marrow-typeinfo/src/impls/ext/uuid.rs b/marrow-convert/src/impls/ext/uuid.rs similarity index 100% rename from marrow-typeinfo/src/impls/ext/uuid.rs rename to marrow-convert/src/impls/ext/uuid.rs diff --git a/marrow-typeinfo/src/impls/mod.rs b/marrow-convert/src/impls/mod.rs similarity index 100% rename from marrow-typeinfo/src/impls/mod.rs rename to marrow-convert/src/impls/mod.rs diff --git a/marrow-typeinfo/src/impls/primitives.rs b/marrow-convert/src/impls/primitives.rs similarity index 100% rename from marrow-typeinfo/src/impls/primitives.rs rename to marrow-convert/src/impls/primitives.rs diff --git a/marrow-typeinfo/src/impls/std.rs b/marrow-convert/src/impls/std.rs similarity index 100% rename from marrow-typeinfo/src/impls/std.rs rename to marrow-convert/src/impls/std.rs diff --git a/marrow-typeinfo/src/impls/utils.rs b/marrow-convert/src/impls/utils.rs similarity index 94% rename from marrow-typeinfo/src/impls/utils.rs rename to marrow-convert/src/impls/utils.rs index 6a9a534..232e5ed 100644 --- a/marrow-typeinfo/src/impls/utils.rs +++ b/marrow-convert/src/impls/utils.rs @@ -1,6 +1,9 @@ use marrow::datatypes::{DataType, Field}; -use crate::{Context, DefaultStringType, LargeList, Result, TypeInfo}; +use crate::{ + Context, Result, TypeInfo, + typeinfo::{DefaultStringType, LargeList}, +}; pub fn new_field(name: &str, data_type: DataType) -> Field { Field { diff --git a/marrow-typeinfo/src/impls/wrappers.rs b/marrow-convert/src/impls/wrappers.rs similarity index 100% rename from marrow-typeinfo/src/impls/wrappers.rs rename to marrow-convert/src/impls/wrappers.rs diff --git a/marrow-convert/src/lib.rs b/marrow-convert/src/lib.rs new file mode 100644 index 0000000..5a4941e --- /dev/null +++ b/marrow-convert/src/lib.rs @@ -0,0 +1,14 @@ +mod error; +mod impls; +mod typeinfo; + +#[cfg(test)] +mod tests; + +/// Derive [TypeInfo] for a given type +/// +/// Currently structs and enums with any type of lifetime parameters are supported. +pub use marrow_convert_derive::TypeInfo; + +pub use error::{Error, Result}; +pub use typeinfo::{Context, Options, TypeInfo, get_data_type, get_field}; diff --git a/marrow-typeinfo/src/tests.rs b/marrow-convert/src/tests.rs similarity index 100% rename from marrow-typeinfo/src/tests.rs rename to marrow-convert/src/tests.rs diff --git a/marrow-typeinfo/src/lib.rs b/marrow-convert/src/typeinfo.rs similarity index 79% rename from marrow-typeinfo/src/lib.rs rename to marrow-convert/src/typeinfo.rs index 1bf6d61..add8ba8 100644 --- a/marrow-typeinfo/src/lib.rs +++ b/marrow-convert/src/typeinfo.rs @@ -1,47 +1,12 @@ use std::{ any::{Any, TypeId}, collections::HashMap, - convert::Infallible, - num::TryFromIntError, rc::Rc, }; use marrow::datatypes::{DataType, Field}; -mod impls; - -#[cfg(test)] -mod tests; - -/// Derive [TypeInfo] for a given type -/// -/// Currently structs and enums with any type of lifetime parameters are supported. -pub use marrow_typeinfo_derive::TypeInfo; - -pub type Result = std::result::Result; - -#[derive(Debug, PartialEq)] -pub struct Error(String); - -impl std::error::Error for Error {} - -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Error({:?})", self.0) - } -} - -impl From for Error { - fn from(_: Infallible) -> Self { - unreachable!() - } -} - -impl From for Error { - fn from(value: TryFromIntError) -> Self { - Self(value.to_string()) - } -} +use crate::{Error, Result}; #[derive(Debug, Default)] pub struct Options { @@ -160,9 +125,9 @@ pub fn get_data_type(options: &Options) -> Result { Ok(get_field::("item", options)?.data_type) } -struct DefaultStringType(DataType); +pub struct DefaultStringType(pub DataType); -struct LargeList(bool); +pub struct LargeList(pub bool); /// Get the Arrow type information for a given Rust type /// diff --git a/marrow-typeinfo/tests/derive.rs b/marrow-convert/tests/derive.rs similarity index 91% rename from marrow-typeinfo/tests/derive.rs rename to marrow-convert/tests/derive.rs index ec26fd5..385dd90 100644 --- a/marrow-typeinfo/tests/derive.rs +++ b/marrow-convert/tests/derive.rs @@ -2,7 +2,7 @@ use marrow::{ datatypes::{DataType, Field, TimeUnit, UnionMode}, types::f16, }; -use marrow_typeinfo::{Context, Options, Result, TypeInfo}; +use marrow_convert::{Context, Options, Result, TypeInfo}; #[test] fn example() { @@ -14,7 +14,7 @@ fn example() { } assert_eq!( - marrow_typeinfo::get_data_type::(&Options::default()), + marrow_convert::get_data_type::(&Options::default()), Ok(DataType::Struct(vec![ Field { name: String::from("a"), @@ -40,7 +40,7 @@ fn overwrites() { } assert_eq!( - marrow_typeinfo::get_data_type::(&Options::default().overwrite( + marrow_convert::get_data_type::(&Options::default().overwrite( "$.b", Field { data_type: DataType::Binary, @@ -69,7 +69,7 @@ fn newtype() { struct S(f16); assert_eq!( - marrow_typeinfo::get_data_type::(&Options::default()), + marrow_convert::get_data_type::(&Options::default()), Ok(DataType::Float16) ); } @@ -81,7 +81,7 @@ fn tuple() { struct S(u8, [u8; 4]); assert_eq!( - marrow_typeinfo::get_data_type::(&Options::default()), + marrow_convert::get_data_type::(&Options::default()), Ok(DataType::Struct(vec![ Field { name: String::from("0"), @@ -116,7 +116,7 @@ fn customize() { } assert_eq!( - marrow_typeinfo::get_data_type::(&Options::default()), + marrow_convert::get_data_type::(&Options::default()), Ok(DataType::Struct(vec![ Field { name: String::from("a"), @@ -143,7 +143,7 @@ fn fieldless_union() { } assert_eq!( - marrow_typeinfo::get_data_type::(&Options::default()), + marrow_convert::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( @@ -196,7 +196,7 @@ fn new_type_enum() { } assert_eq!( - marrow_typeinfo::get_data_type::(&Options::default()), + marrow_convert::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( @@ -244,7 +244,7 @@ fn new_tuple_enum() { } assert_eq!( - marrow_typeinfo::get_data_type::(&Options::default()), + marrow_convert::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( @@ -290,7 +290,7 @@ fn new_struct_enum() { } assert_eq!( - marrow_typeinfo::get_data_type::(&Options::default()), + marrow_convert::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( @@ -335,7 +335,7 @@ fn const_generics() { } assert_eq!( - marrow_typeinfo::get_data_type::>(&Options::default()), + marrow_convert::get_data_type::>(&Options::default()), Ok(DataType::Struct(vec![Field { name: String::from("data"), data_type: DataType::FixedSizeBinary(4), @@ -355,7 +355,7 @@ fn liftime_generics() { } assert_eq!( - marrow_typeinfo::get_data_type::(&Options::default()), + marrow_convert::get_data_type::(&Options::default()), Ok(DataType::Struct(vec![ Field { name: String::from("a"), @@ -381,7 +381,7 @@ fn liftime_generics_with_bounds() { } assert_eq!( - marrow_typeinfo::get_data_type::(&Options::default()), + marrow_convert::get_data_type::(&Options::default()), Ok(DataType::Struct(vec![ Field { name: String::from("a"), @@ -410,7 +410,7 @@ fn liftime_generics_with_where_clause() { } assert_eq!( - marrow_typeinfo::get_data_type::(&Options::default()), + marrow_convert::get_data_type::(&Options::default()), Ok(DataType::Struct(vec![ Field { name: String::from("a"), @@ -435,7 +435,7 @@ fn enums_const_generics() { } assert_eq!( - marrow_typeinfo::get_data_type::>(&Options::default()), + marrow_convert::get_data_type::>(&Options::default()), Ok(DataType::Union( vec![( 0, @@ -461,7 +461,7 @@ fn enums_with_liftime_generics() { } assert_eq!( - marrow_typeinfo::get_data_type::(&Options::default()), + marrow_convert::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( @@ -496,7 +496,7 @@ fn enum_liftime_generics_with_bounds() { } assert_eq!( - marrow_typeinfo::get_data_type::(&Options::default()), + marrow_convert::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( @@ -534,7 +534,7 @@ fn enum_liftime_generics_with_where_clause() { } assert_eq!( - marrow_typeinfo::get_data_type::(&Options::default()), + marrow_convert::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( From fc5629f77f68e2cabc0c186c7f506eb6f9f404c7 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 13 Apr 2025 15:20:36 +0200 Subject: [PATCH 11/12] Refactor start to implement builders --- marrow-convert-derive/src/array_push.rs | 1 + marrow-convert-derive/src/default_builder.rs | 1 + marrow-convert-derive/src/lib.rs | 338 +----------------- marrow-convert-derive/src/type_info.rs | 326 +++++++++++++++++ marrow-convert/src/internal/builder/list.rs | 152 ++++++++ marrow-convert/src/internal/builder/mod.rs | 52 +++ marrow-convert/src/internal/builder/option.rs | 69 ++++ .../src/internal/builder/primitive.rs | 162 +++++++++ marrow-convert/src/internal/builder/struct.rs | 163 +++++++++ marrow-convert/src/internal/builder/union.rs | 284 +++++++++++++++ marrow-convert/src/internal/mod.rs | 4 + .../{typeinfo.rs => internal/type_info.rs} | 11 +- .../type_info_impls}/collections.rs | 21 +- .../type_info_impls}/compounds.rs | 11 +- .../type_info_impls}/ext/bigdecimal.rs | 9 +- .../type_info_impls}/ext/chrono.rs | 21 +- .../type_info_impls}/ext/jiff.rs | 21 +- .../type_info_impls}/ext/mod.rs | 0 .../type_info_impls}/ext/uuid.rs | 7 +- .../type_info_impls}/mod.rs | 0 .../type_info_impls}/primitives.rs | 15 +- .../type_info_impls}/std.rs | 31 +- .../type_info_impls}/utils.rs | 11 +- .../type_info_impls}/wrappers.rs | 23 +- marrow-convert/src/internal/util.rs | 34 ++ marrow-convert/src/lib.rs | 48 ++- marrow-convert/src/tests.rs | 2 +- marrow-convert/tests/derive.rs | 77 ++-- marrow/src/array.rs | 70 ++++ 29 files changed, 1519 insertions(+), 445 deletions(-) create mode 100644 marrow-convert-derive/src/array_push.rs create mode 100644 marrow-convert-derive/src/default_builder.rs create mode 100644 marrow-convert-derive/src/type_info.rs create mode 100644 marrow-convert/src/internal/builder/list.rs create mode 100644 marrow-convert/src/internal/builder/mod.rs create mode 100644 marrow-convert/src/internal/builder/option.rs create mode 100644 marrow-convert/src/internal/builder/primitive.rs create mode 100644 marrow-convert/src/internal/builder/struct.rs create mode 100644 marrow-convert/src/internal/builder/union.rs create mode 100644 marrow-convert/src/internal/mod.rs rename marrow-convert/src/{typeinfo.rs => internal/type_info.rs} (90%) rename marrow-convert/src/{impls => internal/type_info_impls}/collections.rs (68%) rename marrow-convert/src/{impls => internal/type_info_impls}/compounds.rs (88%) rename marrow-convert/src/{impls => internal/type_info_impls}/ext/bigdecimal.rs (57%) rename marrow-convert/src/{impls => internal/type_info_impls}/ext/chrono.rs (60%) rename marrow-convert/src/{impls => internal/type_info_impls}/ext/jiff.rs (62%) rename marrow-convert/src/{impls => internal/type_info_impls}/ext/mod.rs (100%) rename marrow-convert/src/{impls => internal/type_info_impls}/ext/uuid.rs (83%) rename marrow-convert/src/{impls => internal/type_info_impls}/mod.rs (100%) rename marrow-convert/src/{impls => internal/type_info_impls}/primitives.rs (83%) rename marrow-convert/src/{impls => internal/type_info_impls}/std.rs (82%) rename marrow-convert/src/{impls => internal/type_info_impls}/utils.rs (81%) rename marrow-convert/src/{impls => internal/type_info_impls}/wrappers.rs (61%) create mode 100644 marrow-convert/src/internal/util.rs diff --git a/marrow-convert-derive/src/array_push.rs b/marrow-convert-derive/src/array_push.rs new file mode 100644 index 0000000..d3f5a12 --- /dev/null +++ b/marrow-convert-derive/src/array_push.rs @@ -0,0 +1 @@ + diff --git a/marrow-convert-derive/src/default_builder.rs b/marrow-convert-derive/src/default_builder.rs new file mode 100644 index 0000000..d3f5a12 --- /dev/null +++ b/marrow-convert-derive/src/default_builder.rs @@ -0,0 +1 @@ + diff --git a/marrow-convert-derive/src/lib.rs b/marrow-convert-derive/src/lib.rs index 09ceb30..9e1f782 100644 --- a/marrow-convert-derive/src/lib.rs +++ b/marrow-convert-derive/src/lib.rs @@ -1,332 +1,22 @@ use proc_macro::TokenStream; -use quote::{ToTokens, quote}; -use syn::{ - Attribute, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, GenericParam, Ident, - Lit, LitStr, Meta, Token, punctuated::Punctuated, spanned::Spanned, -}; -#[proc_macro_derive(TypeInfo, attributes(marrow_type_info))] -pub fn derive_type_info(input: TokenStream) -> TokenStream { - derive_type_info_impl(input.into()).into() -} - -fn derive_type_info_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream { - let input: DeriveInput = syn::parse2(input).unwrap(); - - if input - .generics - .params - .iter() - .any(|p| matches!(p, GenericParam::Type(_))) - { - panic!("Deriving TypeInfo for generics with type parameters is not supported") - } - - match &input.data { - Data::Struct(data) => derive_for_struct(&input, data), - Data::Enum(data) => derive_for_enum(&input, data), - Data::Union(_) => panic!("Deriving TypeInfo for unions is not supported"), - } -} - -#[derive(Debug, Default)] -struct FieldArgs { - // TODO: use a path here - with: Option, -} - -impl FieldArgs { - pub fn from_attrs(attrs: &[Attribute]) -> Self { - let mut result = Self::default(); - - for attr in attrs { - if !attr.path().is_ident("marrow_type_info") { - continue; - } - - let nested = attr - .parse_args_with(Punctuated::::parse_terminated) - .unwrap(); - for meta in nested { - match meta { - Meta::NameValue(meta) => { - if !meta.path.is_ident("with") { - continue; - } - match meta.value { - Expr::Lit(lit) => match lit.lit { - Lit::Str(str) => { - result.with = Some(Ident::new(&str.value(), str.span())); - } - _ => unimplemented!(), - }, - _ => unimplemented!(), - } - } - _ => unimplemented!(), - } - } - } - result - } -} - -#[derive(Debug, Default)] -struct VariantArgs { - with: Option, -} - -impl VariantArgs { - pub fn from_attrs(attrs: &[Attribute]) -> Self { - let mut result = Self::default(); - - for attr in attrs { - if !attr.path().is_ident("marrow_type_info") { - continue; - } - - let nested = attr - .parse_args_with(Punctuated::::parse_terminated) - .unwrap(); - for meta in nested { - match meta { - Meta::NameValue(meta) => { - if !meta.path.is_ident("with") { - continue; - } - match meta.value { - Expr::Lit(lit) => match lit.lit { - Lit::Str(str) => { - result.with = Some(Ident::new(&str.value(), str.span())); - } - _ => unimplemented!(), - }, - _ => unimplemented!(), - } - } - _ => unimplemented!(), - } - } - } - result - } -} - -fn derive_for_struct(input: &DeriveInput, data: &DataStruct) -> proc_macro2::TokenStream { - let name = &input.ident; - - let generics_decl = &input.generics; - let generics_use = if !input.generics.params.is_empty() { - let generics_use = input.generics.params.iter().map(|p| match p { - GenericParam::Const(p) => p.ident.to_token_stream(), - GenericParam::Lifetime(p) => p.lifetime.to_token_stream(), - GenericParam::Type(_) => panic!(), - }); - quote! { - <#(#generics_use),*> - } - } else { - quote! {} - }; - - let fields = get_fields(&data.fields); - let body = match fields.as_slice() { - [] => panic!(), - [(NameSource::Index, _, field)] => { - // TODO: ensure no args - let field_ty = &field.ty; - quote! { <#field_ty>::get_field(context) } - } - fields => { - let mut field_exprs = Vec::new(); +mod array_push; +mod default_builder; +mod type_info; - for (_, field_name, field) in fields { - let ty = &field.ty; - let args = FieldArgs::from_attrs(&field.attrs); - - if let Some(func) = args.with.as_ref() { - field_exprs.push(quote! { - fields.push(context.nest(#field_name, #func::<#ty>)?); - }); - } else { - field_exprs.push(quote! { - fields.push(context.get_field::<#ty>(#field_name)?); - }) - } - } - - quote! { - let mut fields = ::std::vec::Vec::<::marrow::datatypes::Field>::new(); - #( #field_exprs; )* - - Ok(::marrow::datatypes::Field { - name: ::std::string::String::from(context.get_name()), - data_type: ::marrow::datatypes::DataType::Struct(fields), - nullable: false, - metadata: ::std::default::Default::default(), - }) - } - } - }; - - quote! { - const _: () = { - impl #generics_decl ::marrow_convert::TypeInfo for #name #generics_use { - fn get_field( - context: ::marrow_convert::Context<'_>, - ) -> ::marrow_convert::Result<::marrow::datatypes::Field> { - #body - } - } - }; - } -} - -fn derive_for_enum(input: &DeriveInput, data: &DataEnum) -> proc_macro2::TokenStream { - let mut variant_exprs = Vec::new(); - - let name = &input.ident; - let generics_decl = &input.generics; - let generics_use = if !input.generics.params.is_empty() { - let generics_use = input.generics.params.iter().map(|p| match p { - GenericParam::Const(p) => p.ident.to_token_stream(), - GenericParam::Lifetime(p) => p.lifetime.to_token_stream(), - GenericParam::Type(_) => panic!(), - }); - quote! { - <#(#generics_use),*> - } - } else { - quote! {} - }; - - for (idx, variant) in data.variants.iter().enumerate() { - let variant_name = &variant.ident; - let variant_name = LitStr::new(&variant_name.to_string(), variant_name.span()); - let variant_args = VariantArgs::from_attrs(&variant.attrs); - - if let Some(func) = variant_args.with.as_ref() { - variant_exprs.push(quote! { #func(stringify!(#variant_name)) }); - continue; - } - - let variant_idx = i8::try_from(idx).unwrap(); - - let fields = get_fields(&variant.fields); - match fields.as_slice() { - [] => { - // use nesting to allow overwrites - variant_exprs.push(quote! { - (#variant_idx, context.nest(#variant_name, |context| { - Ok(::marrow::datatypes::Field { - name: ::std::string::String::from(context.get_name()), - data_type: ::marrow::datatypes::DataType::Null, - nullable: true, - metadata: ::std::default::Default::default(), - }) - })?) - }); - } - [(NameSource::Index, _, field)] => { - let field_ty = &field.ty; - variant_exprs.push(quote! { - (#variant_idx, context.nest(#variant_name, <#field_ty>::get_field)?) - }); - } - fields => { - let mut field_exprs = Vec::new(); - for (_, field_name, field) in fields { - let field_ty = &field.ty; - field_exprs.push(quote! { - context.get_field::<#field_ty>(#field_name)? - }); - } - variant_exprs.push(quote! { - (#variant_idx, context.nest(#variant_name, |context| Ok(::marrow::datatypes::Field { - name: ::std::string::String::from(context.get_name()), - data_type: ::marrow::datatypes::DataType::Struct(vec![#(#field_exprs),*]), - nullable: false, - metadata: ::std::default::Default::default(), - }))?) - }); - } - } - } - - quote! { - const _: () = { - impl #generics_decl ::marrow_convert::TypeInfo for #name #generics_use { - fn get_field( - context: ::marrow_convert::Context<'_>, - ) -> ::marrow_convert::Result<::marrow::datatypes::Field> { - let mut variants = ::std::vec::Vec::<(::std::primitive::i8, ::marrow::datatypes::Field)>::new(); - #( variants.push(#variant_exprs); )* - - Ok(::marrow::datatypes::Field { - name: ::std::string::String::from(context.get_name()), - data_type: ::marrow::datatypes::DataType::Union(variants, ::marrow::datatypes::UnionMode::Dense), - nullable: false, - metadata: ::std::default::Default::default(), - }) - } - } - }; - } -} - -fn get_fields(fields: &Fields) -> Vec<(NameSource, LitStr, &Field)> { - let mut result = Vec::new(); - match fields { - Fields::Unit => {} - Fields::Named(fields) => { - for field in &fields.named { - let Some(name) = field.ident.as_ref() else { - unreachable!("Named field must have a name"); - }; - let name = LitStr::new(&name.to_string(), name.span()); - result.push((NameSource::Ident, name, field)); - } - } - Fields::Unnamed(fields) => { - for (idx, field) in fields.unnamed.iter().enumerate() { - let name = LitStr::new(&idx.to_string(), field.span()); - result.push((NameSource::Index, name, field)); - } - } - } - result -} - -#[derive(Debug, Clone, Copy, PartialEq)] -enum NameSource { - Ident, - Index, -} - -#[test] -#[should_panic(expected = "Deriving TypeInfo for generics with type parameters is not supported")] -fn reject_unsupported() { - derive_type_info_impl(quote! { - struct Example { - field: T, - } - }); +#[proc_macro_derive(DefaultArrayType, attributes(marrow))] +pub fn derive_type_info(input: TokenStream) -> TokenStream { + type_info::derive_type_info_impl(input.into()).into() } -#[test] -fn lifetimes_are_supported() { - derive_type_info_impl(quote! { - struct Example<'a> { - field: &'a i64, - } - }); +#[proc_macro_derive(ArrayPush, attributes(marrow))] +pub fn derive_array_push(input: TokenStream) -> TokenStream { + std::mem::drop(input); + unimplemented!() } -#[test] -fn const_params_are_supported() { - derive_type_info_impl(quote! { - struct Example { - field: [u8; N], - } - }); +#[proc_macro_derive(DefaultArrayBuilder, attributes(marrow))] +pub fn derive_default_builder(input: TokenStream) -> TokenStream { + std::mem::drop(input); + unimplemented!() } diff --git a/marrow-convert-derive/src/type_info.rs b/marrow-convert-derive/src/type_info.rs new file mode 100644 index 0000000..8570c46 --- /dev/null +++ b/marrow-convert-derive/src/type_info.rs @@ -0,0 +1,326 @@ +use quote::{ToTokens, quote}; +use syn::{ + Attribute, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, GenericParam, Ident, + Lit, LitStr, Meta, Token, punctuated::Punctuated, spanned::Spanned, +}; + +pub fn derive_type_info_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream { + let input: DeriveInput = syn::parse2(input).unwrap(); + + if input + .generics + .params + .iter() + .any(|p| matches!(p, GenericParam::Type(_))) + { + panic!("Deriving TypeInfo for generics with type parameters is not supported") + } + + match &input.data { + Data::Struct(data) => derive_for_struct(&input, data), + Data::Enum(data) => derive_for_enum(&input, data), + Data::Union(_) => panic!("Deriving TypeInfo for unions is not supported"), + } +} + +#[derive(Debug, Default)] +struct FieldArgs { + // TODO: use a path here + with: Option, +} + +impl FieldArgs { + pub fn from_attrs(attrs: &[Attribute]) -> Self { + let mut result = Self::default(); + + for attr in attrs { + if !attr.path().is_ident("marrow") { + continue; + } + + let nested = attr + .parse_args_with(Punctuated::::parse_terminated) + .unwrap(); + for meta in nested { + match meta { + Meta::NameValue(meta) => { + if !meta.path.is_ident("with") { + continue; + } + match meta.value { + Expr::Lit(lit) => match lit.lit { + Lit::Str(str) => { + result.with = Some(Ident::new(&str.value(), str.span())); + } + _ => unimplemented!(), + }, + _ => unimplemented!(), + } + } + _ => unimplemented!(), + } + } + } + result + } +} + +#[derive(Debug, Default)] +struct VariantArgs { + with: Option, +} + +impl VariantArgs { + pub fn from_attrs(attrs: &[Attribute]) -> Self { + let mut result = Self::default(); + + for attr in attrs { + if !attr.path().is_ident("marrow_type_info") { + continue; + } + + let nested = attr + .parse_args_with(Punctuated::::parse_terminated) + .unwrap(); + for meta in nested { + match meta { + Meta::NameValue(meta) => { + if !meta.path.is_ident("with") { + continue; + } + match meta.value { + Expr::Lit(lit) => match lit.lit { + Lit::Str(str) => { + result.with = Some(Ident::new(&str.value(), str.span())); + } + _ => unimplemented!(), + }, + _ => unimplemented!(), + } + } + _ => unimplemented!(), + } + } + } + result + } +} + +fn derive_for_struct(input: &DeriveInput, data: &DataStruct) -> proc_macro2::TokenStream { + let name = &input.ident; + + let generics_decl = &input.generics; + let generics_use = if !input.generics.params.is_empty() { + let generics_use = input.generics.params.iter().map(|p| match p { + GenericParam::Const(p) => p.ident.to_token_stream(), + GenericParam::Lifetime(p) => p.lifetime.to_token_stream(), + GenericParam::Type(_) => panic!(), + }); + quote! { + <#(#generics_use),*> + } + } else { + quote! {} + }; + + let fields = get_fields(&data.fields); + let body = match fields.as_slice() { + [] => panic!(), + [(NameSource::Index, _, field)] => { + // TODO: ensure no args + let field_ty = &field.ty; + quote! { <#field_ty>::get_field(context) } + } + fields => { + let mut field_exprs = Vec::new(); + + for (_, field_name, field) in fields { + let ty = &field.ty; + let args = FieldArgs::from_attrs(&field.attrs); + + if let Some(func) = args.with.as_ref() { + field_exprs.push(quote! { + fields.push(context.nest(#field_name, #func::<#ty>)?); + }); + } else { + field_exprs.push(quote! { + fields.push(context.get_field::<#ty>(#field_name)?); + }) + } + } + + quote! { + let mut fields = ::std::vec::Vec::<::marrow::datatypes::Field>::new(); + #( #field_exprs; )* + + Ok(::marrow::datatypes::Field { + name: ::std::string::String::from(context.get_name()), + data_type: ::marrow::datatypes::DataType::Struct(fields), + nullable: false, + metadata: ::std::default::Default::default(), + }) + } + } + }; + + quote! { + const _: () = { + impl #generics_decl ::marrow_convert::types::DefaultArrayType for #name #generics_use { + fn get_field( + context: ::marrow_convert::types::Context<'_>, + ) -> ::marrow_convert::Result<::marrow::datatypes::Field> { + #body + } + } + }; + } +} + +fn derive_for_enum(input: &DeriveInput, data: &DataEnum) -> proc_macro2::TokenStream { + let mut variant_exprs = Vec::new(); + + let name = &input.ident; + let generics_decl = &input.generics; + let generics_use = if !input.generics.params.is_empty() { + let generics_use = input.generics.params.iter().map(|p| match p { + GenericParam::Const(p) => p.ident.to_token_stream(), + GenericParam::Lifetime(p) => p.lifetime.to_token_stream(), + GenericParam::Type(_) => panic!(), + }); + quote! { + <#(#generics_use),*> + } + } else { + quote! {} + }; + + for (idx, variant) in data.variants.iter().enumerate() { + let variant_name = &variant.ident; + let variant_name = LitStr::new(&variant_name.to_string(), variant_name.span()); + let variant_args = VariantArgs::from_attrs(&variant.attrs); + + if let Some(func) = variant_args.with.as_ref() { + variant_exprs.push(quote! { #func(stringify!(#variant_name)) }); + continue; + } + + let variant_idx = i8::try_from(idx).unwrap(); + + let fields = get_fields(&variant.fields); + match fields.as_slice() { + [] => { + // use nesting to allow overwrites + variant_exprs.push(quote! { + (#variant_idx, context.nest(#variant_name, |context| { + Ok(::marrow::datatypes::Field { + name: ::std::string::String::from(context.get_name()), + data_type: ::marrow::datatypes::DataType::Null, + nullable: true, + metadata: ::std::default::Default::default(), + }) + })?) + }); + } + [(NameSource::Index, _, field)] => { + let field_ty = &field.ty; + variant_exprs.push(quote! { + (#variant_idx, context.nest(#variant_name, <#field_ty>::get_field)?) + }); + } + fields => { + let mut field_exprs = Vec::new(); + for (_, field_name, field) in fields { + let field_ty = &field.ty; + field_exprs.push(quote! { + context.get_field::<#field_ty>(#field_name)? + }); + } + variant_exprs.push(quote! { + (#variant_idx, context.nest(#variant_name, |context| Ok(::marrow::datatypes::Field { + name: ::std::string::String::from(context.get_name()), + data_type: ::marrow::datatypes::DataType::Struct(vec![#(#field_exprs),*]), + nullable: false, + metadata: ::std::default::Default::default(), + }))?) + }); + } + } + } + + quote! { + const _: () = { + impl #generics_decl ::marrow_convert::types::DefaultArrayType for #name #generics_use { + fn get_field( + context: ::marrow_convert::types::Context<'_>, + ) -> ::marrow_convert::Result<::marrow::datatypes::Field> { + let mut variants = ::std::vec::Vec::<(::std::primitive::i8, ::marrow::datatypes::Field)>::new(); + #( variants.push(#variant_exprs); )* + + Ok(::marrow::datatypes::Field { + name: ::std::string::String::from(context.get_name()), + data_type: ::marrow::datatypes::DataType::Union(variants, ::marrow::datatypes::UnionMode::Dense), + nullable: false, + metadata: ::std::default::Default::default(), + }) + } + } + }; + } +} + +fn get_fields(fields: &Fields) -> Vec<(NameSource, LitStr, &Field)> { + let mut result = Vec::new(); + match fields { + Fields::Unit => {} + Fields::Named(fields) => { + for field in &fields.named { + let Some(name) = field.ident.as_ref() else { + unreachable!("Named field must have a name"); + }; + let name = LitStr::new(&name.to_string(), name.span()); + result.push((NameSource::Ident, name, field)); + } + } + Fields::Unnamed(fields) => { + for (idx, field) in fields.unnamed.iter().enumerate() { + let name = LitStr::new(&idx.to_string(), field.span()); + result.push((NameSource::Index, name, field)); + } + } + } + result +} + +#[derive(Debug, Clone, Copy, PartialEq)] +enum NameSource { + Ident, + Index, +} + +#[test] +#[should_panic(expected = "Deriving TypeInfo for generics with type parameters is not supported")] +fn reject_unsupported() { + derive_type_info_impl(quote! { + struct Example { + field: T, + } + }); +} + +#[test] +fn lifetimes_are_supported() { + derive_type_info_impl(quote! { + struct Example<'a> { + field: &'a i64, + } + }); +} + +#[test] +fn const_params_are_supported() { + derive_type_info_impl(quote! { + struct Example { + field: [u8; N], + } + }); +} diff --git a/marrow-convert/src/internal/builder/list.rs b/marrow-convert/src/internal/builder/list.rs new file mode 100644 index 0000000..93d8183 --- /dev/null +++ b/marrow-convert/src/internal/builder/list.rs @@ -0,0 +1,152 @@ +use marrow::{ + array::{Array, ListArray}, + datatypes::FieldMeta, +}; + +use crate::{Error, Result}; + +use super::{ArrayBuilder, ArrayPush, DefaultArrayBuilder}; + +struct GenericListBuilder { + offsets: Vec, + builder: B, +} + +impl GenericListBuilder { + pub fn new(builder: B) -> Self { + Self { + offsets: vec![O::default()], + builder, + } + } +} + +trait Offset: Default + Copy + std::ops::Add { + const ONE: Self; + const ARRAY_VARIANT: fn(ListArray) -> Array; +} + +impl Offset for i32 { + const ONE: Self = 1; + const ARRAY_VARIANT: fn(ListArray) -> Array = Array::List; +} + +impl Offset for i64 { + const ONE: Self = 1; + const ARRAY_VARIANT: fn(ListArray) -> Array = Array::LargeList; +} + +impl ArrayBuilder for GenericListBuilder { + fn push_default(&mut self) -> Result<()> { + let Some(last_offset) = self.offsets.last() else { + return Err(Error(String::from("invalid state"))); + }; + self.offsets.push(*last_offset); + Ok(()) + } + + fn build_array(&mut self) -> Result { + Ok(O::ARRAY_VARIANT(ListArray { + validity: None, + offsets: std::mem::replace(&mut self.offsets, vec![O::default()]), + meta: FieldMeta { + name: String::from("element"), + ..Default::default() + }, + elements: Box::new(self.builder.build_array()?), + })) + } +} + +impl> ArrayPush<[T]> for GenericListBuilder { + fn push_value(&mut self, value: &[T]) -> Result<()> { + let Some(last_offset) = self.offsets.last().copied() else { + return Err(Error(String::from("invalid state"))); + }; + + let mut pushed = O::default(); + for item in value { + self.builder.push_value(item)?; + pushed = pushed + O::ONE; + } + + self.offsets.push(last_offset + pushed); + Ok(()) + } +} + +pub struct ListBuilder(GenericListBuilder); + +impl ListBuilder { + pub fn new(builder: B) -> Self { + Self(GenericListBuilder::new(builder)) + } +} + +impl ArrayBuilder for ListBuilder { + fn push_default(&mut self) -> Result<()> { + self.0.push_default() + } + + fn build_array(&mut self) -> Result { + self.0.build_array() + } +} + +impl> ArrayPush<[T]> for ListBuilder { + fn push_value(&mut self, value: &[T]) -> Result<()> { + self.0.push_value(value) + } +} + +impl> ArrayPush> for ListBuilder { + fn push_value(&mut self, value: &Vec) -> Result<()> { + self.0.push_value(value.as_slice()) + } +} + +pub struct LargeListBuilder(GenericListBuilder); + +impl LargeListBuilder { + pub fn new(builder: B) -> Self { + Self(GenericListBuilder::new(builder)) + } +} + +impl ArrayBuilder for LargeListBuilder { + fn push_default(&mut self) -> Result<()> { + self.0.push_default() + } + + fn build_array(&mut self) -> Result { + self.0.build_array() + } +} + +impl> ArrayPush<[T]> for LargeListBuilder { + fn push_value(&mut self, value: &[T]) -> Result<()> { + self.0.push_value(value) + } +} + +impl> ArrayPush> for LargeListBuilder { + fn push_value(&mut self, value: &Vec) -> Result<()> { + self.0.push_value(value.as_slice()) + } +} + +impl DefaultArrayBuilder for Vec { + type ArrayBuilder = LargeListBuilder; + + fn default_builder() -> Self::ArrayBuilder { + LargeListBuilder::new(T::default_builder()) + } +} + +impl DefaultArrayBuilder for [T] { + type ArrayBuilder = LargeListBuilder; + + fn default_builder() -> Self::ArrayBuilder { + LargeListBuilder::new(T::default_builder()) + } +} diff --git a/marrow-convert/src/internal/builder/mod.rs b/marrow-convert/src/internal/builder/mod.rs new file mode 100644 index 0000000..6f096a7 --- /dev/null +++ b/marrow-convert/src/internal/builder/mod.rs @@ -0,0 +1,52 @@ +use marrow::array::Array; + +use crate::Result; + +pub mod list; +pub mod option; +pub mod primitive; +pub mod r#struct; +pub mod union; + +pub trait ArrayBuilder { + fn push_default(&mut self) -> Result<()>; + fn build_array(&mut self) -> Result; +} + +pub trait ArrayPush: ArrayBuilder { + fn push_value(&mut self, value: &T) -> Result<()>; +} + +impl> ArrayPush<&T> for B { + fn push_value(&mut self, value: &&T) -> Result<()> { + self.push_value(*value) + } +} + +impl> ArrayPush<&mut T> for B { + fn push_value(&mut self, value: &&mut T) -> Result<()> { + self.push_value(*value) + } +} + +pub trait DefaultArrayBuilder { + type ArrayBuilder: ArrayBuilder; + + fn default_builder() -> Self::ArrayBuilder; +} + +impl DefaultArrayBuilder for &T { + type ArrayBuilder = T::ArrayBuilder; + + fn default_builder() -> Self::ArrayBuilder { + T::default_builder() + } +} + +impl DefaultArrayBuilder for &mut T { + type ArrayBuilder = T::ArrayBuilder; + + fn default_builder() -> Self::ArrayBuilder { + T::default_builder() + } +} diff --git a/marrow-convert/src/internal/builder/option.rs b/marrow-convert/src/internal/builder/option.rs new file mode 100644 index 0000000..1ce8e88 --- /dev/null +++ b/marrow-convert/src/internal/builder/option.rs @@ -0,0 +1,69 @@ +use marrow::array::Array; + +use crate::{Error, Result}; + +use super::{ArrayBuilder, ArrayPush, DefaultArrayBuilder}; + +pub struct OptionBuilder { + len: usize, + validity: Vec, + builder: B, +} + +impl OptionBuilder { + pub fn new(builder: B) -> Self { + Self { + len: 0, + validity: Vec::new(), + builder, + } + } +} + +impl ArrayBuilder for OptionBuilder { + fn push_default(&mut self) -> Result<()> { + marrow::bits::push(&mut self.validity, &mut self.len, false); + self.builder.push_default()?; + Ok(()) + } + + fn build_array(&mut self) -> Result { + let array = self.builder.build_array()?; + let validity = std::mem::take(&mut self.validity); + let _ = std::mem::take(&mut self.len); + with_validity(array, validity) + } +} + +impl> ArrayPush> for OptionBuilder { + fn push_value(&mut self, value: &Option) -> Result<()> { + match value { + Some(value) => { + marrow::bits::push(&mut self.validity, &mut self.len, true); + self.builder.push_value(value) + } + None => self.push_default(), + } + } +} + +impl DefaultArrayBuilder for Option { + type ArrayBuilder = OptionBuilder; + + fn default_builder() -> Self::ArrayBuilder { + OptionBuilder::new(T::default_builder()) + } +} + +fn with_validity(array: Array, validity: Vec) -> Result { + // TODO: check compatibility + match array { + Array::Null(array) => Ok(Array::Null(array)), + Array::Boolean(mut array) => { + array.validity = Some(validity); + Ok(Array::Boolean(array)) + } + // TODO: add more .. + _ => Err(Error(String::from("Cannot set valditiy for array"))), + } +} diff --git a/marrow-convert/src/internal/builder/primitive.rs b/marrow-convert/src/internal/builder/primitive.rs new file mode 100644 index 0000000..edc69e7 --- /dev/null +++ b/marrow-convert/src/internal/builder/primitive.rs @@ -0,0 +1,162 @@ +use marrow::{ + array::{Array, BooleanArray, NullArray, PrimitiveArray}, + types::f16, +}; + +use crate::Result; + +use super::{ArrayBuilder, ArrayPush, DefaultArrayBuilder}; + +#[derive(Debug, Default)] +struct PrimitiveBuilder { + values: Vec, + build_impl: B, +} + +trait BuildPrimitiveArrayImpl { + fn build(&self, values: &mut Vec) -> Result; +} + +impl> ArrayBuilder for PrimitiveBuilder { + fn push_default(&mut self) -> Result<()> { + self.values.push(T::default()); + Ok(()) + } + + fn build_array(&mut self) -> Result { + self.build_impl.build(&mut self.values) + } +} + +#[derive(Debug, Default)] +struct BuildNative; + +macro_rules! impl_build_native { + ($(($ty:ident, $variant:ident),)*) => { + $( + impl BuildPrimitiveArrayImpl<$ty> for BuildNative { + fn build(&self, values: &mut Vec<$ty>) -> Result { + Ok(Array::$variant(PrimitiveArray { + validity: None, + values: std::mem::take(values), + })) + } + } + )* + }; +} + +impl_build_native!( + (i8, Int8), + (i16, Int16), + (i32, Int32), + (i64, Int64), + (u8, UInt8), + (u16, UInt16), + (u32, UInt32), + (u64, UInt64), + (f16, Float16), + (f32, Float32), + (f64, Float64), +); + +macro_rules! define_builder { + ($(($builder:ident, $ty:ident),)*) => { + $( + #[derive(Debug, Default)] + pub struct $builder(PrimitiveBuilder<$ty, BuildNative>); + + impl ArrayBuilder for $builder { + fn push_default(&mut self) -> Result<()> { + self.0.push_default() + } + + fn build_array(&mut self) -> Result { + self.0.build_array() + } + } + + impl ArrayPush<$ty> for $builder { + fn push_value(&mut self, value: &$ty) -> Result<()> { + self.0.values.push(*value); + Ok(()) + } + } + + impl DefaultArrayBuilder for $ty { + type ArrayBuilder = $builder; + + fn default_builder() -> Self::ArrayBuilder { + $builder::default() + } + } + )* + }; +} + +define_builder!( + (Int8Builder, i8), + (Int16Builder, i16), + (Int32Builder, i32), + (Int64Builder, i64), + (UInt8Builder, u8), + (UInt16Builder, u16), + (UInt32Builder, u32), + (UInt64Builder, u64), + (Float16Builder, f16), + (Float32Builder, f32), + (Float64Builder, f64), +); + +#[derive(Debug, Default)] +pub struct NullBuilder(usize); + +impl ArrayBuilder for NullBuilder { + fn push_default(&mut self) -> Result<()> { + self.0 += 1; + Ok(()) + } + + fn build_array(&mut self) -> Result { + Ok(Array::Null(NullArray { + len: std::mem::take(&mut self.0), + })) + } +} + +impl DefaultArrayBuilder for () { + type ArrayBuilder = NullBuilder; + + fn default_builder() -> Self::ArrayBuilder { + NullBuilder::default() + } +} + +#[derive(Debug, Default)] +pub struct BooleanBuilder { + len: usize, + values: Vec, +} + +impl ArrayBuilder for BooleanBuilder { + fn push_default(&mut self) -> Result<()> { + marrow::bits::push(&mut self.values, &mut self.len, false); + Ok(()) + } + + fn build_array(&mut self) -> Result { + Ok(Array::Boolean(BooleanArray { + len: std::mem::take(&mut self.len), + values: std::mem::take(&mut self.values), + validity: None, + })) + } +} + +impl DefaultArrayBuilder for bool { + type ArrayBuilder = BooleanBuilder; + + fn default_builder() -> Self::ArrayBuilder { + BooleanBuilder::default() + } +} diff --git a/marrow-convert/src/internal/builder/struct.rs b/marrow-convert/src/internal/builder/struct.rs new file mode 100644 index 0000000..8dcf2a6 --- /dev/null +++ b/marrow-convert/src/internal/builder/struct.rs @@ -0,0 +1,163 @@ +use marrow::{ + array::{Array, StructArray}, + datatypes::FieldMeta, +}; + +use crate::{Error, Result}; + +use super::ArrayBuilder; + +// TODO: add simple doc test showing how to implement a custom impl +/// Support to build struct builders +/// +/// When pushing a value the following invariants need to be observed: +/// +/// - A value must be pushed to each child field +/// - The `len` field must be incremented +/// +pub struct StructBuilder { + pub meta: Vec, + pub len: usize, + pub children: C, +} + +macro_rules! impl_struct_builder { + ($($el:ident,)*) => { + #[allow(non_snake_case, clippy::vec_init_then_push)] + impl<$($el: ArrayBuilder),*> ArrayBuilder for StructBuilder<($($el,)*)> { + fn push_default(&mut self) -> Result<()> { + let ($($el,)*) = &mut self.children; + self.len += 1; + $($el.push_default()?;)* + Ok(()) + } + + fn build_array(&mut self) -> Result { + let ($($el,)*) = &mut self.children; + let mut arrays = Vec::new(); + // TODO: ensure all builders are called? + $(arrays.push($el.build_array()?);)* + + if arrays.len() != self.meta.len() { + return Err(Error(String::from("Not matching number of meta and children"))); + } + + let fields = std::iter::zip(&self.meta, arrays).map(|(meta, array)| (meta.clone(), array)).collect(); + + Ok(Array::Struct(StructArray { + len: self.len, + validity: None, + fields, + })) + } + } + }; +} + +// TODO: is a struct with fields valid? +impl_struct_builder!(A,); +impl_struct_builder!(A, B,); +impl_struct_builder!(A, B, C,); +impl_struct_builder!(A, B, C, D,); +impl_struct_builder!(A, B, C, D, E,); +impl_struct_builder!(A, B, C, D, E, F,); +impl_struct_builder!(A, B, C, D, E, F, G,); +impl_struct_builder!(A, B, C, D, E, F, G, H,); +impl_struct_builder!(A, B, C, D, E, F, G, H, I,); +impl_struct_builder!(A, B, C, D, E, F, G, H, I, J,); +impl_struct_builder!(A, B, C, D, E, F, G, H, I, J, K,); +impl_struct_builder!(A, B, C, D, E, F, G, H, I, J, K, L,); +impl_struct_builder!(A, B, C, D, E, F, G, H, I, J, K, L, M,); +impl_struct_builder!(A, B, C, D, E, F, G, H, I, J, K, L, M, N,); +impl_struct_builder!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O,); +impl_struct_builder!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P,); + +#[test] +fn struct_example() { + use super::{ArrayPush, DefaultArrayBuilder}; + + struct S { + a: i8, + b: i32, + } + + // move into derive(ArrayPush) + // Allows to customize the builder + const _: () = { + impl, B: ArrayPush> ArrayPush for StructBuilder<(A, B)> { + fn push_value(&mut self, value: &S) -> Result<()> { + self.len += 1; + self.children.0.push_value(&value.a)?; + self.children.1.push_value(&value.b)?; + Ok(()) + } + } + }; + + // move into derive(DefaultBuilder) + const _: () = { + struct Builder( + StructBuilder<( + ::ArrayBuilder, + ::ArrayBuilder, + )>, + ); + + impl ArrayBuilder for Builder { + fn push_default(&mut self) -> Result<()> { + self.0.push_default() + } + + fn build_array(&mut self) -> Result { + self.0.build_array() + } + } + + impl DefaultArrayBuilder for S { + type ArrayBuilder = Builder; + + fn default_builder() -> Self::ArrayBuilder { + Builder(StructBuilder { + len: 0, + meta: vec![ + FieldMeta { + name: String::from("a"), + ..Default::default() + }, + FieldMeta { + name: String::from("b"), + ..Default::default() + }, + ], + children: ( + (::default_builder()), + (::default_builder()), + ), + }) + } + } + + // NOTE: implement separately to allow independent derives + impl ArrayPush for Builder { + fn push_value(&mut self, value: &S) -> Result<()> { + self.0.len += 1; + self.0.children.0.push_value(&value.a)?; + self.0.children.1.push_value(&value.b)?; + Ok(()) + } + } + }; + + // the public API + let mut builder = S::default_builder(); + builder.push_value(&S { a: 0, b: -21 }).unwrap(); + builder.push_value(&S { a: 1, b: -42 }).unwrap(); + let array = builder.build_array().unwrap(); + + let [(_, a), (_, b)] = array.into_struct_fields().expect("invalid array type"); + let a = a.into_int8().expect("invalid array type"); + let b = b.into_int32().expect("invalid array type"); + + assert_eq!(a.values, vec![0, 1]); + assert_eq!(b.values, vec![-21, -42]); +} diff --git a/marrow-convert/src/internal/builder/union.rs b/marrow-convert/src/internal/builder/union.rs new file mode 100644 index 0000000..7ccbbe7 --- /dev/null +++ b/marrow-convert/src/internal/builder/union.rs @@ -0,0 +1,284 @@ +use marrow::{ + array::{Array, UnionArray}, + datatypes::FieldMeta, +}; + +use crate::Result; + +use crate::internal::util::TupleLen; + +use super::ArrayBuilder; + +/// Helper struct to simplify implementing sparse Union builders +/// +/// When pushing a value the following invariants need to be observed: +/// +/// - A discriminator must be pushed to the `types` value +/// - A value must be pushed to each child field +#[derive(Debug)] +pub struct SparseUnionBuilder { + pub types: Vec, + pub meta: Vec, + pub children: C, +} + +macro_rules! impl_sparse_union_builder { + ($($el:ident,)*) => { + #[allow(non_snake_case, clippy::vec_init_then_push)] + impl<$($el: ArrayBuilder),*> ArrayBuilder for SparseUnionBuilder<($($el,)*)> { + fn push_default(&mut self) -> Result<()> { + let ($($el,)*) = &mut self.children; + $($el.push_default()?;)* + self.types.push(0); + Ok(()) + } + + fn build_array(&mut self) -> Result { + const { + assert!(<($($el,)*) as TupleLen>::LEN < (i8::MAX as usize)); + } + + let types = std::mem::take(&mut self.types); + let mut arrays = Vec::new(); + let ($($el,)*) = &mut self.children; + $(arrays.push($el.build_array()?);)* + + let fields = std::iter::zip(&self.meta, arrays) + .enumerate() + .map(|(i, (meta, array))| (i as i8, meta.clone(), array)) + .collect(); + + Ok(Array::Union(UnionArray { + types, + fields, + offsets: None, + })) + } + } + }; +} + +impl_sparse_union_builder!(A,); +impl_sparse_union_builder!(A, B,); +impl_sparse_union_builder!(A, B, C,); +impl_sparse_union_builder!(A, B, C, D,); +impl_sparse_union_builder!(A, B, C, D, E,); +impl_sparse_union_builder!(A, B, C, D, E, F,); +impl_sparse_union_builder!(A, B, C, D, E, F, G,); +impl_sparse_union_builder!(A, B, C, D, E, F, G, H,); +impl_sparse_union_builder!(A, B, C, D, E, F, G, H, I,); +impl_sparse_union_builder!(A, B, C, D, E, F, G, H, I, J,); +impl_sparse_union_builder!(A, B, C, D, E, F, G, H, I, J, K,); +impl_sparse_union_builder!(A, B, C, D, E, F, G, H, I, J, K, L,); +impl_sparse_union_builder!(A, B, C, D, E, F, G, H, I, J, K, L, M,); +impl_sparse_union_builder!(A, B, C, D, E, F, G, H, I, J, K, L, M, N,); +impl_sparse_union_builder!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O,); +impl_sparse_union_builder!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P,); + +/// Helper struct to simplify implementing dense Union builders +/// +/// When pushing a value the following invariants need to be observed: +/// +/// - A discriminator must be pushed to the `types` value +/// - A value must be pushed for the relevant variant +#[derive(Debug)] +pub struct DenseUnionBuilder { + pub types: DenseTypes, + pub offsets: Vec, + pub meta: Vec, + pub children: C, +} + +#[derive(Debug)] +pub struct DenseTypes { + types: Vec, + offsets: Vec, + current_offset: Vec, +} + +impl DenseTypes { + pub fn new(num_types: usize) -> Self { + Self { + types: Vec::new(), + offsets: Vec::new(), + current_offset: vec![0; num_types], + } + } + + pub fn take(&mut self) -> Self { + let num_types = self.current_offset.len(); + Self { + types: std::mem::take(&mut self.types), + offsets: std::mem::take(&mut self.offsets), + current_offset: std::mem::replace(&mut self.current_offset, vec![0; num_types]), + } + } + + pub fn push(&mut self, variant: i8) -> Result<()> { + assert!(variant >= 0); + + self.types.push(variant); + self.offsets.push(self.current_offset[variant as usize]); + self.current_offset[variant as usize] += 1; + Ok(()) + } +} + +macro_rules! impl_dense_union_builder { + ($first:ident, $($el:ident,)*) => { + #[allow(non_snake_case, clippy::vec_init_then_push)] + impl<$first: ArrayBuilder $(, $el: ArrayBuilder)*> ArrayBuilder for DenseUnionBuilder<($first, $($el,)*)> { + fn push_default(&mut self) -> Result<()> { + #[allow(unused_variables)] + let ($first, $($el,)*) = &mut self.children; + $first.push_default()?; + self.types.push(0)?; + Ok(()) + } + + fn build_array(&mut self) -> Result { + const { + assert!(<($first, $($el,)*) as TupleLen>::LEN < (i8::MAX as usize)); + } + + let DenseTypes { types, offsets, ..} = self.types.take(); + let mut arrays = Vec::new(); + let ($first, $($el,)*) = &mut self.children; + arrays.push($first.build_array()?); + $(arrays.push($el.build_array()?);)* + + let fields = std::iter::zip(&self.meta, arrays) + .enumerate() + .map(|(i, (meta, array))| (i as i8, meta.clone(), array)) + .collect(); + + Ok(Array::Union(UnionArray { + types, + offsets: Some(offsets), + fields, + })) + } + } + }; +} + +impl_dense_union_builder!(A,); +impl_dense_union_builder!(A, B,); +impl_dense_union_builder!(A, B, C,); +impl_dense_union_builder!(A, B, C, D,); +impl_dense_union_builder!(A, B, C, D, E,); +impl_dense_union_builder!(A, B, C, D, E, F,); +impl_dense_union_builder!(A, B, C, D, E, F, G,); +impl_dense_union_builder!(A, B, C, D, E, F, G, H,); +impl_dense_union_builder!(A, B, C, D, E, F, G, H, I,); +impl_dense_union_builder!(A, B, C, D, E, F, G, H, I, J,); +impl_dense_union_builder!(A, B, C, D, E, F, G, H, I, J, K,); +impl_dense_union_builder!(A, B, C, D, E, F, G, H, I, J, K, L,); +impl_dense_union_builder!(A, B, C, D, E, F, G, H, I, J, K, L, M,); +impl_dense_union_builder!(A, B, C, D, E, F, G, H, I, J, K, L, M, N,); +impl_dense_union_builder!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O,); +impl_dense_union_builder!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P,); + +#[test] +fn enum_example() { + use super::{ArrayPush, DefaultArrayBuilder}; + + enum Enum { + A(i32), + B(i64), + } + + // TODO: push into derive(ArrayPush) + const _: () = { + impl, B: ArrayPush> ArrayPush for SparseUnionBuilder<(A, B)> { + #[allow(non_snake_case)] + fn push_value(&mut self, value: &Enum) -> Result<()> { + match value { + Enum::A(inner) => { + self.types.push(0); + let (A, B) = &mut self.children; + A.push_value(inner)?; + B.push_default()?; + } + Enum::B(inner) => { + self.types.push(1); + let (A, B) = &mut self.children; + A.push_default()?; + B.push_value(inner)?; + } + } + Ok(()) + } + } + }; + + // TODO: push into derive(DefaultArrayBuilder) + const _: () = { + struct Builder( + SparseUnionBuilder<( + ::ArrayBuilder, + ::ArrayBuilder, + )>, + ); + + #[allow(non_snake_case)] + impl ArrayBuilder for Builder { + fn push_default(&mut self) -> Result<()> { + self.0.types.push(0); + + let (A, B) = &mut self.0.children; + A.push_default()?; + B.push_default()?; + Ok(()) + } + + fn build_array(&mut self) -> Result { + self.0.build_array() + } + } + + // TODO: in practice implement separately to allow indepdent derives + impl ArrayPush for Builder { + fn push_value(&mut self, value: &Enum) -> Result<()> { + self.0.push_value(value) + } + } + + impl DefaultArrayBuilder for Enum { + type ArrayBuilder = Builder; + + fn default_builder() -> Self::ArrayBuilder { + Builder(SparseUnionBuilder { + types: Vec::new(), + meta: vec![ + FieldMeta { + name: String::from("A"), + ..Default::default() + }, + FieldMeta { + name: String::from("B"), + ..Default::default() + }, + ], + children: ( + ::default_builder(), + ::default_builder(), + ), + }) + } + } + }; + + // the public API + let mut builder = Enum::default_builder(); + builder.push_value(&Enum::A(13)).unwrap(); + builder.push_value(&Enum::B(21)).unwrap(); + let array = builder.build_array().unwrap(); + + let [(_, _, a), (_, _, b)] = array.into_union_fields().expect("invalid array type"); + let a = a.into_int32().expect("invalid array type"); + let b = b.into_int64().expect("invalid array type"); + + assert_eq!(a.values, vec![13, 0]); + assert_eq!(b.values, vec![0, 21]); +} diff --git a/marrow-convert/src/internal/mod.rs b/marrow-convert/src/internal/mod.rs new file mode 100644 index 0000000..bf7bf7a --- /dev/null +++ b/marrow-convert/src/internal/mod.rs @@ -0,0 +1,4 @@ +pub mod builder; +pub mod type_info; +pub mod type_info_impls; +pub mod util; diff --git a/marrow-convert/src/typeinfo.rs b/marrow-convert/src/internal/type_info.rs similarity index 90% rename from marrow-convert/src/typeinfo.rs rename to marrow-convert/src/internal/type_info.rs index add8ba8..691aac8 100644 --- a/marrow-convert/src/typeinfo.rs +++ b/marrow-convert/src/internal/type_info.rs @@ -85,10 +85,11 @@ impl Context<'_> { self.options } - pub fn get_field(&self, name: &str) -> Result { + pub fn get_field(&self, name: &str) -> Result { self.nest(name, T::get_field) } + /// Call a function with a context for nested field pub fn nest) -> Result>( &self, name: &str, @@ -112,7 +113,7 @@ impl Context<'_> { } } -pub fn get_field(name: &str, options: &Options) -> Result { +pub fn get_field(name: &str, options: &Options) -> Result { let context = Context { path: "$", name, @@ -121,7 +122,7 @@ pub fn get_field(name: &str, options: &Options) -> Result { T::get_field(context) } -pub fn get_data_type(options: &Options) -> Result { +pub fn get_data_type(options: &Options) -> Result { Ok(get_field::("item", options)?.data_type) } @@ -133,7 +134,7 @@ pub struct LargeList(pub bool); /// /// The functions cannot be called directly. First construct a [Context], then call the /// corresponding methods. -pub trait TypeInfo { - /// See [crate::get_field] +pub trait DefaultArrayType { + /// See [get_field] fn get_field(context: Context<'_>) -> Result; } diff --git a/marrow-convert/src/impls/collections.rs b/marrow-convert/src/internal/type_info_impls/collections.rs similarity index 68% rename from marrow-convert/src/impls/collections.rs rename to marrow-convert/src/internal/type_info_impls/collections.rs index 3ed77b1..964d51f 100644 --- a/marrow-convert/src/impls/collections.rs +++ b/marrow-convert/src/internal/type_info_impls/collections.rs @@ -2,61 +2,64 @@ use std::collections::{BTreeMap, BTreeSet, BinaryHeap, HashMap, HashSet, LinkedL use marrow::datatypes::Field; -use crate::{Context, Result, TypeInfo}; +use crate::{ + Result, + types::{Context, DefaultArrayType}, +}; use super::utils::{new_list_field, new_map_field}; /// Map a vec to an Arrow List -impl TypeInfo for Vec { +impl DefaultArrayType for Vec { fn get_field(context: Context<'_>) -> Result { new_list_field::(context) } } /// Map a `VecDeque` to an Arrow List -impl TypeInfo for VecDeque { +impl DefaultArrayType for VecDeque { fn get_field(context: Context<'_>) -> Result { new_list_field::(context) } } /// Map a `LinkedList` to an Arrow List -impl TypeInfo for LinkedList { +impl DefaultArrayType for LinkedList { fn get_field(context: Context<'_>) -> Result { new_list_field::(context) } } /// Map a `BinaryHeap` to an Arrow List -impl TypeInfo for BinaryHeap { +impl DefaultArrayType for BinaryHeap { fn get_field(context: Context<'_>) -> Result { new_list_field::(context) } } /// Map a `BTreeSet` to an Arrow List -impl TypeInfo for BTreeSet { +impl DefaultArrayType for BTreeSet { fn get_field(context: Context<'_>) -> Result { new_list_field::(context) } } /// Map a `HashSet` to an Arrow List -impl TypeInfo for HashSet { +impl DefaultArrayType for HashSet { fn get_field(context: Context<'_>) -> Result { new_list_field::(context) } } /// Map a `BTreeMap` to an Arrow Map -impl TypeInfo for BTreeMap { +impl DefaultArrayType for BTreeMap { fn get_field(context: Context<'_>) -> Result { new_map_field::(context) } } /// Map a `HashMap` to an Arrow Map -impl TypeInfo for HashMap { +impl DefaultArrayType for HashMap { fn get_field(context: Context<'_>) -> Result { new_map_field::(context) } diff --git a/marrow-convert/src/impls/compounds.rs b/marrow-convert/src/internal/type_info_impls/compounds.rs similarity index 88% rename from marrow-convert/src/impls/compounds.rs rename to marrow-convert/src/internal/type_info_impls/compounds.rs index f36d3a9..baa9b72 100644 --- a/marrow-convert/src/impls/compounds.rs +++ b/marrow-convert/src/internal/type_info_impls/compounds.rs @@ -1,16 +1,19 @@ use marrow::datatypes::{DataType, Field}; -use crate::{Context, Result, TypeInfo}; +use crate::{ + Result, + types::{Context, DefaultArrayType}, +}; use super::utils::new_list_field; -impl TypeInfo for [T] { +impl DefaultArrayType for [T] { fn get_field(context: Context<'_>) -> Result { new_list_field::(context) } } -impl TypeInfo for [T; N] { +impl DefaultArrayType for [T; N] { fn get_field(context: Context<'_>) -> Result { let base_field = context.get_field::("element")?; let n = i32::try_from(N)?; @@ -34,7 +37,7 @@ impl TypeInfo for [T; N] { macro_rules! impl_tuples { ($( ( $($name:ident,)* ), )*) => { $( - impl<$($name: TypeInfo),*> TypeInfo for ( $($name,)* ) { + impl<$($name: DefaultArrayType),*> DefaultArrayType for ( $($name,)* ) { #[allow(unused_assignments, clippy::vec_init_then_push)] fn get_field(context: Context<'_>) -> Result { let mut idx = 0; diff --git a/marrow-convert/src/impls/ext/bigdecimal.rs b/marrow-convert/src/internal/type_info_impls/ext/bigdecimal.rs similarity index 57% rename from marrow-convert/src/impls/ext/bigdecimal.rs rename to marrow-convert/src/internal/type_info_impls/ext/bigdecimal.rs index 342a1f7..7e15e29 100644 --- a/marrow-convert/src/impls/ext/bigdecimal.rs +++ b/marrow-convert/src/internal/type_info_impls/ext/bigdecimal.rs @@ -1,9 +1,12 @@ use marrow::datatypes::{DataType, Field}; -use crate::TypeInfo; +use crate::{ + Result, + types::{Context, DefaultArrayType}, +}; -impl TypeInfo for bigdecimal::BigDecimal { - fn get_field(context: crate::Context<'_>) -> crate::Result { +impl DefaultArrayType for bigdecimal::BigDecimal { + fn get_field(context: Context<'_>) -> Result { Ok(Field { name: String::from(context.get_name()), // TODO: find better defaults diff --git a/marrow-convert/src/impls/ext/chrono.rs b/marrow-convert/src/internal/type_info_impls/ext/chrono.rs similarity index 60% rename from marrow-convert/src/impls/ext/chrono.rs rename to marrow-convert/src/internal/type_info_impls/ext/chrono.rs index f21ca02..771b88e 100644 --- a/marrow-convert/src/impls/ext/chrono.rs +++ b/marrow-convert/src/internal/type_info_impls/ext/chrono.rs @@ -1,10 +1,13 @@ use chrono::Utc; use marrow::datatypes::{DataType, Field, TimeUnit}; -use crate::TypeInfo; +use crate::{ + Result, + types::{Context, DefaultArrayType}, +}; -impl TypeInfo for chrono::NaiveDate { - fn get_field(context: crate::Context<'_>) -> crate::Result { +impl DefaultArrayType for chrono::NaiveDate { + fn get_field(context: Context<'_>) -> Result { Ok(Field { name: String::from(context.get_name()), data_type: DataType::Date32, @@ -13,8 +16,8 @@ impl TypeInfo for chrono::NaiveDate { } } -impl TypeInfo for chrono::NaiveTime { - fn get_field(context: crate::Context<'_>) -> crate::Result { +impl DefaultArrayType for chrono::NaiveTime { + fn get_field(context: Context<'_>) -> Result { Ok(Field { name: String::from(context.get_name()), data_type: DataType::Time32(TimeUnit::Millisecond), @@ -23,8 +26,8 @@ impl TypeInfo for chrono::NaiveTime { } } -impl TypeInfo for chrono::NaiveDateTime { - fn get_field(context: crate::Context<'_>) -> crate::Result { +impl DefaultArrayType for chrono::NaiveDateTime { + fn get_field(context: Context<'_>) -> Result { Ok(Field { name: String::from(context.get_name()), data_type: DataType::Timestamp(TimeUnit::Millisecond, None), @@ -33,8 +36,8 @@ impl TypeInfo for chrono::NaiveDateTime { } } -impl TypeInfo for chrono::DateTime { - fn get_field(context: crate::Context<'_>) -> crate::Result { +impl DefaultArrayType for chrono::DateTime { + fn get_field(context: Context<'_>) -> Result { Ok(Field { name: String::from(context.get_name()), data_type: DataType::Timestamp(TimeUnit::Millisecond, Some(String::from("UTC"))), diff --git a/marrow-convert/src/impls/ext/jiff.rs b/marrow-convert/src/internal/type_info_impls/ext/jiff.rs similarity index 62% rename from marrow-convert/src/impls/ext/jiff.rs rename to marrow-convert/src/internal/type_info_impls/ext/jiff.rs index 9a75104..8feea93 100644 --- a/marrow-convert/src/impls/ext/jiff.rs +++ b/marrow-convert/src/internal/type_info_impls/ext/jiff.rs @@ -1,9 +1,12 @@ use marrow::datatypes::{DataType, Field, TimeUnit}; -use crate::TypeInfo; +use crate::{ + Result, + types::{Context, DefaultArrayType}, +}; -impl TypeInfo for jiff::civil::Date { - fn get_field(context: crate::Context<'_>) -> crate::Result { +impl DefaultArrayType for jiff::civil::Date { + fn get_field(context: Context<'_>) -> Result { Ok(Field { name: String::from(context.get_name()), data_type: DataType::Date32, @@ -12,8 +15,8 @@ impl TypeInfo for jiff::civil::Date { } } -impl TypeInfo for jiff::civil::Time { - fn get_field(context: crate::Context<'_>) -> crate::Result { +impl DefaultArrayType for jiff::civil::Time { + fn get_field(context: Context<'_>) -> Result { Ok(Field { name: String::from(context.get_name()), data_type: DataType::Time32(TimeUnit::Millisecond), @@ -22,8 +25,8 @@ impl TypeInfo for jiff::civil::Time { } } -impl TypeInfo for jiff::Span { - fn get_field(context: crate::Context<'_>) -> crate::Result { +impl DefaultArrayType for jiff::Span { + fn get_field(context: Context<'_>) -> Result { Ok(Field { name: String::from(context.get_name()), data_type: DataType::Duration(TimeUnit::Millisecond), @@ -32,8 +35,8 @@ impl TypeInfo for jiff::Span { } } -impl TypeInfo for jiff::Timestamp { - fn get_field(context: crate::Context<'_>) -> crate::Result { +impl DefaultArrayType for jiff::Timestamp { + fn get_field(context: Context<'_>) -> Result { Ok(Field { name: String::from(context.get_name()), data_type: DataType::Timestamp(TimeUnit::Millisecond, None), diff --git a/marrow-convert/src/impls/ext/mod.rs b/marrow-convert/src/internal/type_info_impls/ext/mod.rs similarity index 100% rename from marrow-convert/src/impls/ext/mod.rs rename to marrow-convert/src/internal/type_info_impls/ext/mod.rs diff --git a/marrow-convert/src/impls/ext/uuid.rs b/marrow-convert/src/internal/type_info_impls/ext/uuid.rs similarity index 83% rename from marrow-convert/src/impls/ext/uuid.rs rename to marrow-convert/src/internal/type_info_impls/ext/uuid.rs index f37fb8b..8fd08e4 100644 --- a/marrow-convert/src/impls/ext/uuid.rs +++ b/marrow-convert/src/internal/type_info_impls/ext/uuid.rs @@ -2,9 +2,12 @@ use std::collections::HashMap; use marrow::datatypes::{DataType, Field}; -use crate::{Context, Result, TypeInfo}; +use crate::{ + Result, + types::{Context, DefaultArrayType}, +}; -impl TypeInfo for uuid::Uuid { +impl DefaultArrayType for uuid::Uuid { fn get_field(context: Context<'_>) -> Result { let mut metadata = HashMap::new(); metadata.insert("ARROW:extension:name".into(), "arrow.uuid".into()); diff --git a/marrow-convert/src/impls/mod.rs b/marrow-convert/src/internal/type_info_impls/mod.rs similarity index 100% rename from marrow-convert/src/impls/mod.rs rename to marrow-convert/src/internal/type_info_impls/mod.rs diff --git a/marrow-convert/src/impls/primitives.rs b/marrow-convert/src/internal/type_info_impls/primitives.rs similarity index 83% rename from marrow-convert/src/impls/primitives.rs rename to marrow-convert/src/internal/type_info_impls/primitives.rs index 5e667c6..e5d2bab 100644 --- a/marrow-convert/src/impls/primitives.rs +++ b/marrow-convert/src/internal/type_info_impls/primitives.rs @@ -3,14 +3,17 @@ use marrow::{ types::f16, }; -use crate::{Context, Result, TypeInfo}; +use crate::{ + Result, + types::{Context, DefaultArrayType}, +}; use super::utils::new_string_field; macro_rules! define_primitive { ($(($ty:ty, $dt:expr),)*) => { $( - impl TypeInfo for $ty { + impl DefaultArrayType for $ty { fn get_field(context: Context<'_>) -> Result { Ok(Field { name: context.get_name().to_owned(), @@ -39,7 +42,7 @@ define_primitive!( (char, DataType::UInt32), ); -impl TypeInfo for () { +impl DefaultArrayType for () { fn get_field(context: Context<'_>) -> Result { let _ = context; Ok(Field { @@ -51,19 +54,19 @@ impl TypeInfo for () { } } -impl TypeInfo for str { +impl DefaultArrayType for str { fn get_field(context: Context<'_>) -> Result { Ok(new_string_field(context)) } } -impl TypeInfo for &T { +impl DefaultArrayType for &T { fn get_field(context: Context<'_>) -> Result { T::get_field(context) } } -impl TypeInfo for &mut T { +impl DefaultArrayType for &mut T { fn get_field(context: Context<'_>) -> Result { T::get_field(context) } diff --git a/marrow-convert/src/impls/std.rs b/marrow-convert/src/internal/type_info_impls/std.rs similarity index 82% rename from marrow-convert/src/impls/std.rs rename to marrow-convert/src/internal/type_info_impls/std.rs index 2aec829..f10e5be 100644 --- a/marrow-convert/src/impls/std.rs +++ b/marrow-convert/src/internal/type_info_impls/std.rs @@ -10,18 +10,21 @@ use std::{ use marrow::datatypes::{DataType, Field, TimeUnit, UnionMode}; -use crate::{Context, Result, TypeInfo}; +use crate::{ + Result, + types::{Context, DefaultArrayType}, +}; use super::utils::new_string_field; -impl TypeInfo for String { +impl DefaultArrayType for String { fn get_field(context: Context<'_>) -> Result { Ok(new_string_field(context)) } } /// Map an option to a nullable field -impl TypeInfo for Option { +impl DefaultArrayType for Option { fn get_field(context: Context<'_>) -> Result { let mut base_field = T::get_field(context)?; base_field.nullable = true; @@ -30,7 +33,7 @@ impl TypeInfo for Option { } /// Map a `Result` to an Arrow Union with `Ok` and `Err` variants -impl TypeInfo for Result { +impl DefaultArrayType for Result { fn get_field(context: Context<'_>) -> Result { let ok = context.get_field::("Ok")?; let err = context.get_field::("Err")?; @@ -44,42 +47,42 @@ impl TypeInfo for Result { } /// Map a `Range` to an Arrow `FixedSizeList(.., 2)` -impl TypeInfo for Range { +impl DefaultArrayType for Range { fn get_field(context: Context<'_>) -> Result { <[T; 2]>::get_field(context) } } /// Map a `RangeInclusive` to an Arrow `FixedSizeList(.., 2)` -impl TypeInfo for RangeInclusive { +impl DefaultArrayType for RangeInclusive { fn get_field(context: Context<'_>) -> Result { <[T; 2]>::get_field(context) } } /// Map a `RangeTo` to the index type -impl TypeInfo for RangeTo { +impl DefaultArrayType for RangeTo { fn get_field(context: Context<'_>) -> Result { T::get_field(context) } } /// Map a `RangeToInclusive` to the index type -impl TypeInfo for RangeToInclusive { +impl DefaultArrayType for RangeToInclusive { fn get_field(context: Context<'_>) -> Result { T::get_field(context) } } /// Map a `RangeFrom` to the index type -impl TypeInfo for RangeFrom { +impl DefaultArrayType for RangeFrom { fn get_field(context: Context<'_>) -> Result { T::get_field(context) } } /// Map a `Bound` to an Arrow Union with variants `Included`, `Excluded`, `Unbounded` -impl TypeInfo for Bound { +impl DefaultArrayType for Bound { fn get_field(context: Context<'_>) -> Result { let included = context.get_field::("Included")?; let excluded = context.get_field::("Excluded")?; @@ -99,7 +102,7 @@ impl TypeInfo for Bound { macro_rules! impl_nonzero { ($($ty:ident),* $(,)?) => { $( - impl TypeInfo for NonZero<$ty> { + impl DefaultArrayType for NonZero<$ty> { fn get_field(context: Context<'_>) -> Result { <$ty>::get_field(context) } @@ -113,7 +116,7 @@ impl_nonzero!(u8, u16, u32, u64, i8, i16, i32, i64); macro_rules! impl_atomic { ($(($atomic:ident, $ty:ident)),* $(,)?) => { $( - impl TypeInfo for $atomic { + impl DefaultArrayType for $atomic { fn get_field(context: Context<'_>) -> Result { $ty::get_field(context) } @@ -134,7 +137,7 @@ impl_atomic!( (AtomicU64, u64), ); -impl TypeInfo for Duration { +impl DefaultArrayType for Duration { fn get_field(context: Context<'_>) -> Result { Ok(Field { name: String::from(context.get_name()), @@ -144,7 +147,7 @@ impl TypeInfo for Duration { } } -impl TypeInfo for SystemTime { +impl DefaultArrayType for SystemTime { fn get_field(context: Context<'_>) -> Result { Ok(Field { name: String::from(context.get_name()), diff --git a/marrow-convert/src/impls/utils.rs b/marrow-convert/src/internal/type_info_impls/utils.rs similarity index 81% rename from marrow-convert/src/impls/utils.rs rename to marrow-convert/src/internal/type_info_impls/utils.rs index 232e5ed..b5904de 100644 --- a/marrow-convert/src/impls/utils.rs +++ b/marrow-convert/src/internal/type_info_impls/utils.rs @@ -1,8 +1,9 @@ use marrow::datatypes::{DataType, Field}; use crate::{ - Context, Result, TypeInfo, - typeinfo::{DefaultStringType, LargeList}, + Result, + internal::type_info::{DefaultStringType, LargeList}, + types::{Context, DefaultArrayType}, }; pub fn new_field(name: &str, data_type: DataType) -> Field { @@ -23,7 +24,7 @@ pub fn new_string_field(context: Context<'_>) -> Field { new_field(context.get_name(), ty) } -pub fn new_list_field(context: Context<'_>) -> Result { +pub fn new_list_field(context: Context<'_>) -> Result { let larget_list = if let Some(LargeList(large_list)) = context.get_options().get() { *large_list } else { @@ -44,7 +45,9 @@ pub fn new_list_field(context: Context<'_>) -> Result { }) } -pub fn new_map_field(context: Context<'_>) -> Result { +pub fn new_map_field( + context: Context<'_>, +) -> Result { let key_field = context.get_field::("key")?; let value_field = context.get_field::("value")?; let entry_field = new_field("entry", DataType::Struct(vec![key_field, value_field])); diff --git a/marrow-convert/src/impls/wrappers.rs b/marrow-convert/src/internal/type_info_impls/wrappers.rs similarity index 61% rename from marrow-convert/src/impls/wrappers.rs rename to marrow-convert/src/internal/type_info_impls/wrappers.rs index e23342f..a0b0491 100644 --- a/marrow-convert/src/impls/wrappers.rs +++ b/marrow-convert/src/internal/type_info_impls/wrappers.rs @@ -8,9 +8,12 @@ use std::{ use marrow::datatypes::Field; -use crate::{Context, Result, TypeInfo}; +use crate::{ + Result, + types::{Context, DefaultArrayType}, +}; -impl TypeInfo for PhantomData { +impl DefaultArrayType for PhantomData { fn get_field(context: Context<'_>) -> Result { let mut field = T::get_field(context)?; field.nullable = true; @@ -18,49 +21,49 @@ impl TypeInfo for PhantomData { } } -impl TypeInfo for Box { +impl DefaultArrayType for Box { fn get_field(context: Context<'_>) -> Result { T::get_field(context) } } -impl TypeInfo for Cell { +impl DefaultArrayType for Cell { fn get_field(context: Context<'_>) -> Result { T::get_field(context) } } -impl TypeInfo for RefCell { +impl DefaultArrayType for RefCell { fn get_field(context: Context<'_>) -> Result { T::get_field(context) } } -impl TypeInfo for Mutex { +impl DefaultArrayType for Mutex { fn get_field(context: Context<'_>) -> Result { T::get_field(context) } } -impl TypeInfo for RwLock { +impl DefaultArrayType for RwLock { fn get_field(context: Context<'_>) -> Result { T::get_field(context) } } -impl TypeInfo for Rc { +impl DefaultArrayType for Rc { fn get_field(context: Context<'_>) -> Result { T::get_field(context) } } -impl TypeInfo for Arc { +impl DefaultArrayType for Arc { fn get_field(context: Context<'_>) -> Result { T::get_field(context) } } -impl<'a, T: TypeInfo + ToOwned + ?Sized + 'a> TypeInfo for Cow<'a, T> { +impl<'a, T: DefaultArrayType + ToOwned + ?Sized + 'a> DefaultArrayType for Cow<'a, T> { fn get_field(context: Context<'_>) -> Result { T::get_field(context) } diff --git a/marrow-convert/src/internal/util.rs b/marrow-convert/src/internal/util.rs new file mode 100644 index 0000000..61a24b0 --- /dev/null +++ b/marrow-convert/src/internal/util.rs @@ -0,0 +1,34 @@ +pub trait TupleLen { + const LEN: usize; +} + +macro_rules! impl_tuple_len { + ($head:ident, $($tail:ident,)*) => { + impl<$head, $($tail),*> TupleLen for ($head, $($tail,)*) { + const LEN: usize = 1 + <($($tail,)*) as TupleLen>::LEN; + } + }; + () => { + impl TupleLen for () { + const LEN: usize = 0; + } + }; +} + +impl_tuple_len!(); +impl_tuple_len!(A,); +impl_tuple_len!(A, B,); +impl_tuple_len!(A, B, C,); +impl_tuple_len!(A, B, C, D,); +impl_tuple_len!(A, B, C, D, E,); +impl_tuple_len!(A, B, C, D, E, F,); +impl_tuple_len!(A, B, C, D, E, F, G,); +impl_tuple_len!(A, B, C, D, E, F, G, H,); +impl_tuple_len!(A, B, C, D, E, F, G, H, I,); +impl_tuple_len!(A, B, C, D, E, F, G, H, I, J,); +impl_tuple_len!(A, B, C, D, E, F, G, H, I, J, K,); +impl_tuple_len!(A, B, C, D, E, F, G, H, I, J, K, L,); +impl_tuple_len!(A, B, C, D, E, F, G, H, I, J, K, L, M,); +impl_tuple_len!(A, B, C, D, E, F, G, H, I, J, K, L, M, N,); +impl_tuple_len!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O,); +impl_tuple_len!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P,); diff --git a/marrow-convert/src/lib.rs b/marrow-convert/src/lib.rs index 5a4941e..6635fb4 100644 --- a/marrow-convert/src/lib.rs +++ b/marrow-convert/src/lib.rs @@ -1,14 +1,46 @@ +#![deny(rustdoc::broken_intra_doc_links)] mod error; -mod impls; -mod typeinfo; +mod internal; #[cfg(test)] mod tests; -/// Derive [TypeInfo] for a given type -/// -/// Currently structs and enums with any type of lifetime parameters are supported. -pub use marrow_convert_derive::TypeInfo; - pub use error::{Error, Result}; -pub use typeinfo::{Context, Options, TypeInfo, get_data_type, get_field}; + +/// Traits to derive schema information from a type +pub mod types { + pub use crate::internal::type_info::{ + Context, DefaultArrayType, Options, get_data_type, get_field, + }; + + /// Derive [DefaultArrayType] for a given Rust type + /// + /// Currently structs and enums without type generic are supported. + pub use marrow_convert_derive::DefaultArrayType; +} + +/// Traits to allow constructing arrays from Rust objects +pub mod builder { + pub use crate::internal::builder::list::{LargeListBuilder, ListBuilder}; + pub use crate::internal::builder::primitive::{ + BooleanBuilder, Float16Builder, Float32Builder, Float64Builder, Int8Builder, Int16Builder, + Int32Builder, Int64Builder, NullBuilder, UInt8Builder, UInt16Builder, UInt32Builder, + UInt64Builder, + }; + pub use crate::internal::builder::{ArrayBuilder, ArrayPush, DefaultArrayBuilder}; + + /// Collect builders to simplify implementing custom builders for compound types (structs and + /// enums) + pub mod compound { + pub use crate::internal::builder::{ + r#struct::StructBuilder, + union::{DenseTypes, DenseUnionBuilder, SparseUnionBuilder}, + }; + } + + /// Derive [ArrayPush] for a given type + pub use marrow_convert_derive::ArrayPush; + + /// Derive [DefaultArrayBuilder] for a given type + pub use marrow_convert_derive::DefaultArrayBuilder; +} diff --git a/marrow-convert/src/tests.rs b/marrow-convert/src/tests.rs index e5f7e01..3a0efab 100644 --- a/marrow-convert/src/tests.rs +++ b/marrow-convert/src/tests.rs @@ -1,6 +1,6 @@ use marrow::datatypes::DataType; -use crate::{Options, get_data_type}; +use crate::types::{Options, get_data_type}; #[test] fn examples() { diff --git a/marrow-convert/tests/derive.rs b/marrow-convert/tests/derive.rs index 385dd90..92ddda0 100644 --- a/marrow-convert/tests/derive.rs +++ b/marrow-convert/tests/derive.rs @@ -2,11 +2,14 @@ use marrow::{ datatypes::{DataType, Field, TimeUnit, UnionMode}, types::f16, }; -use marrow_convert::{Context, Options, Result, TypeInfo}; +use marrow_convert::{ + Result, + types::{Context, DefaultArrayType, Options}, +}; #[test] fn example() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(dead_code)] struct S { a: i64, @@ -14,7 +17,7 @@ fn example() { } assert_eq!( - marrow_convert::get_data_type::(&Options::default()), + marrow_convert::types::get_data_type::(&Options::default()), Ok(DataType::Struct(vec![ Field { name: String::from("a"), @@ -32,7 +35,7 @@ fn example() { #[test] fn overwrites() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(dead_code)] struct S { a: i64, @@ -40,7 +43,7 @@ fn overwrites() { } assert_eq!( - marrow_convert::get_data_type::(&Options::default().overwrite( + marrow_convert::types::get_data_type::(&Options::default().overwrite( "$.b", Field { data_type: DataType::Binary, @@ -64,24 +67,24 @@ fn overwrites() { #[test] fn newtype() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(dead_code)] struct S(f16); assert_eq!( - marrow_convert::get_data_type::(&Options::default()), + marrow_convert::types::get_data_type::(&Options::default()), Ok(DataType::Float16) ); } #[test] fn tuple() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(dead_code)] struct S(u8, [u8; 4]); assert_eq!( - marrow_convert::get_data_type::(&Options::default()), + marrow_convert::types::get_data_type::(&Options::default()), Ok(DataType::Struct(vec![ Field { name: String::from("0"), @@ -99,10 +102,10 @@ fn tuple() { #[test] fn customize() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(dead_code)] struct S { - #[marrow_type_info(with = "timestamp_field")] + #[marrow(with = "timestamp_field")] a: i64, b: [u8; 4], } @@ -116,7 +119,7 @@ fn customize() { } assert_eq!( - marrow_convert::get_data_type::(&Options::default()), + marrow_convert::types::get_data_type::(&Options::default()), Ok(DataType::Struct(vec![ Field { name: String::from("a"), @@ -134,7 +137,7 @@ fn customize() { #[test] fn fieldless_union() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(dead_code)] enum E { A, @@ -143,7 +146,7 @@ fn fieldless_union() { } assert_eq!( - marrow_convert::get_data_type::(&Options::default()), + marrow_convert::types::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( @@ -181,14 +184,14 @@ fn fieldless_union() { #[test] fn new_type_enum() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(dead_code)] enum Enum { Struct(Struct), Int64(i64), } - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(dead_code)] struct Struct { a: bool, @@ -196,7 +199,7 @@ fn new_type_enum() { } assert_eq!( - marrow_convert::get_data_type::(&Options::default()), + marrow_convert::types::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( @@ -236,7 +239,7 @@ fn new_type_enum() { #[test] fn new_tuple_enum() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(dead_code)] enum Enum { Int64(i64), @@ -244,7 +247,7 @@ fn new_tuple_enum() { } assert_eq!( - marrow_convert::get_data_type::(&Options::default()), + marrow_convert::types::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( @@ -282,7 +285,7 @@ fn new_tuple_enum() { #[test] fn new_struct_enum() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(dead_code)] enum Enum { Int64(i64), @@ -290,7 +293,7 @@ fn new_struct_enum() { } assert_eq!( - marrow_convert::get_data_type::(&Options::default()), + marrow_convert::types::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( @@ -328,14 +331,14 @@ fn new_struct_enum() { #[test] fn const_generics() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(unused)] struct Struct { data: [u8; N], } assert_eq!( - marrow_convert::get_data_type::>(&Options::default()), + marrow_convert::types::get_data_type::>(&Options::default()), Ok(DataType::Struct(vec![Field { name: String::from("data"), data_type: DataType::FixedSizeBinary(4), @@ -347,7 +350,7 @@ fn const_generics() { #[test] fn liftime_generics() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(unused)] struct Struct<'a, 'b> { a: &'a u8, @@ -355,7 +358,7 @@ fn liftime_generics() { } assert_eq!( - marrow_convert::get_data_type::(&Options::default()), + marrow_convert::types::get_data_type::(&Options::default()), Ok(DataType::Struct(vec![ Field { name: String::from("a"), @@ -373,7 +376,7 @@ fn liftime_generics() { #[test] fn liftime_generics_with_bounds() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(unused)] struct Struct<'a, 'b: 'a> { a: &'a u8, @@ -381,7 +384,7 @@ fn liftime_generics_with_bounds() { } assert_eq!( - marrow_convert::get_data_type::(&Options::default()), + marrow_convert::types::get_data_type::(&Options::default()), Ok(DataType::Struct(vec![ Field { name: String::from("a"), @@ -399,7 +402,7 @@ fn liftime_generics_with_bounds() { #[test] fn liftime_generics_with_where_clause() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(unused)] struct Struct<'a, 'b> where @@ -410,7 +413,7 @@ fn liftime_generics_with_where_clause() { } assert_eq!( - marrow_convert::get_data_type::(&Options::default()), + marrow_convert::types::get_data_type::(&Options::default()), Ok(DataType::Struct(vec![ Field { name: String::from("a"), @@ -428,14 +431,14 @@ fn liftime_generics_with_where_clause() { #[test] fn enums_const_generics() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(unused)] enum Enum { Data([u8; N]), } assert_eq!( - marrow_convert::get_data_type::>(&Options::default()), + marrow_convert::types::get_data_type::>(&Options::default()), Ok(DataType::Union( vec![( 0, @@ -453,7 +456,7 @@ fn enums_const_generics() { #[test] fn enums_with_liftime_generics() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(unused)] enum Enum<'a, 'b> { A(&'a u8), @@ -461,7 +464,7 @@ fn enums_with_liftime_generics() { } assert_eq!( - marrow_convert::get_data_type::(&Options::default()), + marrow_convert::types::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( @@ -488,7 +491,7 @@ fn enums_with_liftime_generics() { #[test] fn enum_liftime_generics_with_bounds() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(unused)] enum Enum<'a, 'b: 'a> { A(&'a u8), @@ -496,7 +499,7 @@ fn enum_liftime_generics_with_bounds() { } assert_eq!( - marrow_convert::get_data_type::(&Options::default()), + marrow_convert::types::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( @@ -523,7 +526,7 @@ fn enum_liftime_generics_with_bounds() { #[test] fn enum_liftime_generics_with_where_clause() { - #[derive(TypeInfo)] + #[derive(DefaultArrayType)] #[allow(unused)] enum Enum<'a, 'b> where @@ -534,7 +537,7 @@ fn enum_liftime_generics_with_where_clause() { } assert_eq!( - marrow_convert::get_data_type::(&Options::default()), + marrow_convert::types::get_data_type::(&Options::default()), Ok(DataType::Union( vec![ ( diff --git a/marrow/src/array.rs b/marrow/src/array.rs index 7f79908..f21cd84 100644 --- a/marrow/src/array.rs +++ b/marrow/src/array.rs @@ -255,6 +255,76 @@ impl Array { Self::Union(array) => View::Union(array.as_view()), } } + + /// Extract the underlying primitive array if the array is of type Int8 + pub fn into_int8(self) -> Result, Array> { + match self { + Self::Int8(res) => Ok(res), + this => Err(this), + } + } + + /// Extract the underlying primitive array if the array is of type Int32 + pub fn into_int32(self) -> Result, Array> { + match self { + Self::Int32(res) => Ok(res), + this => Err(this), + } + } + + /// Extract the underlying primitive array if the array is of type Int64 + pub fn into_int64(self) -> Result, Array> { + match self { + Self::Int64(res) => Ok(res), + this => Err(this), + } + } + + /// Extract the underlying arrays of a struct array + pub fn into_struct_fields(self) -> Result<[(FieldMeta, Array); N], Array> { + match self { + Array::Struct(this) => { + let StructArray { + len, + validity, + fields, + } = this; + + match <[(FieldMeta, Array); N]>::try_from(fields) { + Ok(fields) => Ok(fields), + // rebuild the original array + Err(fields) => Err(Array::Struct(StructArray { + len, + validity, + fields, + })), + } + } + this => Err(this), + } + } + + /// Extract the underlying arrays of a union array + pub fn into_union_fields(self) -> Result<[(i8, FieldMeta, Array); N], Array> { + match self { + Array::Union(this) => { + let UnionArray { + types, + offsets, + fields, + } = this; + match <[(i8, FieldMeta, Array); N]>::try_from(fields) { + Ok(fields) => Ok(fields), + Err(fields) => Err(Array::Union(UnionArray { + types, + offsets, + fields, + })), + } + } + this => Err(this), + } + } } /// An array without data From 8bfc8d12b1784822bfdab02d70f5760628bb9932 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Tue, 15 Apr 2025 21:58:05 +0200 Subject: [PATCH 12/12] Implement initial version of derive for ArrayPush, DefaultArrayBuilder --- marrow-convert-derive/src/array_push.rs | 93 ++- marrow-convert-derive/src/default_builder.rs | 105 +++- marrow-convert-derive/src/lib.rs | 8 +- marrow-convert-derive/src/type_info.rs | 8 +- marrow-convert/Design.md | 9 + marrow-convert/src/internal/builder/struct.rs | 2 +- marrow-convert/src/lib.rs | 6 + marrow-convert/tests/derive.rs | 564 +----------------- marrow-convert/tests/derive_tests/mod.rs | 3 + .../tests/derive_tests/test_array_push.rs | 37 ++ .../test_default_array_builder.rs | 20 + .../tests/derive_tests/test_type_info.rs | 563 +++++++++++++++++ 12 files changed, 843 insertions(+), 575 deletions(-) create mode 100644 marrow-convert/Design.md create mode 100644 marrow-convert/tests/derive_tests/mod.rs create mode 100644 marrow-convert/tests/derive_tests/test_array_push.rs create mode 100644 marrow-convert/tests/derive_tests/test_default_array_builder.rs create mode 100644 marrow-convert/tests/derive_tests/test_type_info.rs diff --git a/marrow-convert-derive/src/array_push.rs b/marrow-convert-derive/src/array_push.rs index d3f5a12..c3c5043 100644 --- a/marrow-convert-derive/src/array_push.rs +++ b/marrow-convert-derive/src/array_push.rs @@ -1 +1,92 @@ - +use quote::{format_ident, quote}; +use syn::{ + Data, DataEnum, DataStruct, DeriveInput, Fields, GenericParam, Ident, Type, spanned::Spanned, +}; + +pub fn derive_array_push(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream { + let input: DeriveInput = syn::parse2(input).unwrap(); + + if input + .generics + .params + .iter() + .any(|p| matches!(p, GenericParam::Type(_))) + { + panic!("Deriving TypeInfo for generics with type parameters is not supported") + } + + match &input.data { + Data::Struct(data) => derive_for_struct(&input, data), + Data::Enum(data) => derive_for_enum(&input, data), + Data::Union(_) => panic!("Deriving TypeInfo for unions is not supported"), + } +} + +fn derive_for_struct(input: &DeriveInput, data: &DataStruct) -> proc_macro2::TokenStream { + if data.fields.len() >= 16 { + panic!("Only structs with at most 16 fields are supported"); + } + + let ident = &input.ident; + let generics = &input.generics; + let generics = if generics.params.is_empty() { + quote! {} + } else { + quote! { #generics, } + }; + + let mut field_defs = Vec::new(); + let mut field_uses = Vec::new(); + let mut field_push = Vec::new(); + + for (idx, (name, ty)) in get_fields_and_names(&data.fields).into_iter().enumerate() { + let ident = format_ident!("T{idx}"); + field_defs.push(quote! { #ident : ::marrow_convert::builder::ArrayPush<#ty> }); + field_uses.push(quote! { #ident }); + field_push.push( + quote! { ::marrow_convert::builder::ArrayPush::push_value(#ident, &value.#name )?; }, + ); + } + + quote! { + const _: () = { + impl<#generics #(#field_defs),*> ::marrow_convert::builder::ArrayPush<#ident> for ::marrow_convert::builder::compound::StructBuilder<(#(#field_uses,)*)> { + fn push_value(&mut self, value: &#ident) -> ::marrow_convert::Result<()> { + self.len += 1; + let (#(#field_uses,)*) = &mut self.children; + #(#field_push)* + Ok(()) + } + } + }; + } +} + +pub fn get_fields_and_names(fields: &Fields) -> Vec<(Ident, Type)> { + let mut result = Vec::new(); + match fields { + Fields::Named(fields) => { + for field in &fields.named { + let ident = field.ident.clone().expect("Named field without ident"); + let ty = field.ty.clone(); + result.push((ident, ty)); + } + } + Fields::Unnamed(fields) => { + for (idx, field) in fields.unnamed.iter().enumerate() { + result.push(( + Ident::new(&idx.to_string(), field.ty.span()), + field.ty.clone(), + )); + } + } + Fields::Unit => unimplemented!("Unit structs are currently not implemented"), + } + + result +} + +fn derive_for_enum(input: &DeriveInput, data: &DataEnum) -> proc_macro2::TokenStream { + let _ = (input, data); + todo!() +} diff --git a/marrow-convert-derive/src/default_builder.rs b/marrow-convert-derive/src/default_builder.rs index d3f5a12..d825cfa 100644 --- a/marrow-convert-derive/src/default_builder.rs +++ b/marrow-convert-derive/src/default_builder.rs @@ -1 +1,104 @@ - +use quote::{format_ident, quote}; +use syn::{Data, DataEnum, DataStruct, DeriveInput, GenericParam, LitStr}; + +use super::array_push::get_fields_and_names; + +pub fn derive_default_builder(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream { + let input: DeriveInput = syn::parse2(input).unwrap(); + + if input + .generics + .params + .iter() + .any(|p| matches!(p, GenericParam::Type(_))) + { + panic!("Deriving TypeInfo for generics with type parameters is not supported") + } + + match &input.data { + Data::Struct(data) => derive_for_struct(&input, data), + Data::Enum(data) => derive_for_enum(&input, data), + Data::Union(_) => panic!("Deriving TypeInfo for unions is not supported"), + } +} + +fn derive_for_struct(input: &DeriveInput, data: &DataStruct) -> proc_macro2::TokenStream { + if data.fields.len() >= 16 { + panic!("Only structs with at most 16 fields are supported"); + } + + let ident = &input.ident; + + let builder_ident = format_ident!("{ident}Builder"); + + let mut field_uses = Vec::new(); + let mut field_push = Vec::new(); + let mut field_builders = Vec::new(); + let mut field_metas = Vec::new(); + let mut field_inits = Vec::new(); + + for (idx, (name, ty)) in get_fields_and_names(&data.fields).into_iter().enumerate() { + let ident = format_ident!("t{idx}"); + field_uses.push(quote! { #ident }); + field_push.push( + quote! { ::marrow_convert::builder::ArrayPush::push_value(#ident, &value.#name )?; }, + ); + + let field_name = LitStr::new(&name.to_string(), name.span()); + + field_builders + .push(quote! { <#ty as ::marrow_convert::builder::DefaultArrayBuilder>::ArrayBuilder }); + field_metas.push(quote! { + ::marrow::datatypes::FieldMeta { + name: String::from(#field_name), + ..::std::default::Default::default() + } + }); + field_inits.push(quote! { + <#ty as ::marrow_convert::builder::DefaultArrayBuilder>::default_builder() + }) + } + + return quote! { + const _: () = { + pub struct #builder_ident(::marrow_convert::builder::compound::StructBuilder<(#(#field_builders,)*)>); + + impl ::marrow_convert::builder::DefaultArrayBuilder for #ident { + type ArrayBuilder = #builder_ident; + + fn default_builder() -> Self::ArrayBuilder { + #builder_ident(::marrow_convert::builder::compound::StructBuilder { + len: 0, + meta: vec![#(#field_metas),*], + children: (#(#field_inits,)*), + }) + } + } + + impl ::marrow_convert::builder::ArrayBuilder for #builder_ident { + fn push_default(&mut self) -> ::marrow_convert::Result<()> { + self.0.push_default() + } + + fn build_array(&mut self) -> ::marrow_convert::Result<::marrow::array::Array> { + self.0.build_array() + } + } + + impl ::marrow_convert::builder::ArrayPush<#ident> for #builder_ident { + fn push_value(&mut self, value: &#ident) -> ::marrow_convert::Result<()> { + self.0.len += 1; + let (#(#field_uses,)*) = &mut self.0.children; + #(#field_push)* + Ok(()) + } + } + + }; + }; +} + +fn derive_for_enum(input: &DeriveInput, data: &DataEnum) -> proc_macro2::TokenStream { + let _ = (input, data); + todo!() +} diff --git a/marrow-convert-derive/src/lib.rs b/marrow-convert-derive/src/lib.rs index 9e1f782..9dc0e0d 100644 --- a/marrow-convert-derive/src/lib.rs +++ b/marrow-convert-derive/src/lib.rs @@ -6,17 +6,15 @@ mod type_info; #[proc_macro_derive(DefaultArrayType, attributes(marrow))] pub fn derive_type_info(input: TokenStream) -> TokenStream { - type_info::derive_type_info_impl(input.into()).into() + type_info::derive_type_info(input.into()).into() } #[proc_macro_derive(ArrayPush, attributes(marrow))] pub fn derive_array_push(input: TokenStream) -> TokenStream { - std::mem::drop(input); - unimplemented!() + array_push::derive_array_push(input.into()).into() } #[proc_macro_derive(DefaultArrayBuilder, attributes(marrow))] pub fn derive_default_builder(input: TokenStream) -> TokenStream { - std::mem::drop(input); - unimplemented!() + default_builder::derive_default_builder(input.into()).into() } diff --git a/marrow-convert-derive/src/type_info.rs b/marrow-convert-derive/src/type_info.rs index 8570c46..4cb8a3f 100644 --- a/marrow-convert-derive/src/type_info.rs +++ b/marrow-convert-derive/src/type_info.rs @@ -4,7 +4,7 @@ use syn::{ Lit, LitStr, Meta, Token, punctuated::Punctuated, spanned::Spanned, }; -pub fn derive_type_info_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream { +pub fn derive_type_info(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream { let input: DeriveInput = syn::parse2(input).unwrap(); if input @@ -300,7 +300,7 @@ enum NameSource { #[test] #[should_panic(expected = "Deriving TypeInfo for generics with type parameters is not supported")] fn reject_unsupported() { - derive_type_info_impl(quote! { + derive_type_info(quote! { struct Example { field: T, } @@ -309,7 +309,7 @@ fn reject_unsupported() { #[test] fn lifetimes_are_supported() { - derive_type_info_impl(quote! { + derive_type_info(quote! { struct Example<'a> { field: &'a i64, } @@ -318,7 +318,7 @@ fn lifetimes_are_supported() { #[test] fn const_params_are_supported() { - derive_type_info_impl(quote! { + derive_type_info(quote! { struct Example { field: [u8; N], } diff --git a/marrow-convert/Design.md b/marrow-convert/Design.md new file mode 100644 index 0000000..684e637 --- /dev/null +++ b/marrow-convert/Design.md @@ -0,0 +1,9 @@ +# Overall design + +- M:N relationship between Rust and Arrow types + - A single Rust type can be converted into different Arrow types + - Different Rust types can be converted into the same Arrow type + - E.g., `jiff::Timestamp` and `chrono::DateTime` can both be converted to the Arrow + `Timestamp` type + - E.g., `jiff::Timestamp` can both be converted to the Arrow `Timestamp` and the Arrow `Utf8` typ +- Allow to fully specify the builders at compile time \ No newline at end of file diff --git a/marrow-convert/src/internal/builder/struct.rs b/marrow-convert/src/internal/builder/struct.rs index 8dcf2a6..86e55fe 100644 --- a/marrow-convert/src/internal/builder/struct.rs +++ b/marrow-convert/src/internal/builder/struct.rs @@ -54,7 +54,7 @@ macro_rules! impl_struct_builder { }; } -// TODO: is a struct with fields valid? +// TODO: is a struct without fields valid? impl_struct_builder!(A,); impl_struct_builder!(A, B,); impl_struct_builder!(A, B, C,); diff --git a/marrow-convert/src/lib.rs b/marrow-convert/src/lib.rs index 6635fb4..a50b010 100644 --- a/marrow-convert/src/lib.rs +++ b/marrow-convert/src/lib.rs @@ -44,3 +44,9 @@ pub mod builder { /// Derive [DefaultArrayBuilder] for a given type pub use marrow_convert_derive::DefaultArrayBuilder; } + +/// Additional documentation +pub mod docs { + #[doc = include_str!("../Design.md")] + pub mod design {} +} diff --git a/marrow-convert/tests/derive.rs b/marrow-convert/tests/derive.rs index 92ddda0..b68352d 100644 --- a/marrow-convert/tests/derive.rs +++ b/marrow-convert/tests/derive.rs @@ -1,563 +1 @@ -use marrow::{ - datatypes::{DataType, Field, TimeUnit, UnionMode}, - types::f16, -}; -use marrow_convert::{ - Result, - types::{Context, DefaultArrayType, Options}, -}; - -#[test] -fn example() { - #[derive(DefaultArrayType)] - #[allow(dead_code)] - struct S { - a: i64, - b: [u8; 4], - } - - assert_eq!( - marrow_convert::types::get_data_type::(&Options::default()), - Ok(DataType::Struct(vec![ - Field { - name: String::from("a"), - data_type: DataType::Int64, - ..Default::default() - }, - Field { - name: String::from("b"), - data_type: DataType::FixedSizeBinary(4), - ..Default::default() - } - ])) - ); -} - -#[test] -fn overwrites() { - #[derive(DefaultArrayType)] - #[allow(dead_code)] - struct S { - a: i64, - b: [u8; 4], - } - - assert_eq!( - marrow_convert::types::get_data_type::(&Options::default().overwrite( - "$.b", - Field { - data_type: DataType::Binary, - ..Field::default() - } - )), - Ok(DataType::Struct(vec![ - Field { - name: String::from("a"), - data_type: DataType::Int64, - ..Default::default() - }, - Field { - name: String::from("b"), - data_type: DataType::Binary, - ..Default::default() - } - ])) - ); -} - -#[test] -fn newtype() { - #[derive(DefaultArrayType)] - #[allow(dead_code)] - struct S(f16); - - assert_eq!( - marrow_convert::types::get_data_type::(&Options::default()), - Ok(DataType::Float16) - ); -} - -#[test] -fn tuple() { - #[derive(DefaultArrayType)] - #[allow(dead_code)] - struct S(u8, [u8; 4]); - - assert_eq!( - marrow_convert::types::get_data_type::(&Options::default()), - Ok(DataType::Struct(vec![ - Field { - name: String::from("0"), - data_type: DataType::UInt8, - ..Field::default() - }, - Field { - name: String::from("1"), - data_type: DataType::FixedSizeBinary(4), - ..Field::default() - }, - ])) - ); -} - -#[test] -fn customize() { - #[derive(DefaultArrayType)] - #[allow(dead_code)] - struct S { - #[marrow(with = "timestamp_field")] - a: i64, - b: [u8; 4], - } - - fn timestamp_field(context: Context<'_>) -> Result { - Ok(Field { - name: String::from(context.get_name()), - data_type: DataType::Timestamp(TimeUnit::Millisecond, None), - ..Default::default() - }) - } - - assert_eq!( - marrow_convert::types::get_data_type::(&Options::default()), - Ok(DataType::Struct(vec![ - Field { - name: String::from("a"), - data_type: DataType::Timestamp(TimeUnit::Millisecond, None), - ..Default::default() - }, - Field { - name: String::from("b"), - data_type: DataType::FixedSizeBinary(4), - ..Default::default() - } - ])) - ); -} - -#[test] -fn fieldless_union() { - #[derive(DefaultArrayType)] - #[allow(dead_code)] - enum E { - A, - B, - C, - } - - assert_eq!( - marrow_convert::types::get_data_type::(&Options::default()), - Ok(DataType::Union( - vec![ - ( - 0, - Field { - name: String::from("A"), - data_type: DataType::Null, - nullable: true, - metadata: Default::default(), - } - ), - ( - 1, - Field { - name: String::from("B"), - data_type: DataType::Null, - nullable: true, - metadata: Default::default(), - } - ), - ( - 2, - Field { - name: String::from("C"), - data_type: DataType::Null, - nullable: true, - metadata: Default::default(), - } - ), - ], - UnionMode::Dense - )) - ); -} - -#[test] -fn new_type_enum() { - #[derive(DefaultArrayType)] - #[allow(dead_code)] - enum Enum { - Struct(Struct), - Int64(i64), - } - - #[derive(DefaultArrayType)] - #[allow(dead_code)] - struct Struct { - a: bool, - b: (), - } - - assert_eq!( - marrow_convert::types::get_data_type::(&Options::default()), - Ok(DataType::Union( - vec![ - ( - 0, - Field { - name: String::from("Struct"), - data_type: DataType::Struct(vec![ - Field { - name: String::from("a"), - data_type: DataType::Boolean, - ..Default::default() - }, - Field { - name: String::from("b"), - data_type: DataType::Null, - nullable: true, - ..Default::default() - }, - ]), - nullable: false, - metadata: Default::default(), - } - ), - ( - 1, - Field { - name: String::from("Int64"), - data_type: DataType::Int64, - ..Default::default() - } - ), - ], - UnionMode::Dense - )) - ); -} - -#[test] -fn new_tuple_enum() { - #[derive(DefaultArrayType)] - #[allow(dead_code)] - enum Enum { - Int64(i64), - Tuple(i8, u32), - } - - assert_eq!( - marrow_convert::types::get_data_type::(&Options::default()), - Ok(DataType::Union( - vec![ - ( - 0, - Field { - name: String::from("Int64"), - data_type: DataType::Int64, - ..Field::default() - } - ), - ( - 1, - Field { - name: String::from("Tuple"), - data_type: DataType::Struct(vec![ - Field { - name: String::from("0"), - data_type: DataType::Int8, - ..Field::default() - }, - Field { - name: String::from("1"), - data_type: DataType::UInt32, - ..Field::default() - }, - ]), - ..Field::default() - } - ), - ], - UnionMode::Dense - )) - ); -} - -#[test] -fn new_struct_enum() { - #[derive(DefaultArrayType)] - #[allow(dead_code)] - enum Enum { - Int64(i64), - Struct { a: f32, b: String }, - } - - assert_eq!( - marrow_convert::types::get_data_type::(&Options::default()), - Ok(DataType::Union( - vec![ - ( - 0, - Field { - name: String::from("Int64"), - data_type: DataType::Int64, - ..Field::default() - } - ), - ( - 1, - Field { - name: String::from("Struct"), - data_type: DataType::Struct(vec![ - Field { - name: String::from("a"), - data_type: DataType::Float32, - ..Field::default() - }, - Field { - name: String::from("b"), - data_type: DataType::LargeUtf8, - ..Field::default() - }, - ]), - ..Field::default() - } - ), - ], - UnionMode::Dense - )) - ); -} - -#[test] -fn const_generics() { - #[derive(DefaultArrayType)] - #[allow(unused)] - struct Struct { - data: [u8; N], - } - - assert_eq!( - marrow_convert::types::get_data_type::>(&Options::default()), - Ok(DataType::Struct(vec![Field { - name: String::from("data"), - data_type: DataType::FixedSizeBinary(4), - nullable: false, - metadata: Default::default(), - },])) - ); -} - -#[test] -fn liftime_generics() { - #[derive(DefaultArrayType)] - #[allow(unused)] - struct Struct<'a, 'b> { - a: &'a u8, - b: &'b u16, - } - - assert_eq!( - marrow_convert::types::get_data_type::(&Options::default()), - Ok(DataType::Struct(vec![ - Field { - name: String::from("a"), - data_type: DataType::UInt8, - ..Default::default() - }, - Field { - name: String::from("b"), - data_type: DataType::UInt16, - ..Default::default() - }, - ])) - ); -} - -#[test] -fn liftime_generics_with_bounds() { - #[derive(DefaultArrayType)] - #[allow(unused)] - struct Struct<'a, 'b: 'a> { - a: &'a u8, - b: &'b u16, - } - - assert_eq!( - marrow_convert::types::get_data_type::(&Options::default()), - Ok(DataType::Struct(vec![ - Field { - name: String::from("a"), - data_type: DataType::UInt8, - ..Default::default() - }, - Field { - name: String::from("b"), - data_type: DataType::UInt16, - ..Default::default() - }, - ])) - ); -} - -#[test] -fn liftime_generics_with_where_clause() { - #[derive(DefaultArrayType)] - #[allow(unused)] - struct Struct<'a, 'b> - where - 'a: 'b, - { - a: &'a u8, - b: &'b u16, - } - - assert_eq!( - marrow_convert::types::get_data_type::(&Options::default()), - Ok(DataType::Struct(vec![ - Field { - name: String::from("a"), - data_type: DataType::UInt8, - ..Default::default() - }, - Field { - name: String::from("b"), - data_type: DataType::UInt16, - ..Default::default() - }, - ])) - ); -} - -#[test] -fn enums_const_generics() { - #[derive(DefaultArrayType)] - #[allow(unused)] - enum Enum { - Data([u8; N]), - } - - assert_eq!( - marrow_convert::types::get_data_type::>(&Options::default()), - Ok(DataType::Union( - vec![( - 0, - Field { - name: String::from("Data"), - data_type: DataType::FixedSizeBinary(4), - nullable: false, - metadata: Default::default(), - } - ),], - UnionMode::Dense - )), - ); -} - -#[test] -fn enums_with_liftime_generics() { - #[derive(DefaultArrayType)] - #[allow(unused)] - enum Enum<'a, 'b> { - A(&'a u8), - B(&'b u16), - } - - assert_eq!( - marrow_convert::types::get_data_type::(&Options::default()), - Ok(DataType::Union( - vec![ - ( - 0, - Field { - name: String::from("A"), - data_type: DataType::UInt8, - ..Default::default() - } - ), - ( - 1, - Field { - name: String::from("B"), - data_type: DataType::UInt16, - ..Default::default() - } - ), - ], - UnionMode::Dense - )) - ); -} - -#[test] -fn enum_liftime_generics_with_bounds() { - #[derive(DefaultArrayType)] - #[allow(unused)] - enum Enum<'a, 'b: 'a> { - A(&'a u8), - B(&'b u16), - } - - assert_eq!( - marrow_convert::types::get_data_type::(&Options::default()), - Ok(DataType::Union( - vec![ - ( - 0, - Field { - name: String::from("A"), - data_type: DataType::UInt8, - ..Default::default() - } - ), - ( - 1, - Field { - name: String::from("B"), - data_type: DataType::UInt16, - ..Default::default() - } - ), - ], - UnionMode::Dense - )) - ); -} - -#[test] -fn enum_liftime_generics_with_where_clause() { - #[derive(DefaultArrayType)] - #[allow(unused)] - enum Enum<'a, 'b> - where - 'a: 'b, - { - A(&'a u8), - B(&'b u16), - } - - assert_eq!( - marrow_convert::types::get_data_type::(&Options::default()), - Ok(DataType::Union( - vec![ - ( - 0, - Field { - name: String::from("A"), - data_type: DataType::UInt8, - ..Default::default() - } - ), - ( - 1, - Field { - name: String::from("B"), - data_type: DataType::UInt16, - ..Default::default() - } - ), - ], - UnionMode::Dense - )) - ); -} +mod derive_tests; diff --git a/marrow-convert/tests/derive_tests/mod.rs b/marrow-convert/tests/derive_tests/mod.rs new file mode 100644 index 0000000..d511cdb --- /dev/null +++ b/marrow-convert/tests/derive_tests/mod.rs @@ -0,0 +1,3 @@ +mod test_array_push; +mod test_default_array_builder; +mod test_type_info; diff --git a/marrow-convert/tests/derive_tests/test_array_push.rs b/marrow-convert/tests/derive_tests/test_array_push.rs new file mode 100644 index 0000000..18dea3d --- /dev/null +++ b/marrow-convert/tests/derive_tests/test_array_push.rs @@ -0,0 +1,37 @@ +use marrow::datatypes::FieldMeta; +use marrow_convert::builder::{ArrayBuilder, ArrayPush}; + +#[test] +fn example() { + #[derive(marrow_convert::builder::ArrayPush)] + struct S { + a: i32, + b: i64, + } + + let mut builder = marrow_convert::builder::compound::StructBuilder { + len: 0, + meta: vec![ + FieldMeta { + name: String::from("a"), + ..Default::default() + }, + FieldMeta { + name: String::from("b"), + ..Default::default() + }, + ], + children: ( + marrow_convert::builder::Int32Builder::default(), + marrow_convert::builder::Int64Builder::default(), + ), + }; + + builder.push_value(&S { a: 1, b: -1 }).unwrap(); + builder.push_value(&S { a: 2, b: -2 }).unwrap(); + builder.push_value(&S { a: 3, b: -3 }).unwrap(); + + let array = builder.build_array().unwrap(); + // TODO: check resulting array + std::mem::drop(array); +} diff --git a/marrow-convert/tests/derive_tests/test_default_array_builder.rs b/marrow-convert/tests/derive_tests/test_default_array_builder.rs new file mode 100644 index 0000000..406f070 --- /dev/null +++ b/marrow-convert/tests/derive_tests/test_default_array_builder.rs @@ -0,0 +1,20 @@ +use marrow_convert::builder::{ArrayBuilder, ArrayPush, DefaultArrayBuilder}; + +#[test] +fn example() { + #[derive(DefaultArrayBuilder)] + struct S { + a: i32, + b: i64, + } + + let mut builder = S::default_builder(); + + builder.push_value(&S { a: 1, b: -1 }).unwrap(); + builder.push_value(&S { a: 2, b: -2 }).unwrap(); + builder.push_value(&S { a: 3, b: -3 }).unwrap(); + + let array = builder.build_array().unwrap(); + // TODO: check resulting array + std::mem::drop(array); +} diff --git a/marrow-convert/tests/derive_tests/test_type_info.rs b/marrow-convert/tests/derive_tests/test_type_info.rs new file mode 100644 index 0000000..92ddda0 --- /dev/null +++ b/marrow-convert/tests/derive_tests/test_type_info.rs @@ -0,0 +1,563 @@ +use marrow::{ + datatypes::{DataType, Field, TimeUnit, UnionMode}, + types::f16, +}; +use marrow_convert::{ + Result, + types::{Context, DefaultArrayType, Options}, +}; + +#[test] +fn example() { + #[derive(DefaultArrayType)] + #[allow(dead_code)] + struct S { + a: i64, + b: [u8; 4], + } + + assert_eq!( + marrow_convert::types::get_data_type::(&Options::default()), + Ok(DataType::Struct(vec![ + Field { + name: String::from("a"), + data_type: DataType::Int64, + ..Default::default() + }, + Field { + name: String::from("b"), + data_type: DataType::FixedSizeBinary(4), + ..Default::default() + } + ])) + ); +} + +#[test] +fn overwrites() { + #[derive(DefaultArrayType)] + #[allow(dead_code)] + struct S { + a: i64, + b: [u8; 4], + } + + assert_eq!( + marrow_convert::types::get_data_type::(&Options::default().overwrite( + "$.b", + Field { + data_type: DataType::Binary, + ..Field::default() + } + )), + Ok(DataType::Struct(vec![ + Field { + name: String::from("a"), + data_type: DataType::Int64, + ..Default::default() + }, + Field { + name: String::from("b"), + data_type: DataType::Binary, + ..Default::default() + } + ])) + ); +} + +#[test] +fn newtype() { + #[derive(DefaultArrayType)] + #[allow(dead_code)] + struct S(f16); + + assert_eq!( + marrow_convert::types::get_data_type::(&Options::default()), + Ok(DataType::Float16) + ); +} + +#[test] +fn tuple() { + #[derive(DefaultArrayType)] + #[allow(dead_code)] + struct S(u8, [u8; 4]); + + assert_eq!( + marrow_convert::types::get_data_type::(&Options::default()), + Ok(DataType::Struct(vec![ + Field { + name: String::from("0"), + data_type: DataType::UInt8, + ..Field::default() + }, + Field { + name: String::from("1"), + data_type: DataType::FixedSizeBinary(4), + ..Field::default() + }, + ])) + ); +} + +#[test] +fn customize() { + #[derive(DefaultArrayType)] + #[allow(dead_code)] + struct S { + #[marrow(with = "timestamp_field")] + a: i64, + b: [u8; 4], + } + + fn timestamp_field(context: Context<'_>) -> Result { + Ok(Field { + name: String::from(context.get_name()), + data_type: DataType::Timestamp(TimeUnit::Millisecond, None), + ..Default::default() + }) + } + + assert_eq!( + marrow_convert::types::get_data_type::(&Options::default()), + Ok(DataType::Struct(vec![ + Field { + name: String::from("a"), + data_type: DataType::Timestamp(TimeUnit::Millisecond, None), + ..Default::default() + }, + Field { + name: String::from("b"), + data_type: DataType::FixedSizeBinary(4), + ..Default::default() + } + ])) + ); +} + +#[test] +fn fieldless_union() { + #[derive(DefaultArrayType)] + #[allow(dead_code)] + enum E { + A, + B, + C, + } + + assert_eq!( + marrow_convert::types::get_data_type::(&Options::default()), + Ok(DataType::Union( + vec![ + ( + 0, + Field { + name: String::from("A"), + data_type: DataType::Null, + nullable: true, + metadata: Default::default(), + } + ), + ( + 1, + Field { + name: String::from("B"), + data_type: DataType::Null, + nullable: true, + metadata: Default::default(), + } + ), + ( + 2, + Field { + name: String::from("C"), + data_type: DataType::Null, + nullable: true, + metadata: Default::default(), + } + ), + ], + UnionMode::Dense + )) + ); +} + +#[test] +fn new_type_enum() { + #[derive(DefaultArrayType)] + #[allow(dead_code)] + enum Enum { + Struct(Struct), + Int64(i64), + } + + #[derive(DefaultArrayType)] + #[allow(dead_code)] + struct Struct { + a: bool, + b: (), + } + + assert_eq!( + marrow_convert::types::get_data_type::(&Options::default()), + Ok(DataType::Union( + vec![ + ( + 0, + Field { + name: String::from("Struct"), + data_type: DataType::Struct(vec![ + Field { + name: String::from("a"), + data_type: DataType::Boolean, + ..Default::default() + }, + Field { + name: String::from("b"), + data_type: DataType::Null, + nullable: true, + ..Default::default() + }, + ]), + nullable: false, + metadata: Default::default(), + } + ), + ( + 1, + Field { + name: String::from("Int64"), + data_type: DataType::Int64, + ..Default::default() + } + ), + ], + UnionMode::Dense + )) + ); +} + +#[test] +fn new_tuple_enum() { + #[derive(DefaultArrayType)] + #[allow(dead_code)] + enum Enum { + Int64(i64), + Tuple(i8, u32), + } + + assert_eq!( + marrow_convert::types::get_data_type::(&Options::default()), + Ok(DataType::Union( + vec![ + ( + 0, + Field { + name: String::from("Int64"), + data_type: DataType::Int64, + ..Field::default() + } + ), + ( + 1, + Field { + name: String::from("Tuple"), + data_type: DataType::Struct(vec![ + Field { + name: String::from("0"), + data_type: DataType::Int8, + ..Field::default() + }, + Field { + name: String::from("1"), + data_type: DataType::UInt32, + ..Field::default() + }, + ]), + ..Field::default() + } + ), + ], + UnionMode::Dense + )) + ); +} + +#[test] +fn new_struct_enum() { + #[derive(DefaultArrayType)] + #[allow(dead_code)] + enum Enum { + Int64(i64), + Struct { a: f32, b: String }, + } + + assert_eq!( + marrow_convert::types::get_data_type::(&Options::default()), + Ok(DataType::Union( + vec![ + ( + 0, + Field { + name: String::from("Int64"), + data_type: DataType::Int64, + ..Field::default() + } + ), + ( + 1, + Field { + name: String::from("Struct"), + data_type: DataType::Struct(vec![ + Field { + name: String::from("a"), + data_type: DataType::Float32, + ..Field::default() + }, + Field { + name: String::from("b"), + data_type: DataType::LargeUtf8, + ..Field::default() + }, + ]), + ..Field::default() + } + ), + ], + UnionMode::Dense + )) + ); +} + +#[test] +fn const_generics() { + #[derive(DefaultArrayType)] + #[allow(unused)] + struct Struct { + data: [u8; N], + } + + assert_eq!( + marrow_convert::types::get_data_type::>(&Options::default()), + Ok(DataType::Struct(vec![Field { + name: String::from("data"), + data_type: DataType::FixedSizeBinary(4), + nullable: false, + metadata: Default::default(), + },])) + ); +} + +#[test] +fn liftime_generics() { + #[derive(DefaultArrayType)] + #[allow(unused)] + struct Struct<'a, 'b> { + a: &'a u8, + b: &'b u16, + } + + assert_eq!( + marrow_convert::types::get_data_type::(&Options::default()), + Ok(DataType::Struct(vec![ + Field { + name: String::from("a"), + data_type: DataType::UInt8, + ..Default::default() + }, + Field { + name: String::from("b"), + data_type: DataType::UInt16, + ..Default::default() + }, + ])) + ); +} + +#[test] +fn liftime_generics_with_bounds() { + #[derive(DefaultArrayType)] + #[allow(unused)] + struct Struct<'a, 'b: 'a> { + a: &'a u8, + b: &'b u16, + } + + assert_eq!( + marrow_convert::types::get_data_type::(&Options::default()), + Ok(DataType::Struct(vec![ + Field { + name: String::from("a"), + data_type: DataType::UInt8, + ..Default::default() + }, + Field { + name: String::from("b"), + data_type: DataType::UInt16, + ..Default::default() + }, + ])) + ); +} + +#[test] +fn liftime_generics_with_where_clause() { + #[derive(DefaultArrayType)] + #[allow(unused)] + struct Struct<'a, 'b> + where + 'a: 'b, + { + a: &'a u8, + b: &'b u16, + } + + assert_eq!( + marrow_convert::types::get_data_type::(&Options::default()), + Ok(DataType::Struct(vec![ + Field { + name: String::from("a"), + data_type: DataType::UInt8, + ..Default::default() + }, + Field { + name: String::from("b"), + data_type: DataType::UInt16, + ..Default::default() + }, + ])) + ); +} + +#[test] +fn enums_const_generics() { + #[derive(DefaultArrayType)] + #[allow(unused)] + enum Enum { + Data([u8; N]), + } + + assert_eq!( + marrow_convert::types::get_data_type::>(&Options::default()), + Ok(DataType::Union( + vec![( + 0, + Field { + name: String::from("Data"), + data_type: DataType::FixedSizeBinary(4), + nullable: false, + metadata: Default::default(), + } + ),], + UnionMode::Dense + )), + ); +} + +#[test] +fn enums_with_liftime_generics() { + #[derive(DefaultArrayType)] + #[allow(unused)] + enum Enum<'a, 'b> { + A(&'a u8), + B(&'b u16), + } + + assert_eq!( + marrow_convert::types::get_data_type::(&Options::default()), + Ok(DataType::Union( + vec![ + ( + 0, + Field { + name: String::from("A"), + data_type: DataType::UInt8, + ..Default::default() + } + ), + ( + 1, + Field { + name: String::from("B"), + data_type: DataType::UInt16, + ..Default::default() + } + ), + ], + UnionMode::Dense + )) + ); +} + +#[test] +fn enum_liftime_generics_with_bounds() { + #[derive(DefaultArrayType)] + #[allow(unused)] + enum Enum<'a, 'b: 'a> { + A(&'a u8), + B(&'b u16), + } + + assert_eq!( + marrow_convert::types::get_data_type::(&Options::default()), + Ok(DataType::Union( + vec![ + ( + 0, + Field { + name: String::from("A"), + data_type: DataType::UInt8, + ..Default::default() + } + ), + ( + 1, + Field { + name: String::from("B"), + data_type: DataType::UInt16, + ..Default::default() + } + ), + ], + UnionMode::Dense + )) + ); +} + +#[test] +fn enum_liftime_generics_with_where_clause() { + #[derive(DefaultArrayType)] + #[allow(unused)] + enum Enum<'a, 'b> + where + 'a: 'b, + { + A(&'a u8), + B(&'b u16), + } + + assert_eq!( + marrow_convert::types::get_data_type::(&Options::default()), + Ok(DataType::Union( + vec![ + ( + 0, + Field { + name: String::from("A"), + data_type: DataType::UInt8, + ..Default::default() + } + ), + ( + 1, + Field { + name: String::from("B"), + data_type: DataType::UInt16, + ..Default::default() + } + ), + ], + UnionMode::Dense + )) + ); +}