From 8845331d873a79115a9616432d1f77e89756343a Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 30 Jan 2026 04:19:59 +0000 Subject: [PATCH] Implement Rust SeaORM code generator FEATURE: Add code generator for converting PostgreSQL schemas to SeaORM entities Implements the Rust SeaORM code generator following the design document. Key features: - Type mapping from PostgreSQL types to Rust/SeaORM types - Identifier sanitization and Rust reserved word handling - Primary key detection (single and composite) - Column attribute generation (unique, indexed, nullable) - Foreign key relationship analysis (belongs_to, has_many, has_one) - Both compact (DeriveEntityModel) and expanded entity formats - Multi-file output with mod.rs and prelude.rs - Comprehensive unit and snapshot tests https://claude.ai/code/session_0149wURw3hU1v2bsnDwyijZo --- crates/codegen/src/rust/column.rs | 532 ++++++++++++ crates/codegen/src/rust/entity.rs | 656 +++++++++++++++ crates/codegen/src/rust/generator.rs | 499 +++++++++++ crates/codegen/src/rust/imports.rs | 225 +++++ crates/codegen/src/rust/mod.rs | 230 +++++- crates/codegen/src/rust/naming.rs | 522 ++++++++++++ crates/codegen/src/rust/primary_key.rs | 322 ++++++++ crates/codegen/src/rust/relation.rs | 702 ++++++++++++++++ crates/codegen/src/rust/tests/mod.rs | 4 + .../codegen/src/rust/tests/snapshot_tests.rs | 233 ++++++ ...__snapshot_tests__post_with_relations.snap | 38 + ...napshot_tests__snapshot_complex_types.snap | 33 + ...ts__snapshot_tests__snapshot_mod_file.snap | 12 + ...snapshot_tests__snapshot_prelude_file.snap | 10 + ...tests__snapshot_simple_entity_compact.snap | 23 + ...ests__snapshot_simple_entity_expanded.snap | 66 ++ ...__snapshot_tests__user_with_relations.snap | 31 + crates/codegen/src/rust/tests/unit_tests.rs | 494 +++++++++++ crates/codegen/src/rust/type_mapping.rs | 775 ++++++++++++++++++ 19 files changed, 5406 insertions(+), 1 deletion(-) create mode 100644 crates/codegen/src/rust/column.rs create mode 100644 crates/codegen/src/rust/entity.rs create mode 100644 crates/codegen/src/rust/generator.rs create mode 100644 crates/codegen/src/rust/imports.rs create mode 100644 crates/codegen/src/rust/naming.rs create mode 100644 crates/codegen/src/rust/primary_key.rs create mode 100644 crates/codegen/src/rust/relation.rs create mode 100644 crates/codegen/src/rust/tests/mod.rs create mode 100644 crates/codegen/src/rust/tests/snapshot_tests.rs create mode 100644 crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__post_with_relations.snap create mode 100644 crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_complex_types.snap create mode 100644 crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_mod_file.snap create mode 100644 crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_prelude_file.snap create mode 100644 crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_simple_entity_compact.snap create mode 100644 crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_simple_entity_expanded.snap create mode 100644 crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__user_with_relations.snap create mode 100644 crates/codegen/src/rust/tests/unit_tests.rs create mode 100644 crates/codegen/src/rust/type_mapping.rs diff --git a/crates/codegen/src/rust/column.rs b/crates/codegen/src/rust/column.rs new file mode 100644 index 0000000..6721dad --- /dev/null +++ b/crates/codegen/src/rust/column.rs @@ -0,0 +1,532 @@ +//! Column attribute generation for SeaORM code generation. +//! +//! This module provides utilities for generating column attributes and +//! detecting column properties from PostgreSQL table definitions. + +use tern_ddl::{Column, Constraint, ConstraintKind, Index}; + +use super::ReservedWordStrategy; +use super::naming::{SanitizedName, to_field_name}; +use super::primary_key::{is_column_auto_increment, is_primary_key_column}; +use super::type_mapping::{RustType, map_pg_type}; + +/// Information about a column needed for code generation. +#[derive(Debug, Clone)] +pub struct ColumnInfo { + /// The sanitized field name. + pub field_name: SanitizedName, + /// The Rust type for this column. + pub rust_type: RustType, + /// Whether this column is nullable. + pub is_nullable: bool, + /// Whether this column is a primary key. + pub is_primary_key: bool, + /// Whether this column has auto-increment behavior. + pub is_auto_increment: bool, + /// Whether this column has a unique constraint. + pub is_unique: bool, + /// Whether this column is indexed. + pub is_indexed: bool, + /// Whether this is a generated column. + pub is_generated: bool, + /// The doc comment for this column (from database comment). + pub doc_comment: Option, + /// The original column name from the database. + pub original_name: String, +} + +impl ColumnInfo { + /// Creates column info from a DDL column and table constraints. + pub fn from_column( + column: &Column, + constraints: &[Constraint], + indexes: &[Index], + strategy: &ReservedWordStrategy, + ) -> Self { + let field_name = to_field_name(column.name.as_ref(), strategy); + let rust_type = map_pg_type(&column.type_info); + let is_pk = is_primary_key_column(column, constraints); + let is_auto = if is_pk { + is_column_auto_increment(column) + } else { + false + }; + + Self { + field_name, + rust_type, + is_nullable: column.is_nullable, + is_primary_key: is_pk, + is_auto_increment: is_auto, + is_unique: has_unique_constraint(column, constraints), + is_indexed: has_single_column_index(column, indexes), + is_generated: column.generated.is_some(), + doc_comment: column.comment.as_ref().map(|c| c.as_ref().to_string()), + original_name: column.name.as_ref().to_string(), + } + } + + /// Returns the Rust type annotation, wrapped in Option if nullable. + pub fn type_annotation(&self) -> String { + if self.is_nullable { + self.rust_type.as_optional().annotation + } else { + self.rust_type.annotation.clone() + } + } + + /// Generates the `#[sea_orm(...)]` attributes for this column. + /// + /// Returns `None` if no attributes are needed. + pub fn generate_sea_orm_attrs(&self) -> Option { + let mut attrs = Vec::new(); + + // Primary key attribute + if self.is_primary_key { + attrs.push("primary_key".to_string()); + if !self.is_auto_increment { + attrs.push("auto_increment = false".to_string()); + } + } + + // Unique attribute (only for single-column unique, not PKs) + if self.is_unique && !self.is_primary_key { + attrs.push("unique".to_string()); + } + + // Indexed attribute (only for non-constraint indexes) + if self.is_indexed && !self.is_primary_key && !self.is_unique { + attrs.push("indexed".to_string()); + } + + // Column name attribute (if field name differs from original) + if self.field_name.needs_rename_attr { + attrs.push(format!("column_name = \"{}\"", self.field_name.original)); + } + + // Column type attribute (if needed) + if let Some(col_type) = self.get_column_type_attr() { + attrs.push(format!("column_type = \"{col_type}\"")); + } + + // Nullable attribute (explicit for types that need it) + if self.is_nullable && self.rust_type.needs_column_type_attr { + attrs.push("nullable".to_string()); + } + + // Ignore attribute for generated columns + if self.is_generated { + // Generated columns should be marked as ignore + // Clear other attrs and just use ignore + return Some("#[sea_orm(ignore)]".to_string()); + } + + if attrs.is_empty() { + None + } else { + Some(format!("#[sea_orm({})]", attrs.join(", "))) + } + } + + /// Gets the column_type attribute value if needed. + fn get_column_type_attr(&self) -> Option { + // For compact format, we need explicit column_type for some types + if self.rust_type.needs_column_type_attr { + // Extract short form from ColumnType::X + let col_type = &self.rust_type.column_type; + if col_type.starts_with("ColumnType::") { + let rest = &col_type["ColumnType::".len()..]; + // For types with arguments, just return the variant name + if let Some(paren_pos) = rest.find('(') { + return Some(rest[..paren_pos].to_string()); + } + return Some(rest.to_string()); + } + } + None + } +} + +/// Checks if a column has a single-column unique constraint. +pub fn has_unique_constraint(column: &Column, constraints: &[Constraint]) -> bool { + constraints.iter().any(|c| { + if let ConstraintKind::Unique(unique) = &c.kind { + unique.columns.len() == 1 && unique.columns.contains(&column.name) + } else { + false + } + }) +} + +/// Checks if a column has a single-column index (non-constraint). +pub fn has_single_column_index(column: &Column, indexes: &[Index]) -> bool { + indexes.iter().any(|idx| { + // Only consider non-constraint indexes + !idx.is_constraint_index + && idx.columns.len() == 1 + && idx.columns[0] + .column + .as_ref() + .is_some_and(|c| c == &column.name) + }) +} + +/// Gets all unique constraints that span multiple columns. +pub fn get_composite_unique_constraints( + constraints: &[Constraint], +) -> Vec<(&Constraint, &[tern_ddl::ColumnName])> { + constraints + .iter() + .filter_map(|c| { + if let ConstraintKind::Unique(unique) = &c.kind { + if unique.columns.len() > 1 { + return Some((c, unique.columns.as_slice())); + } + } + None + }) + .collect() +} + +/// Gets all check constraints. +pub fn get_check_constraints(constraints: &[Constraint]) -> Vec<(&Constraint, &str)> { + constraints + .iter() + .filter_map(|c| { + if let ConstraintKind::Check(check) = &c.kind { + Some((c, check.expression.as_ref())) + } else { + None + } + }) + .collect() +} + +/// Gets all exclusion constraints. +pub fn get_exclusion_constraints(constraints: &[Constraint]) -> Vec<&Constraint> { + constraints + .iter() + .filter(|c| matches!(c.kind, ConstraintKind::Exclusion(_))) + .collect() +} + +/// Gets all non-constraint indexes. +#[allow(dead_code)] +pub fn get_non_constraint_indexes(indexes: &[Index]) -> Vec<&Index> { + indexes + .iter() + .filter(|idx| !idx.is_constraint_index) + .collect() +} + +/// Gets composite indexes (more than one column). +pub fn get_composite_indexes(indexes: &[Index]) -> Vec<&Index> { + indexes + .iter() + .filter(|idx| !idx.is_constraint_index && idx.columns.len() > 1) + .collect() +} + +/// Generates a comment for a composite unique constraint. +pub fn format_composite_unique_comment( + constraint: &Constraint, + columns: &[tern_ddl::ColumnName], +) -> String { + let col_names: Vec<_> = columns.iter().map(|c| c.as_ref()).collect(); + format!( + "// Composite unique constraint: {} ({})\n// Note: Composite unique constraints are enforced at database level", + constraint.name.as_ref(), + col_names.join(", ") + ) +} + +/// Generates a comment for a check constraint. +pub fn format_check_constraint_comment(constraint: &Constraint, expression: &str) -> String { + format!( + "// Check constraint: {} ({})\n// Note: Check constraints are enforced at database level", + constraint.name.as_ref(), + expression + ) +} + +/// Generates a warning comment for an exclusion constraint. +pub fn format_exclusion_constraint_warning(constraint: &Constraint) -> String { + format!( + "// WARNING: Exclusion constraint '{}' not supported by SeaORM.", + constraint.name.as_ref() + ) +} + +/// Generates a comment for a complex index. +pub fn format_index_comment(index: &Index) -> String { + let col_names: Vec<_> = index + .columns + .iter() + .map(|ic| { + let name = ic + .column + .as_ref() + .map(|c| c.as_ref().to_string()) + .unwrap_or_else(|| { + ic.expression + .as_ref() + .map(|e| e.as_ref().to_string()) + .unwrap_or_else(|| "?".to_string()) + }); + + let order = ic.order.as_sql().unwrap_or(""); + if order.is_empty() { + name + } else { + format!("{name} {order}") + } + }) + .collect(); + + let mut comment = format!( + "// Index: {} ({})", + index.name.as_ref(), + col_names.join(", ") + ); + + if let Some(pred) = &index.predicate { + comment.push_str(&format!(" WHERE {}", pred.as_ref())); + } + + comment +} + +#[cfg(test)] +mod tests { + use super::*; + use tern_ddl::types::{IndexMethod, QualifiedCollationName}; + use tern_ddl::{ + CollationName, ColumnName, ConstraintName, IndexColumn, IndexName, NullsOrder, Oid, + PrimaryKeyConstraint, SchemaName, SortOrder, TypeInfo, TypeName, UniqueConstraint, + }; + + fn make_column(name: &str, type_name: &str, is_nullable: bool) -> Column { + Column { + name: ColumnName::try_new(name.to_string()).unwrap(), + position: 1, + type_info: TypeInfo { + name: TypeName::try_new(type_name.to_string()).unwrap(), + schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(), + formatted: type_name.to_string(), + is_array: false, + }, + is_nullable, + default: None, + generated: None, + identity: None, + collation: QualifiedCollationName::new( + SchemaName::try_new("pg_catalog".to_string()).unwrap(), + CollationName::try_new("default".to_string()).unwrap(), + ), + comment: None, + } + } + + fn make_pk_constraint(name: &str, columns: &[&str]) -> Constraint { + Constraint { + name: ConstraintName::try_new(name.to_string()).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + index_name: IndexName::try_new(name.to_string()).unwrap(), + }), + comment: None, + } + } + + fn make_unique_constraint(name: &str, columns: &[&str]) -> Constraint { + Constraint { + name: ConstraintName::try_new(name.to_string()).unwrap(), + kind: ConstraintKind::Unique(UniqueConstraint { + columns: columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + index_name: IndexName::try_new(name.to_string()).unwrap(), + nulls_not_distinct: false, + }), + comment: None, + } + } + + fn make_index(name: &str, columns: &[&str], is_constraint: bool) -> Index { + Index { + oid: Oid::new(1), + name: IndexName::try_new(name.to_string()).unwrap(), + method: IndexMethod::BTree, + is_unique: false, + is_constraint_index: is_constraint, + columns: columns + .iter() + .map(|c| IndexColumn { + column: Some(ColumnName::try_new(c.to_string()).unwrap()), + expression: None, + order: SortOrder::Ascending, + nulls: NullsOrder::Last, + }) + .collect(), + predicate: None, + comment: None, + } + } + + #[test] + fn test_column_info_simple() { + let col = make_column("id", "int4", false); + let constraints = vec![make_pk_constraint("pk", &["id"])]; + let indexes = vec![]; + let strategy = ReservedWordStrategy::AppendUnderscore; + + let info = ColumnInfo::from_column(&col, &constraints, &indexes, &strategy); + assert_eq!(info.field_name.identifier, "id"); + assert!(info.is_primary_key); + assert!(!info.is_nullable); + } + + #[test] + fn test_column_info_nullable() { + let col = make_column("email", "text", true); + let constraints = vec![]; + let indexes = vec![]; + let strategy = ReservedWordStrategy::AppendUnderscore; + + let info = ColumnInfo::from_column(&col, &constraints, &indexes, &strategy); + assert!(info.is_nullable); + assert_eq!(info.type_annotation(), "Option"); + } + + #[test] + fn test_column_info_unique() { + let col = make_column("email", "text", false); + let constraints = vec![make_unique_constraint("uq_email", &["email"])]; + let indexes = vec![]; + let strategy = ReservedWordStrategy::AppendUnderscore; + + let info = ColumnInfo::from_column(&col, &constraints, &indexes, &strategy); + assert!(info.is_unique); + } + + #[test] + fn test_column_info_indexed() { + let col = make_column("name", "text", false); + let constraints = vec![]; + let indexes = vec![make_index("idx_name", &["name"], false)]; + let strategy = ReservedWordStrategy::AppendUnderscore; + + let info = ColumnInfo::from_column(&col, &constraints, &indexes, &strategy); + assert!(info.is_indexed); + } + + #[test] + fn test_generate_sea_orm_attrs_pk() { + let col = make_column("id", "int4", false); + let mut col_info = + ColumnInfo::from_column(&col, &[], &[], &ReservedWordStrategy::AppendUnderscore); + col_info.is_primary_key = true; + col_info.is_auto_increment = true; + + let attrs = col_info.generate_sea_orm_attrs().unwrap(); + assert!(attrs.contains("primary_key")); + assert!(!attrs.contains("auto_increment = false")); + } + + #[test] + fn test_generate_sea_orm_attrs_pk_no_auto() { + let col = make_column("id", "uuid", false); + let mut col_info = + ColumnInfo::from_column(&col, &[], &[], &ReservedWordStrategy::AppendUnderscore); + col_info.is_primary_key = true; + col_info.is_auto_increment = false; + + let attrs = col_info.generate_sea_orm_attrs().unwrap(); + assert!(attrs.contains("primary_key")); + assert!(attrs.contains("auto_increment = false")); + } + + #[test] + fn test_generate_sea_orm_attrs_unique() { + let col = make_column("email", "text", false); + let mut col_info = + ColumnInfo::from_column(&col, &[], &[], &ReservedWordStrategy::AppendUnderscore); + col_info.is_unique = true; + + let attrs = col_info.generate_sea_orm_attrs().unwrap(); + assert!(attrs.contains("unique")); + } + + #[test] + fn test_generate_sea_orm_attrs_column_name() { + let col = make_column("type", "text", false); + let strategy = ReservedWordStrategy::AppendUnderscore; + let col_info = ColumnInfo::from_column(&col, &[], &[], &strategy); + + let attrs = col_info.generate_sea_orm_attrs().unwrap(); + assert!(attrs.contains("column_name = \"type\"")); + } + + #[test] + fn test_has_unique_constraint() { + let col = make_column("email", "text", false); + let constraints = vec![make_unique_constraint("uq_email", &["email"])]; + assert!(has_unique_constraint(&col, &constraints)); + + let constraints_composite = vec![make_unique_constraint( + "uq_composite", + &["email", "tenant_id"], + )]; + assert!(!has_unique_constraint(&col, &constraints_composite)); + } + + #[test] + fn test_get_composite_unique_constraints() { + let constraints = vec![ + make_unique_constraint("uq_email", &["email"]), + make_unique_constraint("uq_composite", &["email", "tenant_id"]), + ]; + + let composite = get_composite_unique_constraints(&constraints); + assert_eq!(composite.len(), 1); + assert_eq!(composite[0].1.len(), 2); + } + + #[test] + fn test_has_single_column_index() { + let col = make_column("name", "text", false); + let indexes = vec![ + make_index("idx_name", &["name"], false), + make_index("idx_multi", &["name", "email"], false), + make_index("idx_pk", &["id"], true), // constraint index + ]; + + assert!(has_single_column_index(&col, &indexes)); + + let col2 = make_column("id", "int4", false); + assert!(!has_single_column_index(&col2, &indexes)); // constraint indexes don't count + } + + #[test] + fn test_format_composite_unique_comment() { + let constraint = make_unique_constraint("uq_email_tenant", &["email", "tenant_id"]); + if let ConstraintKind::Unique(u) = &constraint.kind { + let comment = format_composite_unique_comment(&constraint, &u.columns); + assert!(comment.contains("uq_email_tenant")); + assert!(comment.contains("email, tenant_id")); + } + } + + #[test] + fn test_format_index_comment() { + let index = make_index("idx_name_email", &["name", "email"], false); + let comment = format_index_comment(&index); + assert!(comment.contains("idx_name_email")); + assert!(comment.contains("name")); + assert!(comment.contains("email")); + } +} diff --git a/crates/codegen/src/rust/entity.rs b/crates/codegen/src/rust/entity.rs new file mode 100644 index 0000000..3e77698 --- /dev/null +++ b/crates/codegen/src/rust/entity.rs @@ -0,0 +1,656 @@ +//! Entity generation for SeaORM code generation. +//! +//! This module provides utilities for generating SeaORM entity structs and +//! their associated code in both compact and expanded formats. + +use tern_ddl::Table; + +use super::column::{ + ColumnInfo, format_check_constraint_comment, format_composite_unique_comment, + format_exclusion_constraint_warning, format_index_comment, get_check_constraints, + get_composite_indexes, get_composite_unique_constraints, get_exclusion_constraints, +}; +use super::imports::ImportCollector; +use super::naming::{to_enum_variant, to_module_name, to_struct_name}; +use super::primary_key::{PrimaryKeyInfo, get_primary_key}; +use super::relation::{RelationInfo, analyze_relations}; +use super::{EntityFormat, RustSeaOrmCodegenConfig}; + +/// Information about a generated entity. +#[derive(Debug, Clone)] +pub struct EntityInfo { + /// The table name (original). + pub table_name: String, + /// The Rust struct name (PascalCase, singular). + pub struct_name: String, + /// The module name (snake_case, singular). + pub module_name: String, + /// Column information for all columns. + pub columns: Vec, + /// Primary key information. + pub primary_key: Option, + /// Relations for this entity. + pub relations: Vec, + /// Imports required for this entity. + pub imports: ImportCollector, + /// Warning comments to include. + pub warnings: Vec, + /// Constraint comments to include. + pub constraint_comments: Vec, + /// Index comments to include. + pub index_comments: Vec, + /// Doc comment for the entity (from table comment). + pub doc_comment: Option, + /// The schema name (if not public). + pub schema_name: Option, +} + +impl EntityInfo { + /// Creates entity info from a table definition. + pub fn from_table( + table: &Table, + config: &RustSeaOrmCodegenConfig, + all_tables: &[Table], + ) -> Self { + let table_name = table.name.as_ref().to_string(); + let struct_name = to_struct_name(&table_name); + let module_name = to_module_name(&table_name); + + // Process columns + let columns: Vec<_> = table + .columns + .iter() + .map(|c| { + ColumnInfo::from_column( + c, + &table.constraints, + &table.indexes, + &config.reserved_word_strategy, + ) + }) + .collect(); + + // Get primary key info + let primary_key = get_primary_key(table); + + // Collect imports + let mut imports = ImportCollector::with_sea_orm_prelude(); + for col in &columns { + imports.add_imports(&col.rust_type.imports); + imports.add_features(&col.rust_type.required_features); + } + + // Get relations if enabled + let relations = if config.generate_relations { + let all_relations = analyze_relations(all_tables); + all_relations.get(&table_name).cloned().unwrap_or_default() + } else { + Vec::new() + }; + + // Collect warnings + let mut warnings = Vec::new(); + + // Check for missing primary key + if primary_key.is_none() { + warnings.push(format!( + "// WARNING: Table '{}' has no primary key. SeaORM requires a primary key for entity operations.", + table_name + )); + } + + // Check for generated columns + for col in &columns { + if col.is_generated { + warnings.push(format!( + "// Note: Column '{}' is a generated column and will be ignored by SeaORM.", + col.original_name + )); + } + } + + // Collect constraint comments + let mut constraint_comments = Vec::new(); + + // Composite unique constraints + for (constraint, cols) in get_composite_unique_constraints(&table.constraints) { + constraint_comments.push(format_composite_unique_comment(constraint, cols)); + } + + // Check constraints + for (constraint, expr) in get_check_constraints(&table.constraints) { + constraint_comments.push(format_check_constraint_comment(constraint, expr)); + } + + // Exclusion constraints + for constraint in get_exclusion_constraints(&table.constraints) { + constraint_comments.push(format_exclusion_constraint_warning(constraint)); + } + + // Collect index comments + let mut index_comments = Vec::new(); + for index in get_composite_indexes(&table.indexes) { + index_comments.push(format_index_comment(index)); + } + + // Doc comment + let doc_comment = if config.include_doc_comments { + table.comment.as_ref().map(|c| c.as_ref().to_string()) + } else { + None + }; + + Self { + table_name, + struct_name, + module_name, + columns, + primary_key, + relations, + imports, + warnings, + constraint_comments, + index_comments, + doc_comment, + schema_name: config.schema_name.clone(), + } + } +} + +/// Generates a complete entity module file (compact format). +pub fn generate_entity_compact(entity: &EntityInfo, config: &RustSeaOrmCodegenConfig) -> String { + let mut lines = Vec::new(); + + // Module doc comment + lines.push(format!( + "//! SeaORM entity for `{}` table.", + entity.table_name + )); + lines.push("//!".to_string()); + lines.push("//! Generated by Tern.".to_string()); + lines.push(String::new()); + + // Imports + let import_block = entity.imports.generate(); + if !import_block.is_empty() { + lines.push(import_block); + lines.push(String::new()); + } + + // Warnings + for warning in &entity.warnings { + lines.push(warning.clone()); + } + if !entity.warnings.is_empty() { + lines.push(String::new()); + } + + // Constraint comments + for comment in &entity.constraint_comments { + lines.push(comment.clone()); + } + if !entity.constraint_comments.is_empty() { + lines.push(String::new()); + } + + // Index comments + for comment in &entity.index_comments { + lines.push(comment.clone()); + } + if !entity.index_comments.is_empty() { + lines.push(String::new()); + } + + // Entity doc comment + if let Some(doc) = &entity.doc_comment { + for line in doc.lines() { + lines.push(format!("/// {line}")); + } + } + + // Model struct with DeriveEntityModel + lines.push("#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]".to_string()); + + // Table name attribute + let mut sea_orm_attrs = vec![format!("table_name = \"{}\"", entity.table_name)]; + if let Some(schema) = &entity.schema_name { + sea_orm_attrs.push(format!("schema_name = \"{schema}\"")); + } + lines.push(format!("#[sea_orm({})]", sea_orm_attrs.join(", "))); + + lines.push("pub struct Model {".to_string()); + + // Fields + for col in &entity.columns { + // Skip generated columns (they're marked with ignore) + let field_name = &col.field_name.identifier; + let type_ann = col.type_annotation(); + + // Doc comment for field + if config.include_doc_comments { + if let Some(doc) = &col.doc_comment { + for line in doc.lines() { + lines.push(format!(" /// {line}")); + } + } + } + + // SeaORM attributes + if let Some(attrs) = col.generate_sea_orm_attrs() { + lines.push(format!(" {attrs}")); + } + + // Field declaration + lines.push(format!(" pub {field_name}: {type_ann},")); + } + + lines.push("}".to_string()); + lines.push(String::new()); + + // Relation enum + if config.generate_relations { + lines.push(generate_relation_enum(&entity.relations)); + lines.push(String::new()); + + // Related impls + for relation in &entity.relations { + lines.push(relation.generate_related_impl(&entity.table_name)); + lines.push(String::new()); + } + } else { + // Empty relation enum + lines.push("#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]".to_string()); + lines.push("pub enum Relation {}".to_string()); + lines.push(String::new()); + } + + // ActiveModelBehavior impl + lines.push("impl ActiveModelBehavior for ActiveModel {}".to_string()); + lines.push(String::new()); + + lines.join("\n") +} + +/// Generates the Relation enum. +fn generate_relation_enum(relations: &[RelationInfo]) -> String { + let mut lines = Vec::new(); + + lines.push("#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]".to_string()); + + if relations.is_empty() { + lines.push("pub enum Relation {}".to_string()); + } else { + lines.push("pub enum Relation {".to_string()); + + for relation in relations { + let attr = relation.generate_relation_attr(); + let variant = relation.variant_name(); + lines.push(format!(" {attr}")); + lines.push(format!(" {variant},")); + } + + lines.push("}".to_string()); + } + + lines.join("\n") +} + +/// Generates a complete entity module file (expanded format). +pub fn generate_entity_expanded(entity: &EntityInfo, config: &RustSeaOrmCodegenConfig) -> String { + let mut lines = Vec::new(); + + // Module doc comment + lines.push(format!( + "//! SeaORM entity for `{}` table.", + entity.table_name + )); + lines.push("//!".to_string()); + lines.push("//! Generated by Tern.".to_string()); + lines.push(String::new()); + + // Imports + let import_block = entity.imports.generate(); + if !import_block.is_empty() { + lines.push(import_block); + lines.push(String::new()); + } + + // Warnings + for warning in &entity.warnings { + lines.push(warning.clone()); + } + if !entity.warnings.is_empty() { + lines.push(String::new()); + } + + // Entity struct + lines.push("#[derive(Copy, Clone, Default, Debug, DeriveEntity)]".to_string()); + lines.push("pub struct Entity;".to_string()); + lines.push(String::new()); + + // EntityName impl + lines.push("impl EntityName for Entity {".to_string()); + if let Some(schema) = &entity.schema_name { + lines.push(format!(" fn schema_name(&self) -> Option<&str> {{")); + lines.push(format!(" Some(\"{schema}\")")); + lines.push(" }".to_string()); + lines.push(String::new()); + } else { + lines.push(" fn schema_name(&self) -> Option<&str> {".to_string()); + lines.push(" None".to_string()); + lines.push(" }".to_string()); + lines.push(String::new()); + } + lines.push(" fn table_name(&self) -> &str {".to_string()); + lines.push(format!(" \"{}\"", entity.table_name)); + lines.push(" }".to_string()); + lines.push("}".to_string()); + lines.push(String::new()); + + // Column enum + lines.push("#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]".to_string()); + lines.push("pub enum Column {".to_string()); + for col in &entity.columns { + let variant = to_enum_variant(&col.original_name); + lines.push(format!(" {variant},")); + } + lines.push("}".to_string()); + lines.push(String::new()); + + // ColumnTrait impl + lines.push("impl ColumnTrait for Column {".to_string()); + lines.push(" type EntityName = Entity;".to_string()); + lines.push(String::new()); + lines.push(" fn def(&self) -> ColumnDef {".to_string()); + lines.push(" match self {".to_string()); + for col in &entity.columns { + let variant = to_enum_variant(&col.original_name); + let col_type = &col.rust_type.column_type; + let mut def = format!("{col_type}.def()"); + if col.is_nullable { + def.push_str(".nullable()"); + } + if col.is_unique && !col.is_primary_key { + def.push_str(".unique()"); + } + lines.push(format!(" Self::{variant} => {def},")); + } + lines.push(" }".to_string()); + lines.push(" }".to_string()); + lines.push("}".to_string()); + lines.push(String::new()); + + // PrimaryKey enum + lines.push("#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)]".to_string()); + lines.push("pub enum PrimaryKey {".to_string()); + if let Some(pk) = &entity.primary_key { + for col_name in &pk.columns { + let variant = to_enum_variant(col_name.as_ref()); + lines.push(format!(" {variant},")); + } + } + lines.push("}".to_string()); + lines.push(String::new()); + + // PrimaryKeyTrait impl + lines.push("impl PrimaryKeyTrait for PrimaryKey {".to_string()); + + // Determine the PK value type + let pk_type = if let Some(pk) = &entity.primary_key { + if pk.columns.len() == 1 { + // Find the column type + entity + .columns + .iter() + .find(|c| c.original_name == pk.columns[0].as_ref()) + .map(|c| c.rust_type.annotation.clone()) + .unwrap_or_else(|| "i32".to_string()) + } else { + // Composite PK - tuple type + let types: Vec<_> = pk + .columns + .iter() + .filter_map(|col_name| { + entity + .columns + .iter() + .find(|c| c.original_name == col_name.as_ref()) + .map(|c| c.rust_type.annotation.clone()) + }) + .collect(); + format!("({})", types.join(", ")) + } + } else { + "i32".to_string() + }; + + lines.push(format!(" type ValueType = {pk_type};")); + lines.push(String::new()); + lines.push(" fn auto_increment() -> bool {".to_string()); + let auto_inc = entity + .primary_key + .as_ref() + .map(|pk| pk.is_auto_increment) + .unwrap_or(false); + lines.push(format!(" {auto_inc}")); + lines.push(" }".to_string()); + lines.push("}".to_string()); + lines.push(String::new()); + + // Model struct + if let Some(doc) = &entity.doc_comment { + for line in doc.lines() { + lines.push(format!("/// {line}")); + } + } + lines + .push("#[derive(Clone, Debug, PartialEq, Eq, DeriveModel, DeriveActiveModel)]".to_string()); + lines.push("pub struct Model {".to_string()); + for col in &entity.columns { + if col.is_generated { + continue; // Skip generated columns in expanded format too + } + let field_name = &col.field_name.identifier; + let type_ann = col.type_annotation(); + + if config.include_doc_comments { + if let Some(doc) = &col.doc_comment { + for line in doc.lines() { + lines.push(format!(" /// {line}")); + } + } + } + + lines.push(format!(" pub {field_name}: {type_ann},")); + } + lines.push("}".to_string()); + lines.push(String::new()); + + // Relation enum + if config.generate_relations { + lines.push(generate_relation_enum(&entity.relations)); + lines.push(String::new()); + + // Related impls + for relation in &entity.relations { + lines.push(relation.generate_related_impl(&entity.table_name)); + lines.push(String::new()); + } + } else { + lines.push("#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]".to_string()); + lines.push("pub enum Relation {}".to_string()); + lines.push(String::new()); + } + + // ActiveModelBehavior impl + lines.push("impl ActiveModelBehavior for ActiveModel {}".to_string()); + lines.push(String::new()); + + lines.join("\n") +} + +/// Generates an entity based on the configured format. +pub fn generate_entity(entity: &EntityInfo, config: &RustSeaOrmCodegenConfig) -> String { + match config.entity_format { + EntityFormat::Compact => generate_entity_compact(entity, config), + EntityFormat::Expanded => generate_entity_expanded(entity, config), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tern_ddl::types::QualifiedCollationName; + use tern_ddl::{ + CollationName, Column, ColumnName, Constraint, ConstraintKind, ConstraintName, + IdentityKind, IndexName, Oid, PrimaryKeyConstraint, SchemaName, TableKind, TableName, + TypeInfo, TypeName, + }; + + fn make_column( + name: &str, + type_name: &str, + is_nullable: bool, + identity: Option, + ) -> Column { + Column { + name: ColumnName::try_new(name.to_string()).unwrap(), + position: 1, + type_info: TypeInfo { + name: TypeName::try_new(type_name.to_string()).unwrap(), + schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(), + formatted: type_name.to_string(), + is_array: false, + }, + is_nullable, + default: None, + generated: None, + identity, + collation: QualifiedCollationName::new( + SchemaName::try_new("pg_catalog".to_string()).unwrap(), + CollationName::try_new("default".to_string()).unwrap(), + ), + comment: None, + } + } + + fn make_table_with_pk(name: &str, columns: Vec, pk_columns: &[&str]) -> Table { + let pk = Constraint { + name: ConstraintName::try_new(format!("{name}_pkey")).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: pk_columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + index_name: IndexName::try_new(format!("{name}_pkey")).unwrap(), + }), + comment: None, + }; + + Table { + oid: Oid::new(1), + name: TableName::try_new(name.to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints: vec![pk], + indexes: vec![], + comment: None, + } + } + + #[test] + fn test_entity_info_from_table() { + let columns = vec![ + make_column("id", "int4", false, Some(IdentityKind::Always)), + make_column("name", "text", false, None), + make_column("email", "text", true, None), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + let config = RustSeaOrmCodegenConfig::default(); + + let entity = EntityInfo::from_table(&table, &config, &[table.clone()]); + + assert_eq!(entity.table_name, "users"); + assert_eq!(entity.struct_name, "User"); + assert_eq!(entity.module_name, "user"); + assert_eq!(entity.columns.len(), 3); + assert!(entity.primary_key.is_some()); + } + + #[test] + fn test_generate_entity_compact() { + let columns = vec![ + make_column("id", "int4", false, Some(IdentityKind::Always)), + make_column("name", "text", false, None), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + let config = RustSeaOrmCodegenConfig::default(); + + let entity = EntityInfo::from_table(&table, &config, &[table.clone()]); + let code = generate_entity_compact(&entity, &config); + + assert!(code.contains("DeriveEntityModel")); + assert!(code.contains("table_name = \"users\"")); + assert!(code.contains("pub struct Model")); + assert!(code.contains("pub id: i32")); + assert!(code.contains("pub name: String")); + assert!(code.contains("primary_key")); + } + + #[test] + fn test_generate_entity_expanded() { + let columns = vec![ + make_column("id", "int4", false, Some(IdentityKind::Always)), + make_column("name", "text", false, None), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + let mut config = RustSeaOrmCodegenConfig::default(); + config.entity_format = EntityFormat::Expanded; + + let entity = EntityInfo::from_table(&table, &config, &[table.clone()]); + let code = generate_entity_expanded(&entity, &config); + + assert!(code.contains("DeriveEntity")); + assert!(code.contains("pub struct Entity;")); + assert!(code.contains("impl EntityName for Entity")); + assert!(code.contains("pub enum Column")); + assert!(code.contains("pub enum PrimaryKey")); + assert!(code.contains("impl ColumnTrait for Column")); + assert!(code.contains("impl PrimaryKeyTrait for PrimaryKey")); + assert!(code.contains("DeriveModel")); + } + + #[test] + fn test_generate_entity_with_nullable() { + let columns = vec![ + make_column("id", "int4", false, Some(IdentityKind::Always)), + make_column("bio", "text", true, None), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + let config = RustSeaOrmCodegenConfig::default(); + + let entity = EntityInfo::from_table(&table, &config, &[table.clone()]); + let code = generate_entity_compact(&entity, &config); + + assert!(code.contains("pub bio: Option")); + } + + #[test] + fn test_generate_entity_without_pk_warning() { + let columns = vec![make_column("data", "text", false, None)]; + let table = Table { + oid: Oid::new(1), + name: TableName::try_new("log_entries".to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints: vec![], + indexes: vec![], + comment: None, + }; + let config = RustSeaOrmCodegenConfig::default(); + + let entity = EntityInfo::from_table(&table, &config, &[table.clone()]); + + assert!(entity.warnings.iter().any(|w| w.contains("no primary key"))); + } +} diff --git a/crates/codegen/src/rust/generator.rs b/crates/codegen/src/rust/generator.rs new file mode 100644 index 0000000..61af117 --- /dev/null +++ b/crates/codegen/src/rust/generator.rs @@ -0,0 +1,499 @@ +//! Rust SeaORM code generator implementation. +//! +//! This module provides the main `RustSeaOrmCodegen` struct that implements the `Codegen` trait +//! for generating Rust SeaORM entities from PostgreSQL table definitions. + +use std::collections::HashMap; + +use tern_ddl::Table; + +use super::entity::{EntityInfo, generate_entity}; +use super::imports::ImportCollector; +use super::{OutputMode, RustSeaOrmCodegenConfig}; +use crate::Codegen; + +/// Rust SeaORM code generator. +/// +/// Generates Rust SeaORM entity definitions from PostgreSQL table schemas. +/// +/// # Example +/// +/// ```ignore +/// use tern_codegen::rust::{RustSeaOrmCodegen, RustSeaOrmCodegenConfig}; +/// use tern_codegen::Codegen; +/// +/// let codegen = RustSeaOrmCodegen::new(RustSeaOrmCodegenConfig::default()); +/// let tables = vec![/* ... */]; +/// let output = codegen.generate(tables); +/// +/// // output["user.rs"] contains the generated entity code +/// ``` +#[derive(Debug, Clone)] +pub struct RustSeaOrmCodegen { + config: RustSeaOrmCodegenConfig, +} + +impl RustSeaOrmCodegen { + /// Creates a new Rust SeaORM code generator with the given configuration. + pub fn new(config: RustSeaOrmCodegenConfig) -> Self { + Self { config } + } + + /// Creates a new Rust SeaORM code generator with default configuration. + pub fn with_defaults() -> Self { + Self::new(RustSeaOrmCodegenConfig::default()) + } + + /// Generates a single entities.rs file containing all entities. + fn generate_single_file(&self, tables: Vec) -> HashMap { + let mut output = HashMap::new(); + + if tables.is_empty() { + output.insert("entities.rs".to_string(), generate_empty_entities_file()); + return output; + } + + // Generate all entities + let entities: Vec = tables + .iter() + .map(|t| EntityInfo::from_table(t, &self.config, &tables)) + .collect(); + + // Collect all imports + let mut imports = ImportCollector::with_sea_orm_prelude(); + for entity in &entities { + imports.merge(&entity.imports); + } + + // Build file content + let mut content = Vec::new(); + + // Module doc comment + content.push("//! SeaORM entity definitions generated by Tern.".to_string()); + content.push("//!".to_string()); + content + .push("//! This file was automatically generated. Do not edit manually.".to_string()); + + // Add required features comment + if let Some(features_comment) = imports.generate_features_comment() { + content.push("//!".to_string()); + for line in features_comment.lines() { + content.push(line.to_string()); + } + } + + content.push(String::new()); + + // Imports + let import_block = imports.generate(); + if !import_block.is_empty() { + content.push(import_block); + content.push(String::new()); + } + + // Generate each entity as a submodule + for (i, entity) in entities.iter().enumerate() { + if i > 0 { + content.push(String::new()); + } + + content.push(format!("pub mod {} {{", entity.module_name)); + + // Generate entity code and indent it + let entity_code = generate_entity(entity, &self.config); + for line in entity_code.lines() { + if line.is_empty() { + content.push(String::new()); + } else { + content.push(format!(" {line}")); + } + } + + content.push("}".to_string()); + } + + content.push(String::new()); + + output.insert("entities.rs".to_string(), content.join("\n")); + output + } + + /// Generates multiple files, one per entity, with a mod.rs. + fn generate_multi_file(&self, tables: Vec
) -> HashMap { + let mut output = HashMap::new(); + + if tables.is_empty() { + output.insert("mod.rs".to_string(), generate_empty_mod_file()); + return output; + } + + // Generate all entities + let entities: Vec = tables + .iter() + .map(|t| EntityInfo::from_table(t, &self.config, &tables)) + .collect(); + + // Collect all features + let mut all_imports = ImportCollector::new(); + for entity in &entities { + all_imports.merge(&entity.imports); + } + + // Generate individual entity files + let mut module_names = Vec::new(); + let mut struct_names = Vec::new(); + + for entity in &entities { + let filename = format!("{}.rs", entity.module_name); + let entity_code = generate_entity(entity, &self.config); + output.insert(filename, entity_code); + + module_names.push(entity.module_name.clone()); + struct_names.push(entity.struct_name.clone()); + } + + // Generate mod.rs + let mod_content = generate_mod_file(&module_names, &all_imports); + output.insert("mod.rs".to_string(), mod_content); + + // Generate prelude.rs + let prelude_content = generate_prelude_file(&module_names, &struct_names); + output.insert("prelude.rs".to_string(), prelude_content); + + output + } +} + +impl Default for RustSeaOrmCodegen { + fn default() -> Self { + Self::with_defaults() + } +} + +impl Codegen for RustSeaOrmCodegen { + fn generate(&self, tables: Vec
) -> HashMap { + match self.config.output_mode { + OutputMode::SingleFile => self.generate_single_file(tables), + OutputMode::MultiFile => self.generate_multi_file(tables), + } + } +} + +/// Generates an empty entities.rs file. +fn generate_empty_entities_file() -> String { + r#"//! SeaORM entity definitions generated by Tern. +//! +//! This file was automatically generated. Do not edit manually. + +use sea_orm::entity::prelude::*; + +// No tables to generate +"# + .to_string() +} + +/// Generates an empty mod.rs file. +fn generate_empty_mod_file() -> String { + r#"//! SeaORM entity definitions generated by Tern. +//! +//! This file was automatically generated. Do not edit manually. + +// No entities to export +"# + .to_string() +} + +/// Generates the mod.rs file with module declarations. +fn generate_mod_file(module_names: &[String], imports: &ImportCollector) -> String { + let mut content = Vec::new(); + + content.push("//! SeaORM entity definitions generated by Tern.".to_string()); + content.push("//!".to_string()); + content.push("//! This file was automatically generated. Do not edit manually.".to_string()); + + // Add required features comment + if let Some(features_comment) = imports.generate_features_comment() { + content.push("//!".to_string()); + for line in features_comment.lines() { + content.push(line.to_string()); + } + } + + content.push(String::new()); + + // Module declarations + content.push("pub mod prelude;".to_string()); + content.push(String::new()); + + for module in module_names { + content.push(format!("pub mod {module};")); + } + + content.push(String::new()); + + content.join("\n") +} + +/// Generates the prelude.rs file with re-exports. +fn generate_prelude_file(module_names: &[String], struct_names: &[String]) -> String { + let mut content = Vec::new(); + + content.push("//! Re-exports commonly used entity types.".to_string()); + content.push("//!".to_string()); + content.push("//! This file was automatically generated. Do not edit manually.".to_string()); + content.push(String::new()); + + // Re-export Entity types with aliases + for (module, struct_name) in module_names.iter().zip(struct_names.iter()) { + content.push(format!("pub use super::{module}::Entity as {struct_name};")); + } + + content.push(String::new()); + + content.join("\n") +} + +#[cfg(test)] +mod tests { + use super::*; + use tern_ddl::types::QualifiedCollationName; + use tern_ddl::{ + CollationName, Column, ColumnName, Constraint, ConstraintKind, ConstraintName, + ForeignKeyAction, ForeignKeyConstraint, IdentityKind, IndexName, Oid, PrimaryKeyConstraint, + QualifiedTableName, SchemaName, TableKind, TableName, TypeInfo, TypeName, + }; + + fn make_column(name: &str, type_name: &str, is_nullable: bool) -> Column { + Column { + name: ColumnName::try_new(name.to_string()).unwrap(), + position: 1, + type_info: TypeInfo { + name: TypeName::try_new(type_name.to_string()).unwrap(), + schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(), + formatted: type_name.to_string(), + is_array: false, + }, + is_nullable, + default: None, + generated: None, + identity: Some(IdentityKind::Always), + collation: QualifiedCollationName::new( + SchemaName::try_new("pg_catalog".to_string()).unwrap(), + CollationName::try_new("default".to_string()).unwrap(), + ), + comment: None, + } + } + + fn make_column_no_identity(name: &str, type_name: &str, is_nullable: bool) -> Column { + let mut col = make_column(name, type_name, is_nullable); + col.identity = None; + col + } + + fn make_table_with_pk(name: &str, columns: Vec, pk_columns: &[&str]) -> Table { + let pk = Constraint { + name: ConstraintName::try_new(format!("{name}_pkey")).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: pk_columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + index_name: IndexName::try_new(format!("{name}_pkey")).unwrap(), + }), + comment: None, + }; + + Table { + oid: Oid::new(1), + name: TableName::try_new(name.to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints: vec![pk], + indexes: vec![], + comment: None, + } + } + + fn make_table_with_fk( + name: &str, + columns: Vec, + pk_columns: &[&str], + fk_column: &str, + fk_target_table: &str, + fk_target_column: &str, + ) -> Table { + let pk = Constraint { + name: ConstraintName::try_new(format!("{name}_pkey")).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: pk_columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + index_name: IndexName::try_new(format!("{name}_pkey")).unwrap(), + }), + comment: None, + }; + + let fk = Constraint { + name: ConstraintName::try_new(format!("{name}_{fk_column}_fkey")).unwrap(), + kind: ConstraintKind::ForeignKey(ForeignKeyConstraint { + columns: vec![ColumnName::try_new(fk_column.to_string()).unwrap()], + referenced_table: QualifiedTableName::new( + SchemaName::try_new("public".to_string()).unwrap(), + TableName::try_new(fk_target_table.to_string()).unwrap(), + ), + referenced_columns: vec![ + ColumnName::try_new(fk_target_column.to_string()).unwrap(), + ], + on_delete: ForeignKeyAction::NoAction, + on_update: ForeignKeyAction::NoAction, + is_deferrable: false, + is_initially_deferred: false, + }), + comment: None, + }; + + Table { + oid: Oid::new(1), + name: TableName::try_new(name.to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints: vec![pk, fk], + indexes: vec![], + comment: None, + } + } + + #[test] + fn test_generate_empty_tables() { + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![]); + + assert!(output.contains_key("mod.rs")); + let content = &output["mod.rs"]; + assert!(content.contains("No entities to export")); + } + + #[test] + fn test_generate_single_table_multi_file() { + let columns = vec![ + make_column("id", "int4", false), + make_column_no_identity("name", "text", false), + make_column_no_identity("email", "text", true), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + assert!(output.contains_key("mod.rs")); + assert!(output.contains_key("prelude.rs")); + assert!(output.contains_key("user.rs")); + + let user_rs = &output["user.rs"]; + assert!(user_rs.contains("pub struct Model")); + assert!(user_rs.contains("table_name = \"users\"")); + assert!(user_rs.contains("pub id: i32")); + assert!(user_rs.contains("pub name: String")); + assert!(user_rs.contains("pub email: Option")); + + let mod_rs = &output["mod.rs"]; + assert!(mod_rs.contains("pub mod user;")); + assert!(mod_rs.contains("pub mod prelude;")); + + let prelude_rs = &output["prelude.rs"]; + assert!(prelude_rs.contains("pub use super::user::Entity as User;")); + } + + #[test] + fn test_generate_single_file_mode() { + let columns = vec![ + make_column("id", "int4", false), + make_column_no_identity("name", "text", false), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + + let config = RustSeaOrmCodegenConfig { + output_mode: OutputMode::SingleFile, + ..Default::default() + }; + let codegen = RustSeaOrmCodegen::new(config); + let output = codegen.generate(vec![table]); + + assert!(output.contains_key("entities.rs")); + let content = &output["entities.rs"]; + assert!(content.contains("pub mod user")); + assert!(content.contains("pub struct Model")); + } + + #[test] + fn test_generate_multiple_tables() { + let user_columns = vec![ + make_column("id", "int4", false), + make_column_no_identity("name", "text", false), + ]; + let user_table = make_table_with_pk("users", user_columns, &["id"]); + + let post_columns = vec![ + make_column("id", "int4", false), + make_column_no_identity("title", "text", false), + make_column_no_identity("user_id", "int4", false), + ]; + let post_table = + make_table_with_fk("posts", post_columns, &["id"], "user_id", "users", "id"); + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![user_table, post_table]); + + assert!(output.contains_key("user.rs")); + assert!(output.contains_key("post.rs")); + + let mod_rs = &output["mod.rs"]; + assert!(mod_rs.contains("pub mod user;")); + assert!(mod_rs.contains("pub mod post;")); + + let prelude = &output["prelude.rs"]; + assert!(prelude.contains("User")); + assert!(prelude.contains("Post")); + } + + #[test] + fn test_generate_with_relations() { + let user_columns = vec![ + make_column("id", "int4", false), + make_column_no_identity("name", "text", false), + ]; + let user_table = make_table_with_pk("users", user_columns, &["id"]); + + let post_columns = vec![ + make_column("id", "int4", false), + make_column_no_identity("user_id", "int4", false), + make_column_no_identity("title", "text", false), + ]; + let post_table = + make_table_with_fk("posts", post_columns, &["id"], "user_id", "users", "id"); + + let config = RustSeaOrmCodegenConfig { + generate_relations: true, + ..Default::default() + }; + let codegen = RustSeaOrmCodegen::new(config); + let output = codegen.generate(vec![user_table, post_table]); + + let user_rs = &output["user.rs"]; + assert!(user_rs.contains("has_many")); + assert!(user_rs.contains("Posts")); + + let post_rs = &output["post.rs"]; + assert!(post_rs.contains("belongs_to")); + assert!(post_rs.contains("User")); + } + + #[test] + fn test_default_codegen_implements_trait() { + let codegen = RustSeaOrmCodegen::default(); + let output = codegen.generate(vec![]); + assert!(output.contains_key("mod.rs")); + } +} diff --git a/crates/codegen/src/rust/imports.rs b/crates/codegen/src/rust/imports.rs new file mode 100644 index 0000000..315be94 --- /dev/null +++ b/crates/codegen/src/rust/imports.rs @@ -0,0 +1,225 @@ +//! Import management for Rust code generation. +//! +//! This module handles collecting and generating use statements for SeaORM entities. + +use std::collections::{BTreeSet, HashMap}; + +use super::type_mapping::RustImport; + +/// Collects and organizes imports for generated Rust code. +#[derive(Debug, Clone, Default)] +pub struct ImportCollector { + /// Standard library imports. + std_imports: BTreeSet, + /// External crate imports (crate -> items). + external_imports: HashMap>, + /// Required SeaORM features. + required_features: BTreeSet, +} + +impl ImportCollector { + /// Creates a new import collector. + pub fn new() -> Self { + Self::default() + } + + /// Creates a collector with the standard SeaORM entity prelude. + pub fn with_sea_orm_prelude() -> Self { + let mut collector = Self::new(); + collector.add_sea_orm_prelude(); + collector + } + + /// Adds the standard SeaORM entity prelude import. + pub fn add_sea_orm_prelude(&mut self) { + self.external_imports + .entry("sea_orm::entity::prelude".to_string()) + .or_default() + .insert("*".to_string()); + } + + /// Adds a Rust import. + pub fn add_import(&mut self, import: &RustImport) { + self.external_imports + .entry(import.module.clone()) + .or_default() + .insert(import.name.clone()); + } + + /// Adds multiple imports. + pub fn add_imports(&mut self, imports: &[RustImport]) { + for import in imports { + self.add_import(import); + } + } + + /// Adds a required feature flag. + pub fn add_feature(&mut self, feature: &str) { + self.required_features.insert(feature.to_string()); + } + + /// Adds multiple feature flags. + pub fn add_features(&mut self, features: &[&str]) { + for feature in features { + self.required_features.insert((*feature).to_string()); + } + } + + /// Merges another collector into this one. + pub fn merge(&mut self, other: &ImportCollector) { + self.std_imports.extend(other.std_imports.iter().cloned()); + for (module, items) in &other.external_imports { + self.external_imports + .entry(module.clone()) + .or_default() + .extend(items.iter().cloned()); + } + self.required_features + .extend(other.required_features.iter().cloned()); + } + + /// Returns the required SeaORM features. + pub fn required_features(&self) -> Vec { + self.required_features.iter().cloned().collect() + } + + /// Generates the import block as a string. + pub fn generate(&self) -> String { + let mut lines = Vec::new(); + + // Standard library imports (sorted) + for import in &self.std_imports { + lines.push(format!("use std::{import};")); + } + + if !self.std_imports.is_empty() && !self.external_imports.is_empty() { + lines.push(String::new()); + } + + // External imports (sorted by crate name) + let mut sorted_modules: Vec<_> = self.external_imports.keys().collect(); + sorted_modules.sort(); + + for module in sorted_modules { + let items = &self.external_imports[module]; + + // Special handling for wildcard imports + if items.contains("*") { + lines.push(format!("use {module}::*;")); + continue; + } + + // Format imports + if items.len() == 1 { + let item = items.iter().next().unwrap(); + lines.push(format!("use {module}::{item};")); + } else { + let mut sorted_items: Vec<&str> = items.iter().map(|s| s.as_str()).collect(); + sorted_items.sort(); + let items_str = sorted_items.join(", "); + lines.push(format!("use {module}::{{{items_str}}};")); + } + } + + lines.join("\n") + } + + /// Generates a feature flags comment block. + pub fn generate_features_comment(&self) -> Option { + if self.required_features.is_empty() { + return None; + } + + let mut lines = Vec::new(); + lines.push("//! Required SeaORM features:".to_string()); + for feature in &self.required_features { + lines.push(format!("//! - `{feature}`")); + } + + Some(lines.join("\n")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_collector() { + let collector = ImportCollector::new(); + assert_eq!(collector.generate(), ""); + } + + #[test] + fn test_sea_orm_prelude() { + let collector = ImportCollector::with_sea_orm_prelude(); + let imports = collector.generate(); + assert!(imports.contains("use sea_orm::entity::prelude::*;")); + } + + #[test] + fn test_add_import() { + let mut collector = ImportCollector::new(); + collector.add_import(&RustImport::new("chrono", "DateTime")); + let imports = collector.generate(); + assert!(imports.contains("use chrono::DateTime;")); + } + + #[test] + fn test_multiple_imports_same_module() { + let mut collector = ImportCollector::new(); + collector.add_import(&RustImport::new("chrono", "DateTime")); + collector.add_import(&RustImport::new("chrono", "FixedOffset")); + let imports = collector.generate(); + assert!(imports.contains("use chrono::{DateTime, FixedOffset};")); + } + + #[test] + fn test_multiple_modules() { + let mut collector = ImportCollector::new(); + collector.add_import(&RustImport::new("chrono", "DateTime")); + collector.add_import(&RustImport::new("uuid", "Uuid")); + let imports = collector.generate(); + assert!(imports.contains("use chrono::DateTime;")); + assert!(imports.contains("use uuid::Uuid;")); + } + + #[test] + fn test_add_features() { + let mut collector = ImportCollector::new(); + collector.add_feature("with-chrono"); + collector.add_feature("with-uuid"); + let features = collector.required_features(); + assert!(features.contains(&"with-chrono".to_string())); + assert!(features.contains(&"with-uuid".to_string())); + } + + #[test] + fn test_generate_features_comment() { + let mut collector = ImportCollector::new(); + collector.add_feature("with-chrono"); + let comment = collector.generate_features_comment().unwrap(); + assert!(comment.contains("with-chrono")); + } + + #[test] + fn test_merge() { + let mut collector1 = ImportCollector::new(); + collector1.add_import(&RustImport::new("chrono", "DateTime")); + collector1.add_feature("with-chrono"); + + let mut collector2 = ImportCollector::new(); + collector2.add_import(&RustImport::new("uuid", "Uuid")); + collector2.add_feature("with-uuid"); + + collector1.merge(&collector2); + + let imports = collector1.generate(); + assert!(imports.contains("chrono")); + assert!(imports.contains("uuid")); + + let features = collector1.required_features(); + assert!(features.contains(&"with-chrono".to_string())); + assert!(features.contains(&"with-uuid".to_string())); + } +} diff --git a/crates/codegen/src/rust/mod.rs b/crates/codegen/src/rust/mod.rs index 893fb98..d2d344d 100644 --- a/crates/codegen/src/rust/mod.rs +++ b/crates/codegen/src/rust/mod.rs @@ -1 +1,229 @@ -//! Rust code generation module. +//! Rust SeaORM code generation module. +//! +//! This module provides code generation for Rust SeaORM entities from PostgreSQL +//! table definitions. It supports: +//! +//! - Idiomatic SeaORM entities with proper type mappings +//! - Primary keys (single and composite), foreign keys, unique constraints, and indexes +//! - Both compact (DeriveEntityModel) and expanded entity formats +//! - Proper handling of Rust reserved words and invalid identifiers +//! - Relationship generation for foreign keys +//! - ActiveEnum generation for PostgreSQL enums +//! +//! # Example +//! +//! ```ignore +//! use tern_codegen::{Codegen, rust::RustSeaOrmCodegen}; +//! +//! let codegen = RustSeaOrmCodegen::new(RustSeaOrmCodegenConfig::default()); +//! let output = codegen.generate(tables); +//! // output contains generated SeaORM entity files +//! ``` + +mod column; +mod entity; +mod generator; +mod imports; +mod naming; +mod primary_key; +mod relation; +mod type_mapping; + +#[cfg(test)] +mod tests; + +pub use generator::RustSeaOrmCodegen; + +use std::fmt; + +/// Configuration for Rust SeaORM code generation. +#[derive(Debug, Clone)] +pub struct RustSeaOrmCodegenConfig { + /// Whether to generate compact format (DeriveEntityModel) or expanded format. + /// Compact is recommended and is the default. + pub entity_format: EntityFormat, + + /// Whether to generate relationship attributes and Related trait implementations. + pub generate_relations: bool, + + /// Whether to include doc comments from table/column comments. + pub include_doc_comments: bool, + + /// How to handle Rust reserved words in identifiers. + pub reserved_word_strategy: ReservedWordStrategy, + + /// Module name for the generated entities (used for cross-module references). + pub module_name: Option, + + /// Whether to generate ActiveEnum types for PostgreSQL enums. + /// Note: Requires enum definitions to be passed separately or inferred. + pub generate_active_enums: bool, + + /// The schema name to use (if not public). + pub schema_name: Option, + + /// Output mode: single file or multiple files. + pub output_mode: OutputMode, +} + +impl Default for RustSeaOrmCodegenConfig { + fn default() -> Self { + Self { + entity_format: EntityFormat::default(), + generate_relations: false, + include_doc_comments: true, + reserved_word_strategy: ReservedWordStrategy::default(), + module_name: None, + generate_active_enums: false, + schema_name: None, + output_mode: OutputMode::default(), + } + } +} + +/// Entity format for generated code. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub enum EntityFormat { + /// Uses DeriveEntityModel macro - recommended, less boilerplate. + #[default] + Compact, + /// Generates explicit Column enum, PrimaryKey enum, and trait implementations. + Expanded, +} + +impl fmt::Display for EntityFormat { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Compact => write!(f, "compact"), + Self::Expanded => write!(f, "expanded"), + } + } +} + +/// Strategy for handling Rust reserved words in identifiers. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub enum ReservedWordStrategy { + /// Append an underscore: `type` -> `type_` with column_name attribute. + #[default] + AppendUnderscore, + /// Use raw identifier: `type` -> `r#type`. + RawIdentifier, + /// Prepend with custom prefix: `type` -> `field_type`. + PrependPrefix(String), +} + +impl fmt::Display for ReservedWordStrategy { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::AppendUnderscore => write!(f, "append_underscore"), + Self::RawIdentifier => write!(f, "raw_identifier"), + Self::PrependPrefix(prefix) => write!(f, "prepend_prefix({prefix})"), + } + } +} + +/// Output mode for generated code. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub enum OutputMode { + /// Generate all entities in a single `entities.rs` file. + SingleFile, + /// Generate separate files for each entity with a `mod.rs`. + #[default] + MultiFile, +} + +impl fmt::Display for OutputMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::SingleFile => write!(f, "single_file"), + Self::MultiFile => write!(f, "multi_file"), + } + } +} + +/// Errors that can occur during Rust SeaORM code generation. +#[derive(Debug, thiserror::Error)] +pub enum RustSeaOrmCodegenError { + /// An unsupported PostgreSQL type was encountered. + #[error("unsupported PostgreSQL type: {type_name} (formatted: {formatted})")] + UnsupportedType { + type_name: String, + formatted: String, + }, + + /// A table has no columns. + #[error("table '{table_name}' has no columns")] + EmptyTable { table_name: String }, + + /// A table has no primary key. + #[error("table '{table_name}' has no primary key")] + NoPrimaryKey { table_name: String }, + + /// An identifier is invalid after sanitization. + #[error("invalid identifier after sanitization: '{original}' -> '{sanitized}'")] + InvalidIdentifier { original: String, sanitized: String }, + + /// Code generation failed. + #[error("code generation failed: {message}")] + GenerationError { message: String }, +} + +/// Warnings that don't prevent generation but should be reported. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum RustSeaOrmCodegenWarning { + /// An unsupported constraint was encountered. + UnsupportedConstraint { + table: String, + constraint: String, + kind: String, + }, + /// A generated column was ignored. + GeneratedColumnIgnored { table: String, column: String }, + /// A feature flag is required for the generated code. + FeatureFlagRequired { feature: String, reason: String }, + /// A table has no primary key. + TableWithoutPrimaryKey { table: String }, +} + +impl fmt::Display for RustSeaOrmCodegenWarning { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::UnsupportedConstraint { + table, + constraint, + kind, + } => { + write!( + f, + "Table '{table}': Unsupported {kind} constraint '{constraint}'" + ) + } + Self::GeneratedColumnIgnored { table, column } => { + write!( + f, + "Table '{table}': Generated column '{column}' ignored (not supported by SeaORM)" + ) + } + Self::FeatureFlagRequired { feature, reason } => { + write!(f, "Required SeaORM feature: '{feature}' ({reason})") + } + Self::TableWithoutPrimaryKey { table } => { + write!( + f, + "Table '{table}': No primary key defined (SeaORM requires a primary key)" + ) + } + } + } +} + +/// Definition of a PostgreSQL enum type for ActiveEnum generation. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EnumDefinition { + /// The enum type name. + pub name: String, + /// The schema containing the enum (if not public). + pub schema: Option, + /// The enum values in order. + pub values: Vec, +} diff --git a/crates/codegen/src/rust/naming.rs b/crates/codegen/src/rust/naming.rs new file mode 100644 index 0000000..d2ed6ec --- /dev/null +++ b/crates/codegen/src/rust/naming.rs @@ -0,0 +1,522 @@ +//! Rust identifier naming and sanitization. +//! +//! This module handles conversion of PostgreSQL identifiers to valid Rust identifiers, +//! including handling of reserved words, invalid characters, and naming conventions. + +use super::ReservedWordStrategy; + +/// Rust reserved keywords that cannot be used as identifiers. +/// +/// These include strict keywords and reserved keywords from all editions. +const RUST_KEYWORDS: &[&str] = &[ + // Strict keywords + "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum", "extern", + "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", + "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true", "type", + "unsafe", "use", "where", "while", + // Reserved keywords (may become keywords in future) + "abstract", "become", "box", "do", "final", "macro", "override", "priv", "try", "typeof", + "unsized", "virtual", "yield", // 2018+ edition keywords + "union", +]; + +/// Rust weak keywords that are reserved in certain contexts. +const RUST_WEAK_KEYWORDS: &[&str] = &[ + // These are only keywords in specific contexts + "macro_rules", + "raw", + "safe", // safe is reserved for future use +]; + +/// Result of sanitizing an identifier. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SanitizedName { + /// The sanitized identifier safe for use in Rust code. + pub identifier: String, + /// Whether a `column_name` or `rename` attribute is needed because the + /// identifier differs from the original database name. + pub needs_rename_attr: bool, + /// The original database name. + pub original: String, +} + +impl SanitizedName { + /// Creates a new sanitized name where no renaming is needed. + pub fn unchanged(name: &str) -> Self { + Self { + identifier: name.to_string(), + needs_rename_attr: false, + original: name.to_string(), + } + } + + /// Creates a new sanitized name with renaming. + pub fn renamed(identifier: String, original: &str) -> Self { + Self { + identifier, + needs_rename_attr: true, + original: original.to_string(), + } + } +} + +/// Checks if a name is a Rust keyword. +pub fn is_rust_keyword(name: &str) -> bool { + RUST_KEYWORDS.contains(&name) +} + +/// Checks if a name is a Rust weak keyword. +pub fn is_rust_weak_keyword(name: &str) -> bool { + RUST_WEAK_KEYWORDS.contains(&name) +} + +/// Checks if a name is any kind of Rust reserved word. +pub fn is_reserved_word(name: &str) -> bool { + is_rust_keyword(name) || is_rust_weak_keyword(name) +} + +/// Sanitizes a database identifier for use as a Rust identifier. +/// +/// This function: +/// 1. Replaces invalid characters with underscores +/// 2. Ensures the name doesn't start with a digit +/// 3. Handles reserved words according to the strategy +/// 4. Returns the sanitized name with information about whether renaming is needed +pub fn sanitize_identifier(name: &str, strategy: &ReservedWordStrategy) -> SanitizedName { + if name.is_empty() { + return SanitizedName::renamed("_empty".to_string(), name); + } + + let mut sanitized = String::with_capacity(name.len() + 1); + let mut needs_rename = false; + + // Process each character + for (i, c) in name.chars().enumerate() { + if i == 0 { + // First character must be a letter or underscore + if c.is_ascii_alphabetic() || c == '_' { + sanitized.push(c); + } else if c.is_ascii_digit() { + // Prefix with underscore if starts with digit + sanitized.push('_'); + sanitized.push(c); + needs_rename = true; + } else { + // Replace invalid first character with underscore + sanitized.push('_'); + needs_rename = true; + } + } else if c.is_ascii_alphanumeric() || c == '_' { + sanitized.push(c); + } else { + // Replace invalid characters with underscore + sanitized.push('_'); + needs_rename = true; + } + } + + // Collapse multiple consecutive underscores + let mut collapsed = String::with_capacity(sanitized.len()); + let mut prev_underscore = false; + for c in sanitized.chars() { + if c == '_' { + if !prev_underscore { + collapsed.push(c); + } else { + needs_rename = true; + } + prev_underscore = true; + } else { + collapsed.push(c); + prev_underscore = false; + } + } + sanitized = collapsed; + + // Remove trailing underscores (unless it's the only character or was in original) + while sanitized.len() > 1 && sanitized.ends_with('_') && !name.ends_with('_') { + sanitized.pop(); + needs_rename = true; + } + + // Handle reserved words + if is_reserved_word(&sanitized) { + needs_rename = true; + sanitized = apply_reserved_word_strategy(&sanitized, strategy); + } + + // Final validation - ensure we have a valid identifier + if sanitized.is_empty() || sanitized == "_" { + return SanitizedName::renamed("_field".to_string(), name); + } + + if needs_rename { + SanitizedName::renamed(sanitized, name) + } else { + SanitizedName::unchanged(name) + } +} + +/// Applies the reserved word strategy to transform a reserved word. +fn apply_reserved_word_strategy(name: &str, strategy: &ReservedWordStrategy) -> String { + match strategy { + ReservedWordStrategy::AppendUnderscore => format!("{name}_"), + ReservedWordStrategy::RawIdentifier => format!("r#{name}"), + ReservedWordStrategy::PrependPrefix(prefix) => format!("{prefix}{name}"), + } +} + +/// Converts a table name to a Rust struct name (PascalCase, singular). +/// +/// Examples: +/// - "users" -> "User" +/// - "user_accounts" -> "UserAccount" +/// - "order_items" -> "OrderItem" +/// - "categories" -> "Category" +pub fn to_struct_name(table_name: &str) -> String { + // First singularize, then convert to PascalCase + let singular = singularize(table_name); + to_pascal_case(&singular) +} + +/// Converts a table name to a Rust module name (snake_case, singular). +/// +/// Examples: +/// - "users" -> "user" +/// - "user_accounts" -> "user_account" +/// - "OrderItems" -> "order_item" +pub fn to_module_name(table_name: &str) -> String { + let snake = to_snake_case(table_name); + singularize(&snake) +} + +/// Converts a string to PascalCase. +fn to_pascal_case(s: &str) -> String { + let mut result = String::with_capacity(s.len()); + let mut capitalize_next = true; + + for c in s.chars() { + if c == '_' || c == '-' || c == ' ' { + capitalize_next = true; + } else if capitalize_next { + result.push(c.to_ascii_uppercase()); + capitalize_next = false; + } else { + result.push(c.to_ascii_lowercase()); + } + } + + if result.is_empty() { + result = "Model".to_string(); + } + + result +} + +/// Converts a string to snake_case. +fn to_snake_case(s: &str) -> String { + let mut result = String::with_capacity(s.len() + 4); + let mut prev_was_upper = false; + let mut prev_was_underscore = true; + + for (i, c) in s.chars().enumerate() { + if c.is_ascii_uppercase() { + if i > 0 && !prev_was_upper && !prev_was_underscore { + result.push('_'); + } + result.push(c.to_ascii_lowercase()); + prev_was_upper = true; + prev_was_underscore = false; + } else if c == '-' || c == ' ' { + if !prev_was_underscore { + result.push('_'); + } + prev_was_upper = false; + prev_was_underscore = true; + } else if c == '_' { + if !prev_was_underscore { + result.push(c); + } + prev_was_upper = false; + prev_was_underscore = true; + } else { + result.push(c); + prev_was_upper = false; + prev_was_underscore = false; + } + } + + result +} + +/// Simple singularization of English words. +fn singularize(s: &str) -> String { + // Handle "ies" -> "y" (e.g., "categories" -> "category") + if s.ends_with("ies") && s.len() > 3 { + let base = &s[..s.len() - 3]; + return format!("{base}y"); + } + + // Handle "es" -> "" for certain endings (e.g., "addresses" -> "address") + if s.ends_with("sses") && s.len() > 4 { + return s[..s.len() - 2].to_string(); + } + if s.ends_with("xes") && s.len() > 3 { + return s[..s.len() - 2].to_string(); + } + if s.ends_with("ches") && s.len() > 4 { + return s[..s.len() - 2].to_string(); + } + if s.ends_with("shes") && s.len() > 4 { + return s[..s.len() - 2].to_string(); + } + + // Handle simple "s" -> "" (e.g., "users" -> "user") + if s.ends_with('s') + && !s.ends_with("ss") + && !s.ends_with("us") + && !s.ends_with("is") + && !s.ends_with("as") + && s.len() > 1 + { + return s[..s.len() - 1].to_string(); + } + + s.to_string() +} + +/// Converts a column name to a Rust field name. +/// +/// PostgreSQL column names are typically already in snake_case, but this +/// handles edge cases like mixed case or reserved words. +pub fn to_field_name(column_name: &str, strategy: &ReservedWordStrategy) -> SanitizedName { + // First sanitize the identifier + let sanitized = sanitize_identifier(column_name, strategy); + + // Convert to snake_case if not already + let snake = to_snake_case(&sanitized.identifier); + + if snake != sanitized.identifier || sanitized.needs_rename_attr { + SanitizedName::renamed(snake, column_name) + } else { + sanitized + } +} + +/// Converts a column name to a Rust enum variant name (PascalCase). +/// +/// Examples: +/// - "user_id" -> "UserId" +/// - "created_at" -> "CreatedAt" +pub fn to_enum_variant(column_name: &str) -> String { + to_pascal_case(column_name) +} + +/// Converts a foreign key relation to a relation enum variant name. +/// +/// Examples: +/// - FK from "posts" to "users" via "user_id" -> "User" +/// - FK from "comments" to "posts" via "post_id" -> "Post" +pub fn to_relation_name(target_table: &str) -> String { + to_pascal_case(&singularize(target_table)) +} + +/// Converts a table name to a plural relation name for has_many. +/// +/// Examples: +/// - "users" -> "Posts" (for has_many from users to posts) +/// - "categories" -> "Products" (for has_many) +pub fn to_plural_relation_name(target_table: &str) -> String { + let pascal = to_pascal_case(target_table); + // If already plural, return as-is; otherwise add 's' + if target_table.ends_with('s') || target_table.ends_with("es") { + pascal + } else { + format!("{pascal}s") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_rust_keyword() { + assert!(is_rust_keyword("type")); + assert!(is_rust_keyword("fn")); + assert!(is_rust_keyword("struct")); + assert!(is_rust_keyword("async")); + assert!(is_rust_keyword("self")); + assert!(is_rust_keyword("Self")); + assert!(!is_rust_keyword("user")); + assert!(!is_rust_keyword("name")); + } + + #[test] + fn test_is_reserved_word() { + assert!(is_reserved_word("type")); + assert!(is_reserved_word("abstract")); + assert!(!is_reserved_word("user")); + } + + #[test] + fn test_sanitize_identifier_valid() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let result = sanitize_identifier("user_id", &strategy); + assert_eq!(result.identifier, "user_id"); + assert!(!result.needs_rename_attr); + } + + #[test] + fn test_sanitize_identifier_reserved_word() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let result = sanitize_identifier("type", &strategy); + assert_eq!(result.identifier, "type_"); + assert!(result.needs_rename_attr); + assert_eq!(result.original, "type"); + } + + #[test] + fn test_sanitize_identifier_raw_identifier() { + let strategy = ReservedWordStrategy::RawIdentifier; + let result = sanitize_identifier("type", &strategy); + assert_eq!(result.identifier, "r#type"); + assert!(result.needs_rename_attr); + } + + #[test] + fn test_sanitize_identifier_prefix() { + let strategy = ReservedWordStrategy::PrependPrefix("field_".to_string()); + let result = sanitize_identifier("type", &strategy); + assert_eq!(result.identifier, "field_type"); + assert!(result.needs_rename_attr); + } + + #[test] + fn test_sanitize_identifier_starts_with_digit() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let result = sanitize_identifier("1column", &strategy); + assert_eq!(result.identifier, "_1column"); + assert!(result.needs_rename_attr); + } + + #[test] + fn test_sanitize_identifier_special_characters() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let result = sanitize_identifier("column-name", &strategy); + assert_eq!(result.identifier, "column_name"); + assert!(result.needs_rename_attr); + } + + #[test] + fn test_sanitize_identifier_multiple_underscores() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let result = sanitize_identifier("column__name", &strategy); + assert_eq!(result.identifier, "column_name"); + assert!(result.needs_rename_attr); + } + + #[test] + fn test_sanitize_identifier_empty() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let result = sanitize_identifier("", &strategy); + assert_eq!(result.identifier, "_empty"); + assert!(result.needs_rename_attr); + } + + #[test] + fn test_to_struct_name_simple() { + assert_eq!(to_struct_name("users"), "User"); + assert_eq!(to_struct_name("user"), "User"); + } + + #[test] + fn test_to_struct_name_compound() { + assert_eq!(to_struct_name("user_accounts"), "UserAccount"); + assert_eq!(to_struct_name("order_items"), "OrderItem"); + } + + #[test] + fn test_to_struct_name_categories() { + assert_eq!(to_struct_name("categories"), "Category"); + } + + #[test] + fn test_to_struct_name_preserves_non_plural() { + assert_eq!(to_struct_name("status"), "Status"); + assert_eq!(to_struct_name("address"), "Address"); + } + + #[test] + fn test_to_module_name_simple() { + assert_eq!(to_module_name("users"), "user"); + assert_eq!(to_module_name("User"), "user"); + } + + #[test] + fn test_to_module_name_compound() { + assert_eq!(to_module_name("user_accounts"), "user_account"); + assert_eq!(to_module_name("OrderItems"), "order_item"); + } + + #[test] + fn test_to_field_name_simple() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let result = to_field_name("user_id", &strategy); + assert_eq!(result.identifier, "user_id"); + assert!(!result.needs_rename_attr); + } + + #[test] + fn test_to_field_name_reserved() { + let strategy = ReservedWordStrategy::AppendUnderscore; + let result = to_field_name("type", &strategy); + assert_eq!(result.identifier, "type_"); + assert!(result.needs_rename_attr); + } + + #[test] + fn test_to_enum_variant() { + assert_eq!(to_enum_variant("user_id"), "UserId"); + assert_eq!(to_enum_variant("created_at"), "CreatedAt"); + assert_eq!(to_enum_variant("id"), "Id"); + } + + #[test] + fn test_to_relation_name() { + assert_eq!(to_relation_name("users"), "User"); + assert_eq!(to_relation_name("user_accounts"), "UserAccount"); + assert_eq!(to_relation_name("categories"), "Category"); + } + + #[test] + fn test_to_plural_relation_name() { + assert_eq!(to_plural_relation_name("post"), "Posts"); + assert_eq!(to_plural_relation_name("posts"), "Posts"); + assert_eq!(to_plural_relation_name("user_account"), "UserAccounts"); + } + + #[test] + fn test_singularize() { + assert_eq!(singularize("users"), "user"); + assert_eq!(singularize("categories"), "category"); + assert_eq!(singularize("addresses"), "address"); + assert_eq!(singularize("boxes"), "box"); + assert_eq!(singularize("status"), "status"); + assert_eq!(singularize("analysis"), "analysis"); + } + + #[test] + fn test_to_pascal_case() { + assert_eq!(to_pascal_case("user_account"), "UserAccount"); + assert_eq!(to_pascal_case("some-thing"), "SomeThing"); + assert_eq!(to_pascal_case("camelCase"), "Camelcase"); + } + + #[test] + fn test_to_snake_case() { + assert_eq!(to_snake_case("UserAccount"), "user_account"); + assert_eq!(to_snake_case("someThing"), "some_thing"); + assert_eq!(to_snake_case("already_snake"), "already_snake"); + } +} diff --git a/crates/codegen/src/rust/primary_key.rs b/crates/codegen/src/rust/primary_key.rs new file mode 100644 index 0000000..53782af --- /dev/null +++ b/crates/codegen/src/rust/primary_key.rs @@ -0,0 +1,322 @@ +//! Primary key detection and handling for SeaORM code generation. +//! +//! This module provides utilities for detecting and analyzing primary keys +//! from PostgreSQL table constraints. + +use tern_ddl::{Column, ColumnName, Constraint, ConstraintKind, IdentityKind, Table}; + +/// Information about a table's primary key. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PrimaryKeyInfo { + /// The columns that make up the primary key. + pub columns: Vec, + /// Whether this is a composite primary key (more than one column). + pub is_composite: bool, + /// Whether the primary key is auto-increment (identity or serial). + pub is_auto_increment: bool, +} + +impl PrimaryKeyInfo { + /// Checks if a column is part of this primary key. + pub fn contains_column(&self, column_name: &ColumnName) -> bool { + self.columns.iter().any(|c| c == column_name) + } + + /// Returns the number of columns in the primary key. + #[allow(dead_code)] + pub fn column_count(&self) -> usize { + self.columns.len() + } +} + +/// Extracts primary key information from a table. +/// +/// Returns `None` if the table has no primary key constraint. +pub fn get_primary_key(table: &Table) -> Option { + // Find the primary key constraint + let pk_constraint = table.constraints.iter().find_map(|c| { + if let ConstraintKind::PrimaryKey(pk) = &c.kind { + Some(pk) + } else { + None + } + }); + + let pk = pk_constraint?; + + // Determine if auto-increment + let is_auto_increment = if pk.columns.len() == 1 { + // Check if the single PK column is an identity or serial column + let pk_column_name = &pk.columns[0]; + table + .columns + .iter() + .find(|c| c.name == *pk_column_name) + .map(|c| is_column_auto_increment(c)) + .unwrap_or(false) + } else { + // Composite PKs are never auto-increment + false + }; + + Some(PrimaryKeyInfo { + columns: pk.columns.clone(), + is_composite: pk.columns.len() > 1, + is_auto_increment, + }) +} + +/// Determines if a column has auto-increment behavior. +/// +/// This is true for: +/// - Identity columns (GENERATED ALWAYS AS IDENTITY or GENERATED BY DEFAULT AS IDENTITY) +/// - Serial columns (detected by having a sequence default) +pub fn is_column_auto_increment(column: &Column) -> bool { + // Check for identity columns + if column.identity.is_some() { + return true; + } + + // Check for serial columns (identified by nextval sequence default) + if let Some(default) = &column.default { + let default_str = default.as_ref().to_lowercase(); + if default_str.contains("nextval(") { + return true; + } + } + + false +} + +/// Determines if a column is an identity column with ALWAYS generation. +pub fn is_identity_always(column: &Column) -> bool { + matches!(column.identity, Some(IdentityKind::Always)) +} + +/// Determines if a column is an identity column with BY DEFAULT generation. +#[allow(dead_code)] +pub fn is_identity_by_default(column: &Column) -> bool { + matches!(column.identity, Some(IdentityKind::ByDefault)) +} + +/// Checks if a column is part of the primary key. +pub fn is_primary_key_column(column: &Column, constraints: &[Constraint]) -> bool { + constraints.iter().any(|c| { + if let ConstraintKind::PrimaryKey(pk) = &c.kind { + pk.columns.contains(&column.name) + } else { + false + } + }) +} + +/// Gets the primary key columns for a table. +pub fn get_primary_key_columns(table: &Table) -> Vec<&Column> { + let pk_column_names: Vec<_> = table + .constraints + .iter() + .filter_map(|c| { + if let ConstraintKind::PrimaryKey(pk) = &c.kind { + Some(&pk.columns) + } else { + None + } + }) + .flatten() + .collect(); + + table + .columns + .iter() + .filter(|c| pk_column_names.contains(&&c.name)) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use tern_ddl::types::QualifiedCollationName; + use tern_ddl::{ + CollationName, ConstraintName, IndexName, Oid, PrimaryKeyConstraint, SchemaName, SqlExpr, + TableKind, TableName, TypeInfo, TypeName, + }; + + fn make_column(name: &str, type_name: &str, identity: Option) -> Column { + Column { + name: ColumnName::try_new(name.to_string()).unwrap(), + position: 1, + type_info: TypeInfo { + name: TypeName::try_new(type_name.to_string()).unwrap(), + schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(), + formatted: type_name.to_string(), + is_array: false, + }, + is_nullable: false, + default: None, + generated: None, + identity, + collation: QualifiedCollationName::new( + SchemaName::try_new("pg_catalog".to_string()).unwrap(), + CollationName::try_new("default".to_string()).unwrap(), + ), + comment: None, + } + } + + fn make_column_with_default(name: &str, type_name: &str, default: &str) -> Column { + let mut col = make_column(name, type_name, None); + col.default = Some(SqlExpr::new(default.to_string())); + col + } + + fn make_table_with_pk(name: &str, columns: Vec, pk_columns: &[&str]) -> Table { + let pk = Constraint { + name: ConstraintName::try_new(format!("{name}_pkey")).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: pk_columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + index_name: IndexName::try_new(format!("{name}_pkey")).unwrap(), + }), + comment: None, + }; + + Table { + oid: Oid::new(1), + name: TableName::try_new(name.to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints: vec![pk], + indexes: vec![], + comment: None, + } + } + + fn make_table_without_pk(name: &str, columns: Vec) -> Table { + Table { + oid: Oid::new(1), + name: TableName::try_new(name.to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints: vec![], + indexes: vec![], + comment: None, + } + } + + #[test] + fn test_single_column_pk() { + let columns = vec![ + make_column("id", "int4", Some(IdentityKind::Always)), + make_column("name", "text", None), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + + let pk_info = get_primary_key(&table).unwrap(); + assert_eq!(pk_info.columns.len(), 1); + assert_eq!(pk_info.columns[0].as_ref(), "id"); + assert!(!pk_info.is_composite); + assert!(pk_info.is_auto_increment); + } + + #[test] + fn test_composite_pk() { + let columns = vec![ + make_column("order_id", "int4", None), + make_column("product_id", "int4", None), + make_column("quantity", "int4", None), + ]; + let table = make_table_with_pk("order_items", columns, &["order_id", "product_id"]); + + let pk_info = get_primary_key(&table).unwrap(); + assert_eq!(pk_info.columns.len(), 2); + assert!(pk_info.is_composite); + assert!(!pk_info.is_auto_increment); + } + + #[test] + fn test_no_pk() { + let columns = vec![make_column("name", "text", None)]; + let table = make_table_without_pk("log_entries", columns); + + assert!(get_primary_key(&table).is_none()); + } + + #[test] + fn test_serial_pk() { + let columns = vec![ + make_column_with_default("id", "int4", "nextval('users_id_seq'::regclass)"), + make_column("name", "text", None), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + + let pk_info = get_primary_key(&table).unwrap(); + assert!(pk_info.is_auto_increment); + } + + #[test] + fn test_uuid_pk() { + let columns = vec![ + make_column("id", "uuid", None), + make_column("name", "text", None), + ]; + let table = make_table_with_pk("items", columns, &["id"]); + + let pk_info = get_primary_key(&table).unwrap(); + assert!(!pk_info.is_auto_increment); + } + + #[test] + fn test_pk_contains_column() { + let columns = vec![ + make_column("id", "int4", Some(IdentityKind::Always)), + make_column("name", "text", None), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + + let pk_info = get_primary_key(&table).unwrap(); + assert!(pk_info.contains_column(&ColumnName::try_new("id".to_string()).unwrap())); + assert!(!pk_info.contains_column(&ColumnName::try_new("name".to_string()).unwrap())); + } + + #[test] + fn test_is_primary_key_column() { + let columns = vec![ + make_column("id", "int4", Some(IdentityKind::Always)), + make_column("name", "text", None), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + + assert!(is_primary_key_column(&table.columns[0], &table.constraints)); + assert!(!is_primary_key_column( + &table.columns[1], + &table.constraints + )); + } + + #[test] + fn test_get_primary_key_columns() { + let columns = vec![ + make_column("order_id", "int4", None), + make_column("product_id", "int4", None), + make_column("quantity", "int4", None), + ]; + let table = make_table_with_pk("order_items", columns, &["order_id", "product_id"]); + + let pk_cols = get_primary_key_columns(&table); + assert_eq!(pk_cols.len(), 2); + } + + #[test] + fn test_is_identity_always() { + let col = make_column("id", "int4", Some(IdentityKind::Always)); + assert!(is_identity_always(&col)); + + let col2 = make_column("id", "int4", Some(IdentityKind::ByDefault)); + assert!(!is_identity_always(&col2)); + + let col3 = make_column("id", "int4", None); + assert!(!is_identity_always(&col3)); + } +} diff --git a/crates/codegen/src/rust/relation.rs b/crates/codegen/src/rust/relation.rs new file mode 100644 index 0000000..5bee4b0 --- /dev/null +++ b/crates/codegen/src/rust/relation.rs @@ -0,0 +1,702 @@ +//! Relation generation for SeaORM code generation. +//! +//! This module provides utilities for analyzing foreign keys and generating +//! SeaORM relation enums and Related trait implementations. + +use std::collections::HashMap; + +use tern_ddl::{ColumnName, Constraint, ConstraintKind, ForeignKeyConstraint, Table}; + +use super::naming::{to_enum_variant, to_module_name, to_plural_relation_name, to_relation_name}; + +/// The kind of relation between tables. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RelationKind { + /// This table has a foreign key to another table (many-to-one). + BelongsTo, + /// Another table has a foreign key to this table (one-to-many). + HasMany, + /// Another table has a unique foreign key to this table (one-to-one). + HasOne, +} + +/// Information about a relation for code generation. +#[derive(Debug, Clone)] +pub struct RelationInfo { + /// The kind of relation. + pub kind: RelationKind, + /// The target table name (without schema). + pub target_table: String, + /// The target table's schema (if not public). + #[allow(dead_code)] + pub target_schema: Option, + /// The columns in this table that form the FK. + pub from_columns: Vec, + /// The columns in the target table that are referenced. + pub to_columns: Vec, + /// The FK constraint name (for belongs_to relations). + #[allow(dead_code)] + pub constraint_name: Option, + /// Whether this is a self-referential relation. + pub is_self_referential: bool, + /// The on_delete action. + pub on_delete: Option, + /// The on_update action. + pub on_update: Option, +} + +impl RelationInfo { + /// Returns the relation enum variant name. + pub fn variant_name(&self) -> String { + match self.kind { + RelationKind::BelongsTo | RelationKind::HasOne => to_relation_name(&self.target_table), + RelationKind::HasMany => to_plural_relation_name(&self.target_table), + } + } + + /// Returns the target entity path for the relation attribute. + pub fn target_entity_path(&self, is_self_ref: bool) -> String { + if is_self_ref { + "Entity".to_string() + } else { + let module = to_module_name(&self.target_table); + format!("super::{module}::Entity") + } + } + + /// Returns the from column path for belongs_to relations. + pub fn from_column_path(&self) -> String { + if self.from_columns.len() == 1 { + let variant = to_enum_variant(self.from_columns[0].as_ref()); + format!("Column::{variant}") + } else { + // Composite FK + let variants: Vec<_> = self + .from_columns + .iter() + .map(|c| format!("Column::{}", to_enum_variant(c.as_ref()))) + .collect(); + format!("({})", variants.join(", ")) + } + } + + /// Returns the to column path for belongs_to relations. + pub fn to_column_path(&self) -> String { + if self.is_self_referential { + if self.to_columns.len() == 1 { + let variant = to_enum_variant(self.to_columns[0].as_ref()); + format!("Column::{variant}") + } else { + let variants: Vec<_> = self + .to_columns + .iter() + .map(|c| format!("Column::{}", to_enum_variant(c.as_ref()))) + .collect(); + format!("({})", variants.join(", ")) + } + } else { + let module = to_module_name(&self.target_table); + if self.to_columns.len() == 1 { + let variant = to_enum_variant(self.to_columns[0].as_ref()); + format!("super::{module}::Column::{variant}") + } else { + let variants: Vec<_> = self + .to_columns + .iter() + .map(|c| format!("super::{module}::Column::{}", to_enum_variant(c.as_ref()))) + .collect(); + format!("({})", variants.join(", ")) + } + } + } + + /// Generates the `#[sea_orm(...)]` attribute for this relation. + pub fn generate_relation_attr(&self) -> String { + match self.kind { + RelationKind::HasMany => { + let target = self.target_entity_path(self.is_self_referential); + format!("#[sea_orm(has_many = \"{target}\")]") + } + RelationKind::HasOne => { + let target = self.target_entity_path(self.is_self_referential); + format!("#[sea_orm(has_one = \"{target}\")]") + } + RelationKind::BelongsTo => { + let target = self.target_entity_path(self.is_self_referential); + let from = self.from_column_path(); + let to = self.to_column_path(); + + let mut attrs = vec![ + format!("belongs_to = \"{target}\""), + format!("from = \"{from}\""), + format!("to = \"{to}\""), + ]; + + // Add on_delete and on_update if not the default (NoAction) + if let Some(on_delete) = &self.on_delete { + if on_delete != "NoAction" { + attrs.push(format!("on_delete = \"{on_delete}\"")); + } + } + if let Some(on_update) = &self.on_update { + if on_update != "NoAction" { + attrs.push(format!("on_update = \"{on_update}\"")); + } + } + + format!("#[sea_orm(\n {}\n )]", attrs.join(",\n ")) + } + } + } + + /// Generates the Related trait implementation for this relation. + pub fn generate_related_impl(&self, _self_table: &str) -> String { + let target = self.target_entity_path(self.is_self_referential); + let variant = self.variant_name(); + + if self.is_self_referential { + format!( + r#"impl Related for Entity {{ + fn to() -> RelationDef {{ + Relation::{variant}.def() + }} +}}"# + ) + } else { + format!( + r#"impl Related<{target}> for Entity {{ + fn to() -> RelationDef {{ + Relation::{variant}.def() + }} +}}"# + ) + } + } +} + +/// Analyzes foreign keys to determine relations for all tables. +/// +/// Returns a map from table name to list of relations. +pub fn analyze_relations(tables: &[Table]) -> HashMap> { + let mut relations: HashMap> = HashMap::new(); + + // Build a set of table names for self-reference detection + let table_names: std::collections::HashSet<_> = + tables.iter().map(|t| t.name.as_ref().to_string()).collect(); + + for table in tables { + let table_name = table.name.as_ref().to_string(); + + for constraint in &table.constraints { + if let ConstraintKind::ForeignKey(fk) = &constraint.kind { + let target_table = fk.referenced_table.name.as_ref().to_string(); + let is_self_ref = target_table == table_name; + + // Table with FK: belongs_to + let belongs_to = RelationInfo { + kind: RelationKind::BelongsTo, + target_table: target_table.clone(), + target_schema: if fk.referenced_table.schema.as_ref() != "public" { + Some(fk.referenced_table.schema.as_ref().to_string()) + } else { + None + }, + from_columns: fk.columns.clone(), + to_columns: fk.referenced_columns.clone(), + constraint_name: Some(constraint.name.as_ref().to_string()), + is_self_referential: is_self_ref, + on_delete: Some(fk_action_to_sea_orm(fk.on_delete)), + on_update: Some(fk_action_to_sea_orm(fk.on_update)), + }; + + relations + .entry(table_name.clone()) + .or_default() + .push(belongs_to); + + // If the target table exists in our table list, add inverse relation + if table_names.contains(&target_table) && !is_self_ref { + let inverse_kind = if is_one_to_one_fk(fk, &table.constraints) { + RelationKind::HasOne + } else { + RelationKind::HasMany + }; + + let inverse = RelationInfo { + kind: inverse_kind, + target_table: table_name.clone(), + target_schema: None, // Same schema + from_columns: fk.referenced_columns.clone(), + to_columns: fk.columns.clone(), + constraint_name: None, + is_self_referential: false, + on_delete: None, + on_update: None, + }; + + relations + .entry(target_table.clone()) + .or_default() + .push(inverse); + } + } + } + } + + relations +} + +/// Determines if a foreign key implies a one-to-one relationship. +/// +/// This is true if the FK columns have a unique constraint. +fn is_one_to_one_fk(fk: &ForeignKeyConstraint, constraints: &[Constraint]) -> bool { + constraints.iter().any(|c| match &c.kind { + ConstraintKind::Unique(u) => u.columns == fk.columns, + ConstraintKind::PrimaryKey(pk) => pk.columns == fk.columns, + _ => false, + }) +} + +/// Converts a ForeignKeyAction to SeaORM string representation. +fn fk_action_to_sea_orm(action: tern_ddl::ForeignKeyAction) -> String { + match action { + tern_ddl::ForeignKeyAction::NoAction => "NoAction".to_string(), + tern_ddl::ForeignKeyAction::Restrict => "Restrict".to_string(), + tern_ddl::ForeignKeyAction::Cascade => "Cascade".to_string(), + tern_ddl::ForeignKeyAction::SetNull => "SetNull".to_string(), + tern_ddl::ForeignKeyAction::SetDefault => "SetDefault".to_string(), + } +} + +/// Checks if a table is likely a junction (many-to-many) table. +/// +/// A junction table typically: +/// 1. Has exactly two foreign keys +/// 2. Those foreign keys together form the primary key +/// 3. Has few or no other non-FK columns +pub fn is_junction_table(table: &Table) -> bool { + // Get FK constraints + let fks: Vec<_> = table + .constraints + .iter() + .filter_map(|c| { + if let ConstraintKind::ForeignKey(fk) = &c.kind { + Some(fk) + } else { + None + } + }) + .collect(); + + if fks.len() != 2 { + return false; + } + + // Get PK constraint + let pk = table.constraints.iter().find_map(|c| { + if let ConstraintKind::PrimaryKey(pk) = &c.kind { + Some(pk) + } else { + None + } + }); + + let Some(pk) = pk else { + return false; + }; + + // Check if PK columns match FK columns + let mut fk_columns: Vec<_> = fks.iter().flat_map(|fk| &fk.columns).collect(); + fk_columns.sort(); + + let mut pk_columns: Vec<_> = pk.columns.iter().collect(); + pk_columns.sort(); + + if fk_columns != pk_columns { + return false; + } + + // Check if there are few other columns + let non_pk_columns = table + .columns + .iter() + .filter(|c| !pk.columns.contains(&c.name)) + .count(); + + non_pk_columns <= 1 // Allow one additional column (like created_at) +} + +/// Gets the two target tables for a junction table. +pub fn get_junction_targets(table: &Table) -> Option<(String, String)> { + if !is_junction_table(table) { + return None; + } + + let fks: Vec<_> = table + .constraints + .iter() + .filter_map(|c| { + if let ConstraintKind::ForeignKey(fk) = &c.kind { + Some(fk) + } else { + None + } + }) + .collect(); + + if fks.len() == 2 { + Some(( + fks[0].referenced_table.name.as_ref().to_string(), + fks[1].referenced_table.name.as_ref().to_string(), + )) + } else { + None + } +} + +/// Generates Related impl with via() for many-to-many relations. +#[allow(dead_code)] +pub fn generate_many_to_many_related( + junction_table: &str, + from_table: &str, + to_table: &str, +) -> String { + let junction_module = to_module_name(junction_table); + let to_module = to_module_name(to_table); + let from_relation = to_relation_name(from_table); + let to_relation = to_relation_name(to_table); + + format!( + r#"impl Related for Entity {{ + fn to() -> RelationDef {{ + super::{junction_module}::Relation::{to_relation}.def() + }} + + fn via() -> Option {{ + Some(super::{junction_module}::Relation::{from_relation}.def().rev()) + }} +}}"# + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use tern_ddl::Column; + use tern_ddl::types::QualifiedCollationName; + use tern_ddl::{ + CollationName, ColumnName, ConstraintName, ForeignKeyAction, IndexName, Oid, + PrimaryKeyConstraint, QualifiedTableName, SchemaName, TableKind, TableName, TypeInfo, + TypeName, + }; + + fn make_column(name: &str, type_name: &str) -> Column { + Column { + name: ColumnName::try_new(name.to_string()).unwrap(), + position: 1, + type_info: TypeInfo { + name: TypeName::try_new(type_name.to_string()).unwrap(), + schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(), + formatted: type_name.to_string(), + is_array: false, + }, + is_nullable: false, + default: None, + generated: None, + identity: None, + collation: QualifiedCollationName::new( + SchemaName::try_new("pg_catalog".to_string()).unwrap(), + CollationName::try_new("default".to_string()).unwrap(), + ), + comment: None, + } + } + + fn make_pk_constraint(name: &str, columns: &[&str]) -> Constraint { + Constraint { + name: ConstraintName::try_new(name.to_string()).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + index_name: IndexName::try_new(name.to_string()).unwrap(), + }), + comment: None, + } + } + + fn make_fk_constraint( + name: &str, + from_columns: &[&str], + to_table: &str, + to_columns: &[&str], + ) -> Constraint { + Constraint { + name: ConstraintName::try_new(name.to_string()).unwrap(), + kind: ConstraintKind::ForeignKey(ForeignKeyConstraint { + columns: from_columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + referenced_table: QualifiedTableName::new( + SchemaName::try_new("public".to_string()).unwrap(), + TableName::try_new(to_table.to_string()).unwrap(), + ), + referenced_columns: to_columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + on_delete: ForeignKeyAction::NoAction, + on_update: ForeignKeyAction::NoAction, + is_deferrable: false, + is_initially_deferred: false, + }), + comment: None, + } + } + + fn make_table(name: &str, columns: Vec, constraints: Vec) -> Table { + Table { + oid: Oid::new(1), + name: TableName::try_new(name.to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints, + indexes: vec![], + comment: None, + } + } + + #[test] + fn test_relation_info_variant_name() { + let relation = RelationInfo { + kind: RelationKind::BelongsTo, + target_table: "users".to_string(), + target_schema: None, + from_columns: vec![ColumnName::try_new("user_id".to_string()).unwrap()], + to_columns: vec![ColumnName::try_new("id".to_string()).unwrap()], + constraint_name: None, + is_self_referential: false, + on_delete: None, + on_update: None, + }; + assert_eq!(relation.variant_name(), "User"); + + let has_many = RelationInfo { + kind: RelationKind::HasMany, + target_table: "posts".to_string(), + ..relation.clone() + }; + assert_eq!(has_many.variant_name(), "Posts"); + } + + #[test] + fn test_relation_info_target_entity_path() { + let relation = RelationInfo { + kind: RelationKind::BelongsTo, + target_table: "users".to_string(), + target_schema: None, + from_columns: vec![ColumnName::try_new("user_id".to_string()).unwrap()], + to_columns: vec![ColumnName::try_new("id".to_string()).unwrap()], + constraint_name: None, + is_self_referential: false, + on_delete: None, + on_update: None, + }; + assert_eq!(relation.target_entity_path(false), "super::user::Entity"); + assert_eq!(relation.target_entity_path(true), "Entity"); + } + + #[test] + fn test_relation_info_column_paths() { + let relation = RelationInfo { + kind: RelationKind::BelongsTo, + target_table: "users".to_string(), + target_schema: None, + from_columns: vec![ColumnName::try_new("user_id".to_string()).unwrap()], + to_columns: vec![ColumnName::try_new("id".to_string()).unwrap()], + constraint_name: None, + is_self_referential: false, + on_delete: None, + on_update: None, + }; + assert_eq!(relation.from_column_path(), "Column::UserId"); + assert_eq!(relation.to_column_path(), "super::user::Column::Id"); + } + + #[test] + fn test_relation_info_composite_fk() { + let relation = RelationInfo { + kind: RelationKind::BelongsTo, + target_table: "items".to_string(), + target_schema: None, + from_columns: vec![ + ColumnName::try_new("left_id".to_string()).unwrap(), + ColumnName::try_new("right_id".to_string()).unwrap(), + ], + to_columns: vec![ + ColumnName::try_new("left_id".to_string()).unwrap(), + ColumnName::try_new("right_id".to_string()).unwrap(), + ], + constraint_name: None, + is_self_referential: false, + on_delete: None, + on_update: None, + }; + assert_eq!( + relation.from_column_path(), + "(Column::LeftId, Column::RightId)" + ); + } + + #[test] + fn test_analyze_relations_simple() { + let users = make_table( + "users", + vec![make_column("id", "int4"), make_column("name", "text")], + vec![make_pk_constraint("users_pkey", &["id"])], + ); + + let posts = make_table( + "posts", + vec![ + make_column("id", "int4"), + make_column("user_id", "int4"), + make_column("title", "text"), + ], + vec![ + make_pk_constraint("posts_pkey", &["id"]), + make_fk_constraint("posts_user_id_fkey", &["user_id"], "users", &["id"]), + ], + ); + + let relations = analyze_relations(&[users, posts]); + + // posts should have belongs_to users + let post_relations = relations.get("posts").unwrap(); + assert_eq!(post_relations.len(), 1); + assert_eq!(post_relations[0].kind, RelationKind::BelongsTo); + assert_eq!(post_relations[0].target_table, "users"); + + // users should have has_many posts + let user_relations = relations.get("users").unwrap(); + assert_eq!(user_relations.len(), 1); + assert_eq!(user_relations[0].kind, RelationKind::HasMany); + assert_eq!(user_relations[0].target_table, "posts"); + } + + #[test] + fn test_analyze_relations_self_referential() { + let employees = make_table( + "employees", + vec![ + make_column("id", "int4"), + make_column("manager_id", "int4"), + make_column("name", "text"), + ], + vec![ + make_pk_constraint("employees_pkey", &["id"]), + make_fk_constraint( + "employees_manager_id_fkey", + &["manager_id"], + "employees", + &["id"], + ), + ], + ); + + let relations = analyze_relations(&[employees]); + + let employee_relations = relations.get("employees").unwrap(); + assert_eq!(employee_relations.len(), 1); + assert!(employee_relations[0].is_self_referential); + assert_eq!(employee_relations[0].kind, RelationKind::BelongsTo); + } + + #[test] + fn test_is_junction_table() { + let post_tags = make_table( + "post_tags", + vec![ + make_column("post_id", "int4"), + make_column("tag_id", "int4"), + ], + vec![ + make_pk_constraint("post_tags_pkey", &["post_id", "tag_id"]), + make_fk_constraint("post_tags_post_id_fkey", &["post_id"], "posts", &["id"]), + make_fk_constraint("post_tags_tag_id_fkey", &["tag_id"], "tags", &["id"]), + ], + ); + + assert!(is_junction_table(&post_tags)); + + let regular = make_table( + "posts", + vec![make_column("id", "int4"), make_column("title", "text")], + vec![make_pk_constraint("posts_pkey", &["id"])], + ); + + assert!(!is_junction_table(®ular)); + } + + #[test] + fn test_get_junction_targets() { + let post_tags = make_table( + "post_tags", + vec![ + make_column("post_id", "int4"), + make_column("tag_id", "int4"), + ], + vec![ + make_pk_constraint("post_tags_pkey", &["post_id", "tag_id"]), + make_fk_constraint("post_tags_post_id_fkey", &["post_id"], "posts", &["id"]), + make_fk_constraint("post_tags_tag_id_fkey", &["tag_id"], "tags", &["id"]), + ], + ); + + let targets = get_junction_targets(&post_tags).unwrap(); + assert_eq!(targets.0, "posts"); + assert_eq!(targets.1, "tags"); + } + + #[test] + fn test_generate_relation_attr_has_many() { + let relation = RelationInfo { + kind: RelationKind::HasMany, + target_table: "posts".to_string(), + target_schema: None, + from_columns: vec![], + to_columns: vec![], + constraint_name: None, + is_self_referential: false, + on_delete: None, + on_update: None, + }; + + let attr = relation.generate_relation_attr(); + assert!(attr.contains("has_many")); + assert!(attr.contains("super::post::Entity")); + } + + #[test] + fn test_generate_relation_attr_belongs_to() { + let relation = RelationInfo { + kind: RelationKind::BelongsTo, + target_table: "users".to_string(), + target_schema: None, + from_columns: vec![ColumnName::try_new("user_id".to_string()).unwrap()], + to_columns: vec![ColumnName::try_new("id".to_string()).unwrap()], + constraint_name: None, + is_self_referential: false, + on_delete: Some("Cascade".to_string()), + on_update: None, + }; + + let attr = relation.generate_relation_attr(); + assert!(attr.contains("belongs_to")); + assert!(attr.contains("from")); + assert!(attr.contains("to")); + assert!(attr.contains("on_delete")); + } +} diff --git a/crates/codegen/src/rust/tests/mod.rs b/crates/codegen/src/rust/tests/mod.rs new file mode 100644 index 0000000..6632d4f --- /dev/null +++ b/crates/codegen/src/rust/tests/mod.rs @@ -0,0 +1,4 @@ +//! Tests for Rust SeaORM code generation. + +mod snapshot_tests; +mod unit_tests; diff --git a/crates/codegen/src/rust/tests/snapshot_tests.rs b/crates/codegen/src/rust/tests/snapshot_tests.rs new file mode 100644 index 0000000..d63f010 --- /dev/null +++ b/crates/codegen/src/rust/tests/snapshot_tests.rs @@ -0,0 +1,233 @@ +//! Snapshot tests for Rust SeaORM code generation. +//! +//! These tests use insta to capture and verify generated code output. + +use tern_ddl::types::QualifiedCollationName; +use tern_ddl::{ + CollationName, Column, ColumnName, Constraint, ConstraintKind, ConstraintName, + ForeignKeyAction, ForeignKeyConstraint, IdentityKind, IndexName, Oid, PrimaryKeyConstraint, + QualifiedTableName, SchemaName, Table, TableKind, TableName, TypeInfo, TypeName, +}; + +use crate::Codegen; +use crate::rust::{EntityFormat, RustSeaOrmCodegen, RustSeaOrmCodegenConfig}; + +fn make_column(name: &str, type_name: &str, is_nullable: bool) -> Column { + Column { + name: ColumnName::try_new(name.to_string()).unwrap(), + position: 1, + type_info: TypeInfo { + name: TypeName::try_new(type_name.to_string()).unwrap(), + schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(), + formatted: type_name.to_string(), + is_array: false, + }, + is_nullable, + default: None, + generated: None, + identity: None, + collation: QualifiedCollationName::new( + SchemaName::try_new("pg_catalog".to_string()).unwrap(), + CollationName::try_new("default".to_string()).unwrap(), + ), + comment: None, + } +} + +fn make_column_with_identity(name: &str, type_name: &str, identity: IdentityKind) -> Column { + let mut col = make_column(name, type_name, false); + col.identity = Some(identity); + col +} + +fn make_table_with_pk(name: &str, columns: Vec, pk_columns: &[&str]) -> Table { + let pk = Constraint { + name: ConstraintName::try_new(format!("{name}_pkey")).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: pk_columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + index_name: IndexName::try_new(format!("{name}_pkey")).unwrap(), + }), + comment: None, + }; + + Table { + oid: Oid::new(1), + name: TableName::try_new(name.to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints: vec![pk], + indexes: vec![], + comment: None, + } +} + +fn make_table_with_fk( + name: &str, + columns: Vec, + pk_columns: &[&str], + fk_column: &str, + fk_target_table: &str, + fk_target_column: &str, +) -> Table { + let pk = Constraint { + name: ConstraintName::try_new(format!("{name}_pkey")).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: pk_columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + index_name: IndexName::try_new(format!("{name}_pkey")).unwrap(), + }), + comment: None, + }; + + let fk = Constraint { + name: ConstraintName::try_new(format!("{name}_{fk_column}_fkey")).unwrap(), + kind: ConstraintKind::ForeignKey(ForeignKeyConstraint { + columns: vec![ColumnName::try_new(fk_column.to_string()).unwrap()], + referenced_table: QualifiedTableName::new( + SchemaName::try_new("public".to_string()).unwrap(), + TableName::try_new(fk_target_table.to_string()).unwrap(), + ), + referenced_columns: vec![ColumnName::try_new(fk_target_column.to_string()).unwrap()], + on_delete: ForeignKeyAction::Cascade, + on_update: ForeignKeyAction::NoAction, + is_deferrable: false, + is_initially_deferred: false, + }), + comment: None, + }; + + Table { + oid: Oid::new(1), + name: TableName::try_new(name.to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints: vec![pk, fk], + indexes: vec![], + comment: None, + } +} + +#[test] +fn snapshot_simple_entity_compact() { + let columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("name", "text", false), + make_column("email", "text", true), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + insta::assert_snapshot!(output.get("user.rs").unwrap()); +} + +#[test] +fn snapshot_simple_entity_expanded() { + let columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("name", "text", false), + make_column("email", "text", true), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + + let config = RustSeaOrmCodegenConfig { + entity_format: EntityFormat::Expanded, + ..Default::default() + }; + let codegen = RustSeaOrmCodegen::new(config); + let output = codegen.generate(vec![table]); + + insta::assert_snapshot!(output.get("user.rs").unwrap()); +} + +#[test] +fn snapshot_with_relations() { + let user_columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("name", "text", false), + ]; + let user_table = make_table_with_pk("users", user_columns, &["id"]); + + let post_columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("user_id", "int4", false), + make_column("title", "text", false), + make_column("body", "text", true), + ]; + let post_table = make_table_with_fk("posts", post_columns, &["id"], "user_id", "users", "id"); + + let config = RustSeaOrmCodegenConfig { + generate_relations: true, + ..Default::default() + }; + let codegen = RustSeaOrmCodegen::new(config); + let output = codegen.generate(vec![user_table, post_table]); + + insta::assert_snapshot!("user_with_relations", output.get("user.rs").unwrap()); + insta::assert_snapshot!("post_with_relations", output.get("post.rs").unwrap()); +} + +#[test] +fn snapshot_mod_file() { + let user_columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("name", "text", false), + ]; + let user_table = make_table_with_pk("users", user_columns, &["id"]); + + let post_columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("title", "text", false), + ]; + let post_table = make_table_with_pk("posts", post_columns, &["id"]); + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![user_table, post_table]); + + insta::assert_snapshot!(output.get("mod.rs").unwrap()); +} + +#[test] +fn snapshot_prelude_file() { + let user_columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("name", "text", false), + ]; + let user_table = make_table_with_pk("users", user_columns, &["id"]); + + let post_columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("title", "text", false), + ]; + let post_table = make_table_with_pk("posts", post_columns, &["id"]); + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![user_table, post_table]); + + insta::assert_snapshot!(output.get("prelude.rs").unwrap()); +} + +#[test] +fn snapshot_complex_types() { + let columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("uuid_col", "uuid", false), + make_column("created_at", "timestamptz", false), + make_column("updated_at", "timestamp", true), + make_column("metadata", "jsonb", true), + make_column("price", "numeric", false), + make_column("is_active", "bool", false), + ]; + let table = make_table_with_pk("complex_table", columns, &["id"]); + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + insta::assert_snapshot!(output.get("complex_table.rs").unwrap()); +} diff --git a/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__post_with_relations.snap b/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__post_with_relations.snap new file mode 100644 index 0000000..97a8677 --- /dev/null +++ b/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__post_with_relations.snap @@ -0,0 +1,38 @@ +--- +source: crates/codegen/src/rust/tests/snapshot_tests.rs +expression: "output.get(\"post.rs\").unwrap()" +--- +//! SeaORM entity for `posts` table. +//! +//! Generated by Tern. + +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "posts")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub user_id: i32, + pub title: String, + pub body: Option, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id", + on_delete = "Cascade" + )] + User, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::User.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_complex_types.snap b/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_complex_types.snap new file mode 100644 index 0000000..e125ef7 --- /dev/null +++ b/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_complex_types.snap @@ -0,0 +1,33 @@ +--- +source: crates/codegen/src/rust/tests/snapshot_tests.rs +expression: "output.get(\"complex_table.rs\").unwrap()" +--- +//! SeaORM entity for `complex_table` table. +//! +//! Generated by Tern. + +use chrono::{DateTime, FixedOffset, NaiveDateTime}; +use rust_decimal::Decimal; +use sea_orm::entity::prelude::*; +use serde_json::Value; +use uuid::Uuid; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "complex_table")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub uuid_col: uuid::Uuid, + #[sea_orm(column_type = "TimestampWithTimeZone")] + pub created_at: chrono::DateTime, + pub updated_at: Option, + #[sea_orm(column_type = "JsonBinary", nullable)] + pub metadata: Option, + pub price: rust_decimal::Decimal, + pub is_active: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_mod_file.snap b/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_mod_file.snap new file mode 100644 index 0000000..0ac7739 --- /dev/null +++ b/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_mod_file.snap @@ -0,0 +1,12 @@ +--- +source: crates/codegen/src/rust/tests/snapshot_tests.rs +expression: "output.get(\"mod.rs\").unwrap()" +--- +//! SeaORM entity definitions generated by Tern. +//! +//! This file was automatically generated. Do not edit manually. + +pub mod prelude; + +pub mod user; +pub mod post; diff --git a/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_prelude_file.snap b/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_prelude_file.snap new file mode 100644 index 0000000..01e0596 --- /dev/null +++ b/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_prelude_file.snap @@ -0,0 +1,10 @@ +--- +source: crates/codegen/src/rust/tests/snapshot_tests.rs +expression: "output.get(\"prelude.rs\").unwrap()" +--- +//! Re-exports commonly used entity types. +//! +//! This file was automatically generated. Do not edit manually. + +pub use super::user::Entity as User; +pub use super::post::Entity as Post; diff --git a/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_simple_entity_compact.snap b/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_simple_entity_compact.snap new file mode 100644 index 0000000..f42bb6d --- /dev/null +++ b/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_simple_entity_compact.snap @@ -0,0 +1,23 @@ +--- +source: crates/codegen/src/rust/tests/snapshot_tests.rs +expression: "output.get(\"user.rs\").unwrap()" +--- +//! SeaORM entity for `users` table. +//! +//! Generated by Tern. + +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "users")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub name: String, + pub email: Option, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_simple_entity_expanded.snap b/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_simple_entity_expanded.snap new file mode 100644 index 0000000..7d17ad6 --- /dev/null +++ b/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__snapshot_simple_entity_expanded.snap @@ -0,0 +1,66 @@ +--- +source: crates/codegen/src/rust/tests/snapshot_tests.rs +expression: "output.get(\"user.rs\").unwrap()" +--- +//! SeaORM entity for `users` table. +//! +//! Generated by Tern. + +use sea_orm::entity::prelude::*; + +#[derive(Copy, Clone, Default, Debug, DeriveEntity)] +pub struct Entity; + +impl EntityName for Entity { + fn schema_name(&self) -> Option<&str> { + None + } + + fn table_name(&self) -> &str { + "users" + } +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] +pub enum Column { + Id, + Name, + Email, +} + +impl ColumnTrait for Column { + type EntityName = Entity; + + fn def(&self) -> ColumnDef { + match self { + Self::Id => ColumnType::Integer.def(), + Self::Name => ColumnType::Text.def(), + Self::Email => ColumnType::Text.def().nullable(), + } + } +} + +#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)] +pub enum PrimaryKey { + Id, +} + +impl PrimaryKeyTrait for PrimaryKey { + type ValueType = i32; + + fn auto_increment() -> bool { + true + } +} + +#[derive(Clone, Debug, PartialEq, Eq, DeriveModel, DeriveActiveModel)] +pub struct Model { + pub id: i32, + pub name: String, + pub email: Option, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__user_with_relations.snap b/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__user_with_relations.snap new file mode 100644 index 0000000..5b66ffa --- /dev/null +++ b/crates/codegen/src/rust/tests/snapshots/tern_codegen__rust__tests__snapshot_tests__user_with_relations.snap @@ -0,0 +1,31 @@ +--- +source: crates/codegen/src/rust/tests/snapshot_tests.rs +expression: "output.get(\"user.rs\").unwrap()" +--- +//! SeaORM entity for `users` table. +//! +//! Generated by Tern. + +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "users")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub name: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::post::Entity")] + Posts, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Posts.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/codegen/src/rust/tests/unit_tests.rs b/crates/codegen/src/rust/tests/unit_tests.rs new file mode 100644 index 0000000..1239bc6 --- /dev/null +++ b/crates/codegen/src/rust/tests/unit_tests.rs @@ -0,0 +1,494 @@ +//! Unit tests for Rust SeaORM code generation components. + +use tern_ddl::types::QualifiedCollationName; +use tern_ddl::{ + CollationName, Column, ColumnName, Constraint, ConstraintKind, ConstraintName, + ForeignKeyAction, ForeignKeyConstraint, IdentityKind, IndexName, Oid, PrimaryKeyConstraint, + QualifiedTableName, SchemaName, Table, TableKind, TableName, TypeInfo, TypeName, + UniqueConstraint, +}; + +use crate::Codegen; +use crate::rust::{ + EntityFormat, OutputMode, ReservedWordStrategy, RustSeaOrmCodegen, RustSeaOrmCodegenConfig, +}; + +/// Test utility: creates a basic column. +fn make_column(name: &str, type_name: &str, is_nullable: bool) -> Column { + Column { + name: ColumnName::try_new(name.to_string()).unwrap(), + position: 1, + type_info: TypeInfo { + name: TypeName::try_new(type_name.to_string()).unwrap(), + schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(), + formatted: type_name.to_string(), + is_array: false, + }, + is_nullable, + default: None, + generated: None, + identity: None, + collation: QualifiedCollationName::new( + SchemaName::try_new("pg_catalog".to_string()).unwrap(), + CollationName::try_new("default".to_string()).unwrap(), + ), + comment: None, + } +} + +/// Test utility: creates a column with identity. +fn make_column_with_identity(name: &str, type_name: &str, identity: IdentityKind) -> Column { + let mut col = make_column(name, type_name, false); + col.identity = Some(identity); + col +} + +/// Test utility: creates a column with a specific formatted type. +fn make_column_formatted( + name: &str, + type_name: &str, + formatted: &str, + is_nullable: bool, +) -> Column { + let mut col = make_column(name, type_name, is_nullable); + col.type_info.formatted = formatted.to_string(); + col +} + +/// Test utility: creates a table with a primary key. +fn make_table_with_pk(name: &str, columns: Vec, pk_columns: &[&str]) -> Table { + let pk = Constraint { + name: ConstraintName::try_new(format!("{name}_pkey")).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: pk_columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + index_name: IndexName::try_new(format!("{name}_pkey")).unwrap(), + }), + comment: None, + }; + + Table { + oid: Oid::new(1), + name: TableName::try_new(name.to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints: vec![pk], + indexes: vec![], + comment: None, + } +} + +/// Test utility: creates a table with FK constraint. +fn make_table_with_fk( + name: &str, + columns: Vec, + pk_columns: &[&str], + fk_column: &str, + fk_target_table: &str, + fk_target_column: &str, +) -> Table { + let pk = Constraint { + name: ConstraintName::try_new(format!("{name}_pkey")).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: pk_columns + .iter() + .map(|c| ColumnName::try_new(c.to_string()).unwrap()) + .collect(), + index_name: IndexName::try_new(format!("{name}_pkey")).unwrap(), + }), + comment: None, + }; + + let fk = Constraint { + name: ConstraintName::try_new(format!("{name}_{fk_column}_fkey")).unwrap(), + kind: ConstraintKind::ForeignKey(ForeignKeyConstraint { + columns: vec![ColumnName::try_new(fk_column.to_string()).unwrap()], + referenced_table: QualifiedTableName::new( + SchemaName::try_new("public".to_string()).unwrap(), + TableName::try_new(fk_target_table.to_string()).unwrap(), + ), + referenced_columns: vec![ColumnName::try_new(fk_target_column.to_string()).unwrap()], + on_delete: ForeignKeyAction::NoAction, + on_update: ForeignKeyAction::NoAction, + is_deferrable: false, + is_initially_deferred: false, + }), + comment: None, + }; + + Table { + oid: Oid::new(1), + name: TableName::try_new(name.to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints: vec![pk, fk], + indexes: vec![], + comment: None, + } +} + +// ============================================================================= +// Integration Tests +// ============================================================================= + +#[test] +fn test_simple_table() { + let columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("name", "text", false), + make_column("email", "text", true), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + assert!(output.contains_key("user.rs")); + let code = &output["user.rs"]; + + // Check model struct + assert!(code.contains("pub struct Model")); + assert!(code.contains("table_name = \"users\"")); + + // Check fields + assert!(code.contains("pub id: i32")); + assert!(code.contains("pub name: String")); + assert!(code.contains("pub email: Option")); + + // Check primary key + assert!(code.contains("#[sea_orm(primary_key)]")); +} + +#[test] +fn test_composite_primary_key() { + let columns = vec![ + make_column("order_id", "int4", false), + make_column("product_id", "int4", false), + make_column("quantity", "int4", false), + ]; + let table = make_table_with_pk("order_items", columns, &["order_id", "product_id"]); + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let code = &output["order_item.rs"]; + + // Both PK columns should have primary_key and auto_increment = false + assert!(code.contains("primary_key")); + assert!(code.contains("auto_increment = false")); +} + +#[test] +fn test_reserved_word_column() { + let columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("type", "text", false), // "type" is a reserved word + make_column("match", "text", true), // "match" is also reserved + ]; + let table = make_table_with_pk("items", columns, &["id"]); + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let code = &output["item.rs"]; + + // Should have column_name attributes for reserved words + assert!(code.contains("column_name = \"type\"")); + assert!(code.contains("type_:")); // Escaped field name +} + +#[test] +fn test_raw_identifier_strategy() { + let columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("type", "text", false), + ]; + let table = make_table_with_pk("items", columns, &["id"]); + + let config = RustSeaOrmCodegenConfig { + reserved_word_strategy: ReservedWordStrategy::RawIdentifier, + ..Default::default() + }; + let codegen = RustSeaOrmCodegen::new(config); + let output = codegen.generate(vec![table]); + + let code = &output["item.rs"]; + assert!(code.contains("r#type")); +} + +#[test] +fn test_datetime_types() { + let columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("created_at", "timestamptz", false), + make_column("updated_at", "timestamp", true), + make_column("birth_date", "date", true), + ]; + let table = make_table_with_pk("events", columns, &["id"]); + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let code = &output["event.rs"]; + + assert!(code.contains("chrono::DateTime")); + assert!(code.contains("Option")); + assert!(code.contains("Option")); +} + +#[test] +fn test_json_types() { + let columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("config", "jsonb", false), + make_column("metadata", "json", true), + ]; + let table = make_table_with_pk("settings", columns, &["id"]); + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let code = &output["setting.rs"]; + + assert!(code.contains("serde_json::Value")); + assert!(code.contains("column_type = \"JsonBinary\"") || code.contains("JsonBinary")); +} + +#[test] +fn test_array_types() { + let mut tags_column = make_column("tags", "text", false); + tags_column.type_info.is_array = true; + + let columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + tags_column, + ]; + let table = make_table_with_pk("posts", columns, &["id"]); + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let code = &output["post.rs"]; + assert!(code.contains("Vec")); +} + +#[test] +fn test_relations_generation() { + let user_columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("name", "text", false), + ]; + let user_table = make_table_with_pk("users", user_columns, &["id"]); + + let post_columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("user_id", "int4", false), + make_column("title", "text", false), + ]; + let post_table = make_table_with_fk("posts", post_columns, &["id"], "user_id", "users", "id"); + + let config = RustSeaOrmCodegenConfig { + generate_relations: true, + ..Default::default() + }; + let codegen = RustSeaOrmCodegen::new(config); + let output = codegen.generate(vec![user_table, post_table]); + + // User should have has_many Posts + let user_code = &output["user.rs"]; + assert!(user_code.contains("has_many")); + assert!(user_code.contains("Posts")); + + // Post should have belongs_to User + let post_code = &output["post.rs"]; + assert!(user_code.contains("has_many")); + assert!(post_code.contains("belongs_to")); + assert!(post_code.contains("User")); +} + +#[test] +fn test_expanded_format() { + let columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("name", "text", false), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + + let config = RustSeaOrmCodegenConfig { + entity_format: EntityFormat::Expanded, + ..Default::default() + }; + let codegen = RustSeaOrmCodegen::new(config); + let output = codegen.generate(vec![table]); + + let code = &output["user.rs"]; + + // Expanded format has explicit types + assert!(code.contains("DeriveEntity")); + assert!(code.contains("pub struct Entity;")); + assert!(code.contains("impl EntityName for Entity")); + assert!(code.contains("pub enum Column")); + assert!(code.contains("pub enum PrimaryKey")); + assert!(code.contains("impl ColumnTrait for Column")); + assert!(code.contains("impl PrimaryKeyTrait for PrimaryKey")); +} + +#[test] +fn test_single_file_output() { + let columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("name", "text", false), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + + let config = RustSeaOrmCodegenConfig { + output_mode: OutputMode::SingleFile, + ..Default::default() + }; + let codegen = RustSeaOrmCodegen::new(config); + let output = codegen.generate(vec![table]); + + assert!(output.contains_key("entities.rs")); + assert!(!output.contains_key("user.rs")); + assert!(!output.contains_key("mod.rs")); + + let code = &output["entities.rs"]; + assert!(code.contains("pub mod user")); +} + +#[test] +fn test_multi_file_output() { + let columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("name", "text", false), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + + let config = RustSeaOrmCodegenConfig { + output_mode: OutputMode::MultiFile, + ..Default::default() + }; + let codegen = RustSeaOrmCodegen::new(config); + let output = codegen.generate(vec![table]); + + assert!(output.contains_key("mod.rs")); + assert!(output.contains_key("prelude.rs")); + assert!(output.contains_key("user.rs")); +} + +#[test] +fn test_uuid_primary_key() { + let columns = vec![ + make_column("id", "uuid", false), + make_column("name", "text", false), + ]; + let table = make_table_with_pk("items", columns, &["id"]); + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let code = &output["item.rs"]; + + assert!(code.contains("uuid::Uuid")); + assert!(code.contains("auto_increment = false")); +} + +#[test] +fn test_numeric_types() { + let columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column_formatted("price", "numeric", "numeric(10,2)", false), + make_column("quantity", "int2", false), + make_column("big_value", "int8", true), + ]; + let table = make_table_with_pk("products", columns, &["id"]); + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let code = &output["product.rs"]; + + assert!(code.contains("rust_decimal::Decimal")); + assert!(code.contains("i16")); // smallint + assert!(code.contains("Option")); // bigint nullable +} + +#[test] +fn test_varchar_with_length() { + let columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column_formatted("name", "varchar", "character varying(255)", false), + ]; + let table = make_table_with_pk("users", columns, &["id"]); + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let code = &output["user.rs"]; + assert!(code.contains("pub name: String")); +} + +#[test] +fn test_table_without_primary_key() { + let columns = vec![make_column("data", "text", false)]; + let table = Table { + oid: Oid::new(1), + name: TableName::try_new("log_entries".to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints: vec![], + indexes: vec![], + comment: None, + }; + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let code = &output["log_entry.rs"]; + assert!(code.contains("WARNING") || code.contains("no primary key")); +} + +#[test] +fn test_unique_constraint() { + let columns = vec![ + make_column_with_identity("id", "int4", IdentityKind::Always), + make_column("email", "text", false), + ]; + + let pk = Constraint { + name: ConstraintName::try_new("users_pkey".to_string()).unwrap(), + kind: ConstraintKind::PrimaryKey(PrimaryKeyConstraint { + columns: vec![ColumnName::try_new("id".to_string()).unwrap()], + index_name: IndexName::try_new("users_pkey".to_string()).unwrap(), + }), + comment: None, + }; + + let unique = Constraint { + name: ConstraintName::try_new("users_email_key".to_string()).unwrap(), + kind: ConstraintKind::Unique(UniqueConstraint { + columns: vec![ColumnName::try_new("email".to_string()).unwrap()], + index_name: IndexName::try_new("users_email_key".to_string()).unwrap(), + nulls_not_distinct: false, + }), + comment: None, + }; + + let table = Table { + oid: Oid::new(1), + name: TableName::try_new("users".to_string()).unwrap(), + kind: TableKind::Regular, + columns, + constraints: vec![pk, unique], + indexes: vec![], + comment: None, + }; + + let codegen = RustSeaOrmCodegen::with_defaults(); + let output = codegen.generate(vec![table]); + + let code = &output["user.rs"]; + assert!(code.contains("#[sea_orm(unique)]")); +} diff --git a/crates/codegen/src/rust/type_mapping.rs b/crates/codegen/src/rust/type_mapping.rs new file mode 100644 index 0000000..617e8eb --- /dev/null +++ b/crates/codegen/src/rust/type_mapping.rs @@ -0,0 +1,775 @@ +//! PostgreSQL to Rust/SeaORM type mapping. +//! +//! This module handles the conversion of PostgreSQL types to their Rust equivalents +//! for use in SeaORM entity definitions. + +use tern_ddl::TypeInfo; + +/// Represents a Rust type with its SeaORM mapping and import requirements. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RustType { + /// The Rust type annotation (e.g., "i32", "String", "chrono::DateTime"). + pub annotation: String, + /// The SeaORM ColumnType expression (e.g., "ColumnType::Integer"). + pub column_type: String, + /// Required SeaORM feature flags. + pub required_features: Vec<&'static str>, + /// Required imports (module path, type name). + pub imports: Vec, + /// Whether this type requires explicit column_type attribute in compact format. + pub needs_column_type_attr: bool, +} + +/// A Rust import statement. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct RustImport { + /// The full module path (e.g., "chrono", "serde_json"). + pub module: String, + /// The type name to import (e.g., "DateTime", "Value"). + pub name: String, +} + +impl RustImport { + /// Creates a new Rust import. + pub fn new(module: impl Into, name: impl Into) -> Self { + Self { + module: module.into(), + name: name.into(), + } + } + + /// Creates a chrono import. + pub fn chrono(name: &str) -> Self { + Self::new("chrono", name) + } + + /// Creates a serde_json import. + pub fn serde_json(name: &str) -> Self { + Self::new("serde_json", name) + } + + /// Creates a uuid import. + pub fn uuid() -> Self { + Self::new("uuid", "Uuid") + } + + /// Creates a rust_decimal import. + pub fn rust_decimal() -> Self { + Self::new("rust_decimal", "Decimal") + } + + /// Creates an ipnetwork import. + pub fn ipnetwork() -> Self { + Self::new("ipnetwork", "IpNetwork") + } + + /// Creates a mac_address import. + pub fn mac_address() -> Self { + Self::new("mac_address", "MacAddress") + } +} + +impl RustType { + /// Creates a simple Rust type with no imports or special requirements. + pub fn simple(annotation: &str, column_type: &str) -> Self { + Self { + annotation: annotation.to_string(), + column_type: column_type.to_string(), + required_features: Vec::new(), + imports: Vec::new(), + needs_column_type_attr: false, + } + } + + /// Creates a Rust type with imports. + pub fn with_imports( + annotation: &str, + column_type: &str, + imports: Vec, + features: Vec<&'static str>, + ) -> Self { + Self { + annotation: annotation.to_string(), + column_type: column_type.to_string(), + required_features: features, + imports, + needs_column_type_attr: false, + } + } + + /// Creates a Rust type that needs explicit column_type attribute. + pub fn with_column_type_attr( + annotation: &str, + column_type: &str, + imports: Vec, + features: Vec<&'static str>, + ) -> Self { + Self { + annotation: annotation.to_string(), + column_type: column_type.to_string(), + required_features: features, + imports, + needs_column_type_attr: true, + } + } + + /// Wraps this type in Option for nullable columns. + pub fn as_optional(&self) -> Self { + Self { + annotation: format!("Option<{}>", self.annotation), + column_type: self.column_type.clone(), + required_features: self.required_features.clone(), + imports: self.imports.clone(), + needs_column_type_attr: self.needs_column_type_attr, + } + } +} + +/// Maps a PostgreSQL type to a Rust type and SeaORM column type. +/// +/// This function handles both the raw type name and the formatted type with modifiers. +/// The `formatted` field is used to extract precision/scale for numeric types. +pub fn map_pg_type(type_info: &TypeInfo) -> RustType { + let type_name = type_info.name.as_ref(); + let formatted = &type_info.formatted; + + // Handle array types first + if type_info.is_array { + return map_array_type(type_name); + } + + // Map based on type name (canonical PostgreSQL type names) + match type_name { + // Integer types + "int2" | "smallint" => RustType::simple("i16", "ColumnType::SmallInteger"), + "int4" | "integer" | "int" => RustType::simple("i32", "ColumnType::Integer"), + "int8" | "bigint" => RustType::simple("i64", "ColumnType::BigInteger"), + + // Serial types (same Rust types as integers) + "serial" | "serial4" => RustType::simple("i32", "ColumnType::Integer"), + "bigserial" | "serial8" => RustType::simple("i64", "ColumnType::BigInteger"), + "smallserial" | "serial2" => RustType::simple("i16", "ColumnType::SmallInteger"), + + // Floating point types + "float4" | "real" => RustType::simple("f32", "ColumnType::Float"), + "float8" | "double precision" => RustType::simple("f64", "ColumnType::Double"), + + // Numeric/decimal types + "numeric" | "decimal" => { + let column_type = extract_numeric_column_type(formatted); + RustType::with_imports( + "rust_decimal::Decimal", + &column_type, + vec![RustImport::rust_decimal()], + vec!["with-rust_decimal"], + ) + } + + // Money type + "money" => { + let column_type = "ColumnType::Money(Some((19, 2)))"; + RustType::with_imports( + "rust_decimal::Decimal", + column_type, + vec![RustImport::rust_decimal()], + vec!["with-rust_decimal"], + ) + } + + // Boolean + "bool" | "boolean" => RustType::simple("bool", "ColumnType::Boolean"), + + // Text types + "text" => RustType::simple("String", "ColumnType::Text"), + "varchar" | "character varying" => { + let column_type = extract_varchar_column_type(formatted); + RustType::simple("String", &column_type) + } + "char" | "character" | "bpchar" => { + let column_type = extract_char_column_type(formatted); + RustType::simple("String", &column_type) + } + "name" => RustType::simple("String", "ColumnType::String(StringLen::N(64))"), + + // Date/time types + "date" => RustType::with_imports( + "chrono::NaiveDate", + "ColumnType::Date", + vec![RustImport::chrono("NaiveDate")], + vec!["with-chrono"], + ), + "time" | "time without time zone" => RustType::with_imports( + "chrono::NaiveTime", + "ColumnType::Time", + vec![RustImport::chrono("NaiveTime")], + vec!["with-chrono"], + ), + "timetz" | "time with time zone" => { + // Time with timezone - SeaORM doesn't have great support, use NaiveTime + RustType::with_imports( + "chrono::NaiveTime", + "ColumnType::Time", + vec![RustImport::chrono("NaiveTime")], + vec!["with-chrono"], + ) + } + "timestamp" | "timestamp without time zone" => RustType::with_imports( + "chrono::NaiveDateTime", + "ColumnType::DateTime", + vec![RustImport::chrono("NaiveDateTime")], + vec!["with-chrono"], + ), + "timestamptz" | "timestamp with time zone" => RustType::with_column_type_attr( + "chrono::DateTime", + "ColumnType::TimestampWithTimeZone", + vec![ + RustImport::chrono("DateTime"), + RustImport::chrono("FixedOffset"), + ], + vec!["with-chrono"], + ), + "interval" => { + // Interval doesn't have a native Rust equivalent, use String + RustType::with_column_type_attr( + "String", + "ColumnType::Interval(None, None)", + vec![], + vec![], + ) + } + + // UUID + "uuid" => RustType::with_imports( + "uuid::Uuid", + "ColumnType::Uuid", + vec![RustImport::uuid()], + vec!["with-uuid"], + ), + + // JSON types + "json" => RustType::with_column_type_attr( + "serde_json::Value", + "ColumnType::Json", + vec![RustImport::serde_json("Value")], + vec!["with-json"], + ), + "jsonb" => RustType::with_column_type_attr( + "serde_json::Value", + "ColumnType::JsonBinary", + vec![RustImport::serde_json("Value")], + vec!["with-json"], + ), + + // Binary + "bytea" => RustType::with_column_type_attr( + "Vec", + "ColumnType::Binary(BlobSize::Blob(None))", + vec![], + vec![], + ), + + // Network types + "inet" => RustType::with_imports( + "ipnetwork::IpNetwork", + "ColumnType::Inet", + vec![RustImport::ipnetwork()], + vec!["with-ipnetwork"], + ), + "cidr" => RustType::with_imports( + "ipnetwork::IpNetwork", + "ColumnType::Cidr", + vec![RustImport::ipnetwork()], + vec!["with-ipnetwork"], + ), + "macaddr" | "macaddr8" => RustType::with_imports( + "mac_address::MacAddress", + "ColumnType::MacAddr", + vec![RustImport::mac_address()], + vec!["with-mac_address"], + ), + + // Geometric types (stored as strings) + "point" => RustType::with_column_type_attr( + "String", + "ColumnType::Custom(\"point\".into())", + vec![], + vec![], + ), + "line" => RustType::with_column_type_attr( + "String", + "ColumnType::Custom(\"line\".into())", + vec![], + vec![], + ), + "lseg" => RustType::with_column_type_attr( + "String", + "ColumnType::Custom(\"lseg\".into())", + vec![], + vec![], + ), + "box" => RustType::with_column_type_attr( + "String", + "ColumnType::Custom(\"box\".into())", + vec![], + vec![], + ), + "path" => RustType::with_column_type_attr( + "String", + "ColumnType::Custom(\"path\".into())", + vec![], + vec![], + ), + "polygon" => RustType::with_column_type_attr( + "String", + "ColumnType::Custom(\"polygon\".into())", + vec![], + vec![], + ), + "circle" => RustType::with_column_type_attr( + "String", + "ColumnType::Custom(\"circle\".into())", + vec![], + vec![], + ), + + // Bit strings + "bit" | "varbit" | "bit varying" => { + let column_type = extract_bit_column_type(formatted); + RustType::with_column_type_attr("String", &column_type, vec![], vec![]) + } + + // XML + "xml" => RustType::with_column_type_attr( + "String", + "ColumnType::Custom(\"xml\".into())", + vec![], + vec![], + ), + + // Full-text search types + "tsvector" => RustType::with_column_type_attr( + "String", + "ColumnType::Custom(\"tsvector\".into())", + vec![], + vec![], + ), + "tsquery" => RustType::with_column_type_attr( + "String", + "ColumnType::Custom(\"tsquery\".into())", + vec![], + vec![], + ), + + // OID types (internal PostgreSQL types) + "oid" | "regproc" | "regprocedure" | "regoper" | "regoperator" | "regclass" | "regtype" + | "regrole" | "regnamespace" | "regconfig" | "regdictionary" => { + RustType::simple("u32", "ColumnType::Unsigned") + } + + // Range types (as string, complex to map) + "int4range" | "int8range" | "numrange" | "tsrange" | "tstzrange" | "daterange" => { + RustType::with_column_type_attr( + "String", + &format!("ColumnType::Custom(\"{type_name}\".into())"), + vec![], + vec![], + ) + } + + // Default fallback - use String for unknown types + _ => { + // Check if it might be a user-defined enum + // User-defined types typically don't have pg_catalog schema + if type_info.schema.as_ref() != "pg_catalog" { + // Likely a user-defined enum or composite type + RustType::with_column_type_attr( + "String", + &format!("ColumnType::Custom(\"{type_name}\".into())"), + vec![], + vec![], + ) + } else { + // Unknown pg_catalog type, use String + RustType::with_column_type_attr( + "String", + &format!("ColumnType::Custom(\"{type_name}\".into())"), + vec![], + vec![], + ) + } + } + } +} + +/// Maps a PostgreSQL array type to a Rust Vec type. +fn map_array_type(element_type_name: &str) -> RustType { + // Get the element type first + let element_type = match element_type_name { + "int2" | "smallint" => ("i16", "ColumnType::SmallInteger"), + "int4" | "integer" | "int" => ("i32", "ColumnType::Integer"), + "int8" | "bigint" => ("i64", "ColumnType::BigInteger"), + "float4" | "real" => ("f32", "ColumnType::Float"), + "float8" | "double precision" => ("f64", "ColumnType::Double"), + "bool" | "boolean" => ("bool", "ColumnType::Boolean"), + "text" => ("String", "ColumnType::Text"), + "varchar" | "character varying" => ("String", "ColumnType::String(StringLen::None)"), + "uuid" => ("uuid::Uuid", "ColumnType::Uuid"), + _ => ("String", "ColumnType::Text"), + }; + + let annotation = format!("Vec<{}>", element_type.0); + let column_type = format!("ColumnType::Array(RcOrArc::new({}))", element_type.1); + + let mut imports = Vec::new(); + let mut features = Vec::new(); + + if element_type.0 == "uuid::Uuid" { + imports.push(RustImport::uuid()); + features.push("with-uuid"); + } + + RustType { + annotation, + column_type, + required_features: features, + imports, + needs_column_type_attr: true, + } +} + +/// Extracts numeric column type with precision and scale from formatted type. +fn extract_numeric_column_type(formatted: &str) -> String { + if let Some((precision, scale)) = extract_numeric_precision(formatted) { + format!("ColumnType::Decimal(Some(({precision}, {scale})))") + } else { + "ColumnType::Decimal(None)".to_string() + } +} + +/// Extracts varchar column type with length from formatted type. +fn extract_varchar_column_type(formatted: &str) -> String { + if let Some(length) = extract_string_length(formatted) { + format!("ColumnType::String(StringLen::N({length}))") + } else { + "ColumnType::String(StringLen::None)".to_string() + } +} + +/// Extracts char column type with length from formatted type. +fn extract_char_column_type(formatted: &str) -> String { + if let Some(length) = extract_string_length(formatted) { + format!("ColumnType::Char(Some({length}))") + } else { + "ColumnType::Char(None)".to_string() + } +} + +/// Extracts bit column type with length from formatted type. +fn extract_bit_column_type(formatted: &str) -> String { + if let Some(length) = extract_bit_length(formatted) { + format!("ColumnType::Bit(Some({length}))") + } else { + "ColumnType::Bit(None)".to_string() + } +} + +/// Extracts the varchar/char length constraint from a formatted type. +/// +/// Returns `Some(length)` for types like "character varying(255)" or "character(10)". +pub fn extract_string_length(formatted: &str) -> Option { + let formatted_lower = formatted.to_lowercase(); + + if formatted_lower.starts_with("character varying(") + || formatted_lower.starts_with("varchar(") + || formatted_lower.starts_with("character(") + || formatted_lower.starts_with("char(") + { + extract_paren_number(formatted) + } else { + None + } +} + +/// Extracts the bit length from a formatted type. +fn extract_bit_length(formatted: &str) -> Option { + let formatted_lower = formatted.to_lowercase(); + + if formatted_lower.starts_with("bit(") || formatted_lower.starts_with("bit varying(") { + extract_paren_number(formatted) + } else { + None + } +} + +/// Extracts numeric precision and scale from a formatted type. +/// +/// Returns `Some((precision, scale))` for types like "numeric(10,2)". +pub fn extract_numeric_precision(formatted: &str) -> Option<(u32, u32)> { + let formatted_lower = formatted.to_lowercase(); + + if formatted_lower.starts_with("numeric(") || formatted_lower.starts_with("decimal(") { + if let Some(start) = formatted.find('(') { + if let Some(end) = formatted.find(')') { + let params = &formatted[start + 1..end]; + let parts: Vec<&str> = params.split(',').collect(); + if parts.len() == 2 { + let precision: u32 = parts[0].trim().parse().ok()?; + let scale: u32 = parts[1].trim().parse().ok()?; + return Some((precision, scale)); + } else if parts.len() == 1 { + let precision: u32 = parts[0].trim().parse().ok()?; + return Some((precision, 0)); + } + } + } + } + + None +} + +/// Extracts a number from parentheses in a type string. +fn extract_paren_number(s: &str) -> Option { + if let Some(start) = s.find('(') { + if let Some(end) = s.find(')') { + let num_str = &s[start + 1..end]; + return num_str.trim().parse().ok(); + } + } + None +} + +/// Returns the SeaORM column type attribute value for use in `#[sea_orm(column_type = "...")]`. +/// +/// This returns the short form used in derive macro attributes. +#[allow(dead_code)] +pub fn get_column_type_attr(type_info: &TypeInfo) -> Option { + let rust_type = map_pg_type(type_info); + if rust_type.needs_column_type_attr { + // Extract the short form from ColumnType::X + let col_type = &rust_type.column_type; + if col_type.starts_with("ColumnType::") { + let short_form = &col_type["ColumnType::".len()..]; + // For simple types, just return the variant name + if let Some(paren_pos) = short_form.find('(') { + let variant = &short_form[..paren_pos]; + // Return just the variant name for the attribute + Some(variant.to_string()) + } else { + Some(short_form.to_string()) + } + } else { + None + } + } else { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tern_ddl::{SchemaName, TypeName}; + + fn make_type_info(name: &str, formatted: &str, is_array: bool) -> TypeInfo { + TypeInfo { + name: TypeName::try_new(name.to_string()).unwrap(), + schema: SchemaName::try_new("pg_catalog".to_string()).unwrap(), + formatted: formatted.to_string(), + is_array, + } + } + + #[test] + fn test_integer_types() { + let type_info = make_type_info("int4", "integer", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "i32"); + assert_eq!(rust_type.column_type, "ColumnType::Integer"); + assert!(rust_type.imports.is_empty()); + assert!(rust_type.required_features.is_empty()); + + let type_info = make_type_info("int8", "bigint", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "i64"); + assert_eq!(rust_type.column_type, "ColumnType::BigInteger"); + + let type_info = make_type_info("int2", "smallint", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "i16"); + assert_eq!(rust_type.column_type, "ColumnType::SmallInteger"); + } + + #[test] + fn test_float_types() { + let type_info = make_type_info("float8", "double precision", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "f64"); + assert_eq!(rust_type.column_type, "ColumnType::Double"); + + let type_info = make_type_info("float4", "real", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "f32"); + assert_eq!(rust_type.column_type, "ColumnType::Float"); + } + + #[test] + fn test_numeric_type() { + let type_info = make_type_info("numeric", "numeric(10,2)", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "rust_decimal::Decimal"); + assert_eq!(rust_type.column_type, "ColumnType::Decimal(Some((10, 2)))"); + assert_eq!(rust_type.required_features, vec!["with-rust_decimal"]); + assert_eq!(rust_type.imports.len(), 1); + assert_eq!(rust_type.imports[0].module, "rust_decimal"); + } + + #[test] + fn test_boolean_type() { + let type_info = make_type_info("bool", "boolean", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "bool"); + assert_eq!(rust_type.column_type, "ColumnType::Boolean"); + } + + #[test] + fn test_text_types() { + let type_info = make_type_info("text", "text", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "String"); + assert_eq!(rust_type.column_type, "ColumnType::Text"); + + let type_info = make_type_info("varchar", "character varying(255)", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "String"); + assert_eq!( + rust_type.column_type, + "ColumnType::String(StringLen::N(255))" + ); + } + + #[test] + fn test_datetime_types() { + let type_info = make_type_info("timestamp", "timestamp without time zone", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "chrono::NaiveDateTime"); + assert_eq!(rust_type.column_type, "ColumnType::DateTime"); + assert_eq!(rust_type.required_features, vec!["with-chrono"]); + + let type_info = make_type_info("timestamptz", "timestamp with time zone", false); + let rust_type = map_pg_type(&type_info); + assert_eq!( + rust_type.annotation, + "chrono::DateTime" + ); + assert_eq!(rust_type.column_type, "ColumnType::TimestampWithTimeZone"); + assert!(rust_type.needs_column_type_attr); + + let type_info = make_type_info("date", "date", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "chrono::NaiveDate"); + assert_eq!(rust_type.column_type, "ColumnType::Date"); + } + + #[test] + fn test_uuid_type() { + let type_info = make_type_info("uuid", "uuid", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "uuid::Uuid"); + assert_eq!(rust_type.column_type, "ColumnType::Uuid"); + assert_eq!(rust_type.required_features, vec!["with-uuid"]); + } + + #[test] + fn test_json_types() { + let type_info = make_type_info("jsonb", "jsonb", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "serde_json::Value"); + assert_eq!(rust_type.column_type, "ColumnType::JsonBinary"); + assert!(rust_type.needs_column_type_attr); + assert_eq!(rust_type.required_features, vec!["with-json"]); + + let type_info = make_type_info("json", "json", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "serde_json::Value"); + assert_eq!(rust_type.column_type, "ColumnType::Json"); + } + + #[test] + fn test_bytea_type() { + let type_info = make_type_info("bytea", "bytea", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "Vec"); + assert!(rust_type.needs_column_type_attr); + } + + #[test] + fn test_array_type() { + let type_info = make_type_info("int4", "integer[]", true); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "Vec"); + assert!(rust_type.column_type.contains("Array")); + assert!(rust_type.needs_column_type_attr); + } + + #[test] + fn test_text_array_type() { + let type_info = make_type_info("text", "text[]", true); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "Vec"); + assert!(rust_type.column_type.contains("Array")); + } + + #[test] + fn test_network_types() { + let type_info = make_type_info("inet", "inet", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "ipnetwork::IpNetwork"); + assert_eq!(rust_type.required_features, vec!["with-ipnetwork"]); + + let type_info = make_type_info("macaddr", "macaddr", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "mac_address::MacAddress"); + assert_eq!(rust_type.required_features, vec!["with-mac_address"]); + } + + #[test] + fn test_extract_string_length() { + assert_eq!(extract_string_length("character varying(255)"), Some(255)); + assert_eq!(extract_string_length("varchar(100)"), Some(100)); + assert_eq!(extract_string_length("character(10)"), Some(10)); + assert_eq!(extract_string_length("char(5)"), Some(5)); + assert_eq!(extract_string_length("text"), None); + } + + #[test] + fn test_extract_numeric_precision() { + assert_eq!(extract_numeric_precision("numeric(10,2)"), Some((10, 2))); + assert_eq!(extract_numeric_precision("numeric(5)"), Some((5, 0))); + assert_eq!(extract_numeric_precision("decimal(8,3)"), Some((8, 3))); + assert_eq!(extract_numeric_precision("integer"), None); + } + + #[test] + fn test_optional_type() { + let type_info = make_type_info("int4", "integer", false); + let rust_type = map_pg_type(&type_info); + let optional = rust_type.as_optional(); + assert_eq!(optional.annotation, "Option"); + } + + #[test] + fn test_geometric_types() { + let type_info = make_type_info("point", "point", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "String"); + assert!(rust_type.column_type.contains("Custom")); + } + + #[test] + fn test_interval_type() { + let type_info = make_type_info("interval", "interval", false); + let rust_type = map_pg_type(&type_info); + assert_eq!(rust_type.annotation, "String"); + assert!(rust_type.column_type.contains("Interval")); + } +}