diff --git a/crates/oxyde-codec/src/lib.rs b/crates/oxyde-codec/src/lib.rs index b04775e..7d04c46 100644 --- a/crates/oxyde-codec/src/lib.rs +++ b/crates/oxyde-codec/src/lib.rs @@ -130,6 +130,10 @@ pub enum ColumnTypeSpec { Json, /// JSONB on Postgres; identical to Json elsewhere. JsonBinary, + Enum { + name: String, + values: Vec, + }, Array { item: Box, }, @@ -613,6 +617,19 @@ mod tests { ); } + #[test] + fn test_spec_enum() { + assert_eq!( + spec_from_json( + r#"{"kind": "enum", "name": "post_status_enum", "values": ["draft", "published"]}"# + ), + ColumnTypeSpec::Enum { + name: "post_status_enum".into(), + values: vec!["draft".into(), "published".into()], + } + ); + } + #[test] fn test_spec_unknown_kind_is_error() { let result: std::result::Result = @@ -632,6 +649,10 @@ mod tests { ColumnTypeSpec::Array { item: Box::new(ColumnTypeSpec::Uuid), }, + ColumnTypeSpec::Enum { + name: "post_status_enum".into(), + values: vec!["draft".into(), "published".into()], + }, ColumnTypeSpec::Unknown, ]; for spec in specs { diff --git a/crates/oxyde-driver/src/convert/mysql.rs b/crates/oxyde-driver/src/convert/mysql.rs index b8f093d..019e332 100644 --- a/crates/oxyde-driver/src/convert/mysql.rs +++ b/crates/oxyde-driver/src/convert/mysql.rs @@ -31,7 +31,7 @@ impl CellEncoder for MySqlEncoder { } true } - ColumnTypeSpec::Text | ColumnTypeSpec::String { .. } => { + ColumnTypeSpec::Text | ColumnTypeSpec::String { .. } | ColumnTypeSpec::Enum { .. } => { match row.try_get::, _>(idx) { Ok(Some(v)) => write_str(buf, &v), Ok(None) => write_nil(buf), diff --git a/crates/oxyde-driver/src/convert/postgres.rs b/crates/oxyde-driver/src/convert/postgres.rs index df283d1..c983660 100644 --- a/crates/oxyde-driver/src/convert/postgres.rs +++ b/crates/oxyde-driver/src/convert/postgres.rs @@ -7,13 +7,46 @@ use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; use oxyde_codec::ColumnTypeSpec; use rust_decimal::Decimal; -use sqlx::{postgres::PgRow, Row}; +use sqlx::{ + error::BoxDynError, + postgres::{PgHasArrayType, PgRow, PgTypeInfo, PgTypeKind, PgValueRef, Postgres}, + Decode, Row, Type, +}; use uuid::Uuid; use super::encoder::*; pub struct PgEncoder; +#[derive(Debug, Clone)] +struct PgEnumText(String); + +impl Type for PgEnumText { + fn type_info() -> PgTypeInfo { + >::type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + >::compatible(ty) || matches!(ty.kind(), PgTypeKind::Enum(_)) + } +} + +impl PgHasArrayType for PgEnumText { + fn array_type_info() -> PgTypeInfo { + ::array_type_info() + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + matches!(ty.kind(), PgTypeKind::Array(element) if Self::compatible(element)) + } +} + +impl<'r> Decode<'r, Postgres> for PgEnumText { + fn decode(value: PgValueRef<'r>) -> Result { + Ok(Self(value.as_str()?.to_string())) + } +} + impl CellEncoder for PgEncoder { type Row = PgRow; @@ -44,6 +77,14 @@ impl CellEncoder for PgEncoder { } true } + ColumnTypeSpec::Enum { .. } => { + match row.try_get::, _>(idx) { + Ok(Some(v)) => write_str(buf, &v.0), + Ok(None) => write_nil(buf), + Err(_) => write_nil(buf), + } + true + } ColumnTypeSpec::Double => { match row.try_get::, _>(idx) { Ok(Some(v)) => write_f64(buf, v), @@ -259,6 +300,9 @@ fn encode_pg_array(buf: &mut Vec, row: &PgRow, idx: usize, item: &ColumnType ColumnTypeSpec::Text | ColumnTypeSpec::String { .. } => { encode_pg_array_ref::(buf, row, idx, |b, v| write_str(b, v)); } + ColumnTypeSpec::Enum { .. } => { + encode_pg_array_ref::(buf, row, idx, |b, v| write_str(b, &v.0)); + } ColumnTypeSpec::Uuid => { encode_pg_array_ref::(buf, row, idx, |b, v| { write_str(b, &v.to_string()); diff --git a/crates/oxyde-driver/src/convert/sqlite.rs b/crates/oxyde-driver/src/convert/sqlite.rs index 416efad..c70b0c3 100644 --- a/crates/oxyde-driver/src/convert/sqlite.rs +++ b/crates/oxyde-driver/src/convert/sqlite.rs @@ -32,6 +32,7 @@ impl CellEncoder for SqliteEncoder { // datetime, date, time, decimal, uuid — stored as TEXT in SQLite ColumnTypeSpec::Text | ColumnTypeSpec::String { .. } + | ColumnTypeSpec::Enum { .. } | ColumnTypeSpec::DateTime | ColumnTypeSpec::DateTimeUtc | ColumnTypeSpec::Date diff --git a/crates/oxyde-migrate/src/diff.rs b/crates/oxyde-migrate/src/diff.rs index 50c4fc1..c254fc0 100644 --- a/crates/oxyde-migrate/src/diff.rs +++ b/crates/oxyde-migrate/src/diff.rs @@ -2,8 +2,9 @@ use std::collections::{HashMap, HashSet}; -use crate::op::MigrationOp; +use crate::op::{EnumFieldRef, MigrationOp}; use crate::types::{Dialect, MigrateError, Result, Snapshot, TableDef}; +use oxyde_codec::ColumnTypeSpec; use serde::{Deserialize, Serialize}; /// Topologically sort table names so that referenced tables come before @@ -86,6 +87,136 @@ fn topo_sort_table_names(tables: &HashMap) -> Result Result>> { + let mut defs = HashMap::new(); + for table in snapshot.tables.values() { + for field in &table.fields { + collect_enum_def_from_spec(&field.column_type, &mut defs)?; + } + } + Ok(defs) +} + +fn collect_enum_def_from_spec( + spec: &ColumnTypeSpec, + defs: &mut HashMap>, +) -> Result<()> { + match spec { + ColumnTypeSpec::Enum { name, values } => { + if let Some(existing) = defs.get(name) { + if existing != values { + return Err(MigrateError::DiffError(format!( + "enum type '{}' has conflicting value sets", + name + ))); + } + } else { + defs.insert(name.clone(), values.clone()); + } + } + ColumnTypeSpec::Array { item } => collect_enum_def_from_spec(item, defs)?, + _ => {} + } + Ok(()) +} + +fn sorted_keys(map: &HashMap>) -> Vec { + let mut keys = map.keys().cloned().collect::>(); + keys.sort(); + keys +} + +fn enum_values_are_append_only(old_values: &[String], new_values: &[String]) -> bool { + new_values.len() >= old_values.len() && &new_values[..old_values.len()] == old_values +} + +fn column_type_requires_alter(old: &ColumnTypeSpec, new: &ColumnTypeSpec) -> bool { + match (old, new) { + ( + ColumnTypeSpec::Enum { name: old_name, .. }, + ColumnTypeSpec::Enum { name: new_name, .. }, + ) => old_name != new_name, + (ColumnTypeSpec::Array { item: old_item }, ColumnTypeSpec::Array { item: new_item }) => { + column_type_requires_alter(old_item, new_item) + } + _ => old != new, + } +} + +fn db_type_requires_alter( + old_type: &ColumnTypeSpec, + new_type: &ColumnTypeSpec, + old_db_type: &Option, + new_db_type: &Option, +) -> bool { + if !column_type_requires_alter(old_type, new_type) && contains_enum_type(old_type) { + return false; + } + old_db_type != new_db_type +} + +fn contains_enum_type(spec: &ColumnTypeSpec) -> bool { + match spec { + ColumnTypeSpec::Enum { .. } => true, + ColumnTypeSpec::Array { item } => contains_enum_type(item), + _ => false, + } +} + +fn scalar_enum_name(spec: &ColumnTypeSpec) -> Option<&str> { + match spec { + ColumnTypeSpec::Enum { name, .. } => Some(name), + _ => None, + } +} + +fn existing_scalar_enum_fields( + old: &Snapshot, + new: &Snapshot, + enum_name: &str, + values: &[String], +) -> Vec { + let mut fields = Vec::new(); + let mut table_names = old + .tables + .keys() + .filter(|name| new.tables.contains_key(*name)) + .cloned() + .collect::>(); + table_names.sort(); + + for table_name in table_names { + let old_table = &old.tables[&table_name]; + let new_table = &new.tables[&table_name]; + for old_field in &old_table.fields { + if scalar_enum_name(&old_field.column_type) != Some(enum_name) { + continue; + } + if let Some(new_field) = new_table + .fields + .iter() + .find(|field| field.name == old_field.name) + .filter(|field| scalar_enum_name(&field.column_type) == Some(enum_name)) + { + let mut field = new_field.clone(); + if let ColumnTypeSpec::Enum { + values: field_values, + .. + } = &mut field.column_type + { + *field_values = values.to_vec(); + } + fields.push(EnumFieldRef { + table: table_name.clone(), + field, + }); + } + } + } + + fields +} + /// Migration file #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Migration { @@ -124,19 +255,23 @@ impl Migration { /// This ensures referenced tables exist before FK constraints are added /// (PG/MySQL emit FK as separate ALTER TABLE, not inline in CREATE TABLE). pub fn to_sql(&self, dialect: Dialect) -> Result> { - let mut all_sql = Vec::new(); + let mut all_sql: Vec<(u8, String)> = Vec::new(); for op in &self.operations { let sqls = op.to_sql(dialect)?; - all_sql.extend(sqls); - } - all_sql.sort_by_key(|s| { - if s.trim_start().starts_with("ALTER") { - 1 - } else { - 0 + for sql in sqls { + let bucket = match op { + MigrationOp::CreateEnumType { .. } => 0, + MigrationOp::AddEnumValue { .. } => 1, + MigrationOp::DropEnumType { .. } => 4, + MigrationOp::AlterEnumType { .. } => 5, + _ if sql.trim_start().starts_with("ALTER TABLE") => 3, + _ => 2, + }; + all_sql.push((bucket, sql)); } - }); - Ok(all_sql) + } + all_sql.sort_by_key(|(bucket, _)| *bucket); + Ok(all_sql.into_iter().map(|(_, sql)| sql).collect()) } } @@ -147,6 +282,39 @@ impl Migration { /// Cycles among unchanged tables are irrelevant and pass through silently. pub fn compute_diff(old: &Snapshot, new: &Snapshot) -> Result> { let mut ops = Vec::new(); + let old_enums = collect_enum_defs(old)?; + let new_enums = collect_enum_defs(new)?; + + for name in sorted_keys(&new_enums) { + if !old_enums.contains_key(&name) { + ops.push(MigrationOp::CreateEnumType { + name: name.clone(), + values: new_enums[&name].clone(), + }); + } + } + + for name in sorted_keys(&new_enums) { + if let Some(old_values) = old_enums.get(&name) { + let new_values = &new_enums[&name]; + if enum_values_are_append_only(old_values, new_values) { + for (index, value) in new_values[old_values.len()..].iter().cloned().enumerate() { + let values = &new_values[..old_values.len() + index + 1]; + ops.push(MigrationOp::AddEnumValue { + name: name.clone(), + value, + fields: existing_scalar_enum_fields(old, new, &name, values), + }); + } + } else { + ops.push(MigrationOp::AlterEnumType { + name: name.clone(), + old_values: old_values.clone(), + new_values: new_values.clone(), + }); + } + } + } // Topo-sort only the subset of tables that are actually being created. // FKs from this subset to tables that already exist in `old` are not @@ -215,8 +383,14 @@ pub fn compute_diff(old: &Snapshot, new: &Snapshot) -> Result> if let Some(old_field) = old_table.fields.iter().find(|f| f.name == new_field.name) { // Check if type changed using column_type or db_type - let type_changed = old_field.column_type != new_field.column_type - || old_field.db_type != new_field.db_type; + let type_changed = + column_type_requires_alter(&old_field.column_type, &new_field.column_type) + || db_type_requires_alter( + &old_field.column_type, + &new_field.column_type, + &old_field.db_type, + &new_field.db_type, + ); let nullable_changed = old_field.nullable != new_field.nullable; let default_changed = old_field.default != new_field.default; @@ -350,5 +524,14 @@ pub fn compute_diff(old: &Snapshot, new: &Snapshot) -> Result> } } + for name in sorted_keys(&old_enums) { + if !new_enums.contains_key(&name) { + ops.push(MigrationOp::DropEnumType { + name: name.clone(), + values: Some(old_enums[&name].clone()), + }); + } + } + Ok(ops) } diff --git a/crates/oxyde-migrate/src/op.rs b/crates/oxyde-migrate/src/op.rs index 281f527..1863e6d 100644 --- a/crates/oxyde-migrate/src/op.rs +++ b/crates/oxyde-migrate/src/op.rs @@ -3,10 +3,36 @@ use crate::types::{CheckDef, FieldDef, ForeignKeyDef, IndexDef, TableDef}; use serde::{Deserialize, Serialize}; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EnumFieldRef { + pub table: String, + pub field: FieldDef, +} + /// Migration operation #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum MigrationOp { + CreateEnumType { + name: String, + values: Vec, + }, + DropEnumType { + name: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + values: Option>, + }, + AddEnumValue { + name: String, + value: String, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + fields: Vec, + }, + AlterEnumType { + name: String, + old_values: Vec, + new_values: Vec, + }, CreateTable { table: TableDef, }, diff --git a/crates/oxyde-migrate/src/spec_sql.rs b/crates/oxyde-migrate/src/spec_sql.rs index 1912e52..a2e33a2 100644 --- a/crates/oxyde-migrate/src/spec_sql.rs +++ b/crates/oxyde-migrate/src/spec_sql.rs @@ -27,12 +27,23 @@ pub fn resolve_spec_type( dialect: Dialect, is_pk: bool, ) -> String { + if contains_enum_spec(spec) { + return canonical_type(spec, dialect, is_pk); + } if let Some(db_type) = db_type { return translate_user_db_type(db_type, dialect); } canonical_type(spec, dialect, is_pk) } +fn contains_enum_spec(spec: &ColumnTypeSpec) -> bool { + match spec { + ColumnTypeSpec::Enum { .. } => true, + ColumnTypeSpec::Array { item } => contains_enum_spec(item), + _ => false, + } +} + /// SERIAL/BIGSERIAL are PostgreSQL-specific — translate for other dialects. /// Everything else renders verbatim: the user owns the string. fn translate_user_db_type(db_type: &str, dialect: Dialect) -> String { @@ -117,6 +128,18 @@ fn canonical_type(spec: &ColumnTypeSpec, dialect: Dialect, is_pk: bool) -> Strin D::Mysql => "JSON".to_string(), D::Sqlite => "TEXT".to_string(), }, + S::Enum { name, values } => match dialect { + D::Postgres => quote_postgres_type_name(name), + D::Mysql => format!( + "ENUM({})", + values + .iter() + .map(|value| quote_sql_string(value)) + .collect::>() + .join(",") + ), + D::Sqlite => "TEXT".to_string(), + }, S::Array { item } => match dialect { D::Postgres => format!("{}[]", canonical_type(item, dialect, false)), D::Mysql => "JSON".to_string(), @@ -126,6 +149,17 @@ fn canonical_type(spec: &ColumnTypeSpec, dialect: Dialect, is_pk: bool) -> Strin } } +pub(crate) fn quote_postgres_type_name(name: &str) -> String { + name.split('.') + .map(|part| format!("\"{}\"", part.replace('"', "\"\""))) + .collect::>() + .join(".") +} + +pub(crate) fn quote_sql_string(value: &str) -> String { + format!("'{}'", value.replace('\'', "''")) +} + #[cfg(test)] mod tests { use super::*; @@ -242,6 +276,26 @@ mod tests { ); } + #[test] + fn enum_type_rendering() { + let spec = ColumnTypeSpec::Enum { + name: "post_status_enum".to_string(), + values: vec!["draft".to_string(), "published".to_string()], + }; + assert_eq!( + resolve_spec_type(&spec, None, Dialect::Postgres, false), + r#""post_status_enum""# + ); + assert_eq!( + resolve_spec_type(&spec, None, Dialect::Mysql, false), + "ENUM('draft','published')" + ); + assert_eq!( + resolve_spec_type(&spec, None, Dialect::Sqlite, false), + "TEXT" + ); + } + #[test] fn pk_str_and_uuid_do_not_become_serial() { assert_eq!( diff --git a/crates/oxyde-migrate/src/sql.rs b/crates/oxyde-migrate/src/sql.rs index e894b0d..dc9e6c3 100644 --- a/crates/oxyde-migrate/src/sql.rs +++ b/crates/oxyde-migrate/src/sql.rs @@ -263,6 +263,61 @@ impl MigrationOp { /// (e.g., ALTER COLUMN on SQLite without table schema). pub fn to_sql(&self, dialect: Dialect) -> Result> { match self { + MigrationOp::CreateEnumType { name, values } => { + if dialect != Dialect::Postgres { + return Ok(Vec::new()); + } + let labels = values + .iter() + .map(|value| crate::spec_sql::quote_sql_string(value)) + .collect::>() + .join(", "); + Ok(vec![format!( + "CREATE TYPE {} AS ENUM ({})", + crate::spec_sql::quote_postgres_type_name(name), + labels + )]) + } + + MigrationOp::DropEnumType { name, values: _ } => { + if dialect != Dialect::Postgres { + return Ok(Vec::new()); + } + Ok(vec![format!( + "DROP TYPE {}", + crate::spec_sql::quote_postgres_type_name(name) + )]) + } + + MigrationOp::AddEnumValue { + name, + value, + fields, + } => { + if dialect == Dialect::Mysql { + return Ok(fields + .iter() + .map(|field| { + format!( + "ALTER TABLE `{}` MODIFY COLUMN {}", + field.table, + mysql_column_def(&field.field) + ) + }) + .collect()); + } + if dialect == Dialect::Sqlite { + return Ok(Vec::new()); + } + Ok(vec![format!( + "ALTER TYPE {} ADD VALUE IF NOT EXISTS {}", + crate::spec_sql::quote_postgres_type_name(name), + crate::spec_sql::quote_sql_string(value) + )]) + } + + MigrationOp::AlterEnumType { .. } => Ok(Vec::new()), + MigrationOp::CreateTable { table } => { let mut create = SeaTable::create(); create.table(Alias::new(&table.name)); diff --git a/crates/oxyde-migrate/tests/migration_tests.rs b/crates/oxyde-migrate/tests/migration_tests.rs index 4b354c3..945be86 100644 --- a/crates/oxyde-migrate/tests/migration_tests.rs +++ b/crates/oxyde-migrate/tests/migration_tests.rs @@ -4,8 +4,8 @@ use oxyde_codec::ColumnTypeSpec; use oxyde_migrate::{ - compute_diff, CheckDef, Dialect, FieldDef, ForeignKeyDef, IndexDef, MigrationOp, Snapshot, - TableDef, + compute_diff, CheckDef, Dialect, FieldDef, ForeignKeyDef, IndexDef, Migration, MigrationOp, + Snapshot, TableDef, }; fn sample_field(name: &str) -> FieldDef { @@ -56,6 +56,51 @@ fn sample_table() -> TableDef { } } +fn enum_spec(values: &[&str]) -> ColumnTypeSpec { + ColumnTypeSpec::Enum { + name: "post_status_enum".into(), + values: values.iter().map(|value| (*value).into()).collect(), + } +} + +fn enum_table(values: &[&str]) -> TableDef { + TableDef { + name: "posts".into(), + fields: vec![ + FieldDef { + name: "id".into(), + column_type: ColumnTypeSpec::BigInteger, + db_type: None, + nullable: false, + primary_key: true, + unique: false, + default: None, + auto_increment: false, + max_length: None, + max_digits: None, + decimal_places: None, + }, + FieldDef { + name: "status".into(), + column_type: enum_spec(values), + db_type: None, + nullable: false, + primary_key: false, + unique: false, + default: None, + auto_increment: false, + max_length: None, + max_digits: None, + decimal_places: None, + }, + ], + indexes: vec![], + foreign_keys: vec![], + checks: vec![], + comment: None, + } +} + #[test] fn test_snapshot_serialization_roundtrip() { let mut snapshot = Snapshot::new(); @@ -122,6 +167,170 @@ fn test_migration_create_table_generates_sql() { assert!(sql[1].contains(r#"CREATE UNIQUE INDEX "users_email_idx""#)); } +#[test] +fn test_postgres_enum_type_is_created_before_table() { + let migration = Migration { + name: "enum".into(), + operations: vec![ + MigrationOp::CreateEnumType { + name: "post_status_enum".into(), + values: vec!["draft".into(), "published".into()], + }, + MigrationOp::CreateTable { + table: enum_table(&["draft", "published"]), + }, + ], + }; + + let sql = migration.to_sql(Dialect::Postgres).unwrap(); + assert_eq!( + sql[0], + r#"CREATE TYPE "post_status_enum" AS ENUM ('draft', 'published')"# + ); + assert!( + sql[1].contains(r#""status" "post_status_enum" NOT NULL"#), + "{}", + sql[1] + ); +} + +#[test] +fn test_mysql_enum_is_inline_column_type() { + let sql = MigrationOp::CreateTable { + table: enum_table(&["draft", "published"]), + } + .to_sql(Dialect::Mysql) + .unwrap(); + + assert!(sql[0].contains("`status` ENUM('draft','published') NOT NULL")); +} + +#[test] +fn test_compute_diff_adds_enum_value_without_altering_column() { + let mut old = Snapshot::new(); + old.add_table(enum_table(&["draft", "published"])); + let mut new = Snapshot::new(); + new.add_table(enum_table(&["draft", "published", "archived"])); + + let ops = compute_diff(&old, &new).unwrap(); + assert_eq!(ops.len(), 1); + match &ops[0] { + MigrationOp::AddEnumValue { + name, + value, + fields, + } => { + assert_eq!(name, "post_status_enum"); + assert_eq!(value, "archived"); + assert_eq!(fields.len(), 1); + assert_eq!(fields[0].table, "posts"); + assert_eq!(fields[0].field.name, "status"); + } + op => panic!("expected AddEnumValue, got {:?}", op), + } +} + +#[test] +fn test_mysql_enum_value_append_modifies_inline_enum_columns() { + let mut old = Snapshot::new(); + old.add_table(enum_table(&["draft", "published"])); + let mut new = Snapshot::new(); + new.add_table(enum_table(&["draft", "published", "archived"])); + + let ops = compute_diff(&old, &new).unwrap(); + let migration = Migration { + name: "enum_value".into(), + operations: ops, + }; + + let sql = migration.to_sql(Dialect::Mysql).unwrap(); + assert_eq!(sql.len(), 1); + assert_eq!( + sql[0], + "ALTER TABLE `posts` MODIFY COLUMN `status` ENUM('draft','published','archived') NOT NULL" + ); +} + +#[test] +fn test_mysql_multiple_enum_value_append_modifies_inline_enum_columns_progressively() { + let mut old = Snapshot::new(); + old.add_table(enum_table(&["draft", "published"])); + let mut new = Snapshot::new(); + new.add_table(enum_table(&["draft", "published", "archived", "deleted"])); + + let ops = compute_diff(&old, &new).unwrap(); + let migration = Migration { + name: "enum_value".into(), + operations: ops, + }; + + let sql = migration.to_sql(Dialect::Mysql).unwrap(); + assert_eq!( + sql, + vec![ + "ALTER TABLE `posts` MODIFY COLUMN `status` ENUM('draft','published','archived') NOT NULL", + "ALTER TABLE `posts` MODIFY COLUMN `status` ENUM('draft','published','archived','deleted') NOT NULL", + ] + ); +} + +#[test] +fn test_postgres_enum_value_is_added_before_dependent_ddl() { + let migration = Migration { + name: "enum_value".into(), + operations: vec![ + MigrationOp::CreateIndex { + table: "posts".into(), + index: IndexDef { + name: "posts_archived_idx".into(), + fields: vec!["status".into()], + unique: false, + method: Some("btree".into()), + where_clause: Some("status = 'archived'".into()), + }, + }, + MigrationOp::AddEnumValue { + name: "post_status_enum".into(), + value: "archived".into(), + fields: vec![], + }, + ], + }; + + let sql = migration.to_sql(Dialect::Postgres).unwrap(); + assert_eq!( + sql[0], + r#"ALTER TYPE "post_status_enum" ADD VALUE IF NOT EXISTS 'archived'"# + ); + assert!(sql[1].contains(r#"CREATE INDEX "posts_archived_idx""#)); +} + +#[test] +fn test_compute_diff_emits_manual_enum_alter_for_value_removal() { + let mut old = Snapshot::new(); + old.add_table(enum_table(&["draft", "published"])); + let mut new = Snapshot::new(); + new.add_table(enum_table(&["draft"])); + + let ops = compute_diff(&old, &new).unwrap(); + assert_eq!(ops.len(), 1); + match &ops[0] { + MigrationOp::AlterEnumType { + name, + old_values, + new_values, + } => { + assert_eq!(name, "post_status_enum"); + assert_eq!( + old_values, + &vec!["draft".to_string(), "published".to_string()] + ); + assert_eq!(new_values, &vec!["draft".to_string()]); + } + op => panic!("expected AlterEnumType, got {:?}", op), + } +} + #[test] fn test_sqlite_create_table_with_fk_inline() { // SQLite should have FK constraints inline in CREATE TABLE, not as ALTER TABLE diff --git a/crates/oxyde-query/src/builder/bulk.rs b/crates/oxyde-query/src/builder/bulk.rs index 9a57ca3..efd518c 100644 --- a/crates/oxyde-query/src/builder/bulk.rs +++ b/crates/oxyde-query/src/builder/bulk.rs @@ -10,7 +10,7 @@ use sea_query::{ use crate::error::{QueryError, Result}; use crate::filter::build_filter_node; -use crate::utils::{bind_value, ColumnIdent, TableIdent}; +use crate::utils::{bind_value, typed_value_expr, ColumnIdent, TableIdent}; use crate::Dialect; /// Build bulk UPDATE query using CASE WHEN statements. @@ -45,7 +45,7 @@ pub fn build_bulk_update( for column in row.values.keys() { update_columns.insert(column.clone()); } - let cond = build_bulk_row_condition(row, col_types)?; + let cond = build_bulk_row_condition(row, col_types, dialect)?; row_conditions.push(cond); } @@ -66,7 +66,10 @@ pub fn build_bulk_update( .unwrap_or(&ColumnTypeSpec::Unknown); for (row, cond) in bulk.rows.iter().zip(&row_conditions) { if let Some(value) = row.values.get(&column) { - case_stmt = case_stmt.case(cond.clone(), Expr::val(bind_value(value, spec))); + case_stmt = case_stmt.case( + cond.clone(), + typed_value_expr(bind_value(value, spec), spec, dialect), + ); } } // ELSE keeps current value for rows not matched by this column @@ -82,7 +85,7 @@ pub fn build_bulk_update( query.cond_where(filter_cond); if let Some(filter_tree) = &ir.filter_tree { - let expr = build_filter_node(filter_tree, None, col_types, None)?; + let expr = build_filter_node(filter_tree, None, col_types, None, dialect)?; query.and_where(expr); } @@ -103,10 +106,11 @@ pub fn build_bulk_update( fn build_bulk_row_condition( row: &BulkUpdateRow, col_types: Option<&HashMap>, + dialect: Dialect, ) -> Result { let mut cond = Cond::all(); for (column, value) in &row.filters { - cond = cond.add(build_match_expression(column, value, col_types)); + cond = cond.add(build_match_expression(column, value, col_types, dialect)); } Ok(cond) } @@ -116,6 +120,7 @@ fn build_match_expression( column: &str, value: &rmpv::Value, col_types: Option<&HashMap>, + dialect: Dialect, ) -> SimpleExpr { if value.is_nil() { Expr::col(ColumnIdent(column.to_string())).is_null() @@ -123,6 +128,10 @@ fn build_match_expression( let spec = col_types .and_then(|ct| ct.get(column)) .unwrap_or(&ColumnTypeSpec::Unknown); - Expr::col(ColumnIdent(column.to_string())).eq(bind_value(value, spec)) + Expr::col(ColumnIdent(column.to_string())).eq(typed_value_expr( + bind_value(value, spec), + spec, + dialect, + )) } } diff --git a/crates/oxyde-query/src/builder/delete.rs b/crates/oxyde-query/src/builder/delete.rs index ea78ec8..e97228e 100644 --- a/crates/oxyde-query/src/builder/delete.rs +++ b/crates/oxyde-query/src/builder/delete.rs @@ -16,7 +16,7 @@ pub fn build_delete(ir: &QueryIR, dialect: Dialect) -> Result<(String, Vec Result<(String, Vec Result<(String, Vec Result<(String, Vec Result<(String, Vec Result { +fn build_select_statement(ir: &QueryIR, dialect: Dialect) -> Result { let table = TableIdent(ir.table.clone()); let mut query = Query::select(); query.from(table.clone()); @@ -61,7 +61,13 @@ fn build_select_statement(ir: &QueryIR) -> Result { }; if let Some(filter_tree) = &ir.filter_tree { - let expr = build_filter_node(filter_tree, default_table, ir.column_types.as_ref(), None)?; + let expr = build_filter_node( + filter_tree, + default_table, + ir.column_types.as_ref(), + None, + dialect, + )?; query.and_where(expr); } @@ -85,6 +91,7 @@ fn build_select_statement(ir: &QueryIR) -> Result { default_table, ir.column_types.as_ref(), ir.aggregates.as_deref(), + dialect, )?; query.and_having(expr); } @@ -95,7 +102,7 @@ fn build_select_statement(ir: &QueryIR) -> Result { // UNION via sea-query (recursive) if let Some(union_query_ir) = &ir.union_query { - let union_stmt = build_select_statement(union_query_ir)?; + let union_stmt = build_select_statement(union_query_ir, dialect)?; let union_type = if ir.union_all.unwrap_or(false) { UnionType::All } else { @@ -172,8 +179,13 @@ pub fn build_select(ir: &QueryIR, dialect: Dialect) -> Result<(String, Vec Result<(String, Vec Result<(String, Vec Result<(String, Vec, col_types: Option<&HashMap>, aggregates: Option<&[Aggregate]>, + dialect: Dialect, ) -> Result { match node { - FilterNode::Condition(filter) => apply_filter(filter, default_table, col_types, aggregates), + FilterNode::Condition(filter) => { + apply_filter(filter, default_table, col_types, aggregates, dialect) + } FilterNode::And { conditions } => { if conditions.is_empty() { return Err(QueryError::InvalidQuery( "AND node must have at least one condition".into(), )); } - let first = build_filter_node(&conditions[0], default_table, col_types, aggregates)?; + let first = build_filter_node( + &conditions[0], + default_table, + col_types, + aggregates, + dialect, + )?; let mut result = first; for cond in &conditions[1..] { - let next = build_filter_node(cond, default_table, col_types, aggregates)?; + let next = build_filter_node(cond, default_table, col_types, aggregates, dialect)?; result = result.and(next); } Ok(result) @@ -90,16 +100,23 @@ pub fn build_filter_node( "OR node must have at least one condition".into(), )); } - let first = build_filter_node(&conditions[0], default_table, col_types, aggregates)?; + let first = build_filter_node( + &conditions[0], + default_table, + col_types, + aggregates, + dialect, + )?; let mut result = first; for cond in &conditions[1..] { - let next = build_filter_node(cond, default_table, col_types, aggregates)?; + let next = build_filter_node(cond, default_table, col_types, aggregates, dialect)?; result = result.or(next); } Ok(result) } FilterNode::Not { condition } => { - let inner = build_filter_node(condition, default_table, col_types, aggregates)?; + let inner = + build_filter_node(condition, default_table, col_types, aggregates, dialect)?; Ok(inner.not()) } } @@ -113,6 +130,7 @@ fn apply_filter( default_table: Option<&str>, col_types: Option<&HashMap>, aggregates: Option<&[Aggregate]>, + dialect: Dialect, ) -> Result { let col_name = filter.column.as_ref().unwrap_or(&filter.field); @@ -130,12 +148,12 @@ fn apply_filter( let val = bind_value(&filter.value, spec); let expr = match filter.operator.as_str() { - "=" => col.eq(val), - "!=" => col.ne(val), - ">" => col.gt(val), - ">=" => col.gte(val), - "<" => col.lt(val), - "<=" => col.lte(val), + "=" => col.eq(typed_value_expr(val, spec, dialect)), + "!=" => col.ne(typed_value_expr(val, spec, dialect)), + ">" => col.gt(typed_value_expr(val, spec, dialect)), + ">=" => col.gte(typed_value_expr(val, spec, dialect)), + "<" => col.lt(typed_value_expr(val, spec, dialect)), + "<=" => col.lte(typed_value_expr(val, spec, dialect)), "LIKE" => { let text = filter.value.as_str().ok_or_else(|| { QueryError::InvalidQuery("LIKE operator requires string value".into()) @@ -152,7 +170,10 @@ fn apply_filter( } "IN" => { if let rmpv::Value::Array(arr) = &filter.value { - let values: Vec = arr.iter().map(|v| bind_value(v, spec)).collect(); + let values: Vec = arr + .iter() + .map(|v| typed_value_expr(bind_value(v, spec), spec, dialect)) + .collect(); col.is_in(values) } else { return Err(QueryError::InvalidQuery( @@ -167,8 +188,8 @@ fn apply_filter( "BETWEEN operator requires exactly two values".to_string(), )); } - let start = Expr::val(bind_value(&arr[0], spec)); - let end = Expr::val(bind_value(&arr[1], spec)); + let start = typed_value_expr(bind_value(&arr[0], spec), spec, dialect); + let end = typed_value_expr(bind_value(&arr[1], spec), spec, dialect); col.between(start, end) } else { return Err(QueryError::InvalidQuery( diff --git a/crates/oxyde-query/src/lib.rs b/crates/oxyde-query/src/lib.rs index dbf0245..a2826a2 100644 --- a/crates/oxyde-query/src/lib.rs +++ b/crates/oxyde-query/src/lib.rs @@ -121,7 +121,8 @@ pub fn build_sql(ir: &QueryIR, dialect: Dialect) -> Result<(String, Vec)> mod tests { use super::*; use oxyde_codec::{ - ConflictAction, Filter, FilterNode, JoinColumn, JoinSpec, OnConflict, Operation, QueryIR, + ColumnTypeSpec, ConflictAction, Filter, FilterNode, JoinColumn, JoinSpec, OnConflict, + Operation, QueryIR, }; use std::collections::HashMap; @@ -141,6 +142,13 @@ mod tests { rmpv::Value::Array(vals) } + fn enum_spec() -> ColumnTypeSpec { + ColumnTypeSpec::Enum { + name: "post_status_enum".into(), + values: vec!["draft".into(), "published".into()], + } + } + /// Helper to create a simple condition filter node fn filter_cond(field: &str, operator: &str, value: rmpv::Value) -> FilterNode { FilterNode::Condition(Filter { @@ -214,6 +222,39 @@ mod tests { } } + #[test] + fn test_insert_enum_value_is_cast_on_postgres() { + let ir = QueryIR { + op: Operation::Insert, + table: "posts".into(), + values: Some(HashMap::from([("status".into(), rmpv_str("draft"))])), + column_types: Some(HashMap::from([("status".into(), enum_spec())])), + ..Default::default() + }; + let (sql, params) = build_sql(&ir, Dialect::Postgres).unwrap(); + assert!(sql.contains("$1::\"post_status_enum\""), "{sql}"); + assert_eq!(params.len(), 1); + } + + #[test] + fn test_filter_enum_in_values_are_cast_on_postgres() { + let ir = QueryIR { + table: "posts".into(), + cols: Some(vec!["id".into()]), + filter_tree: Some(filter_cond( + "status", + "IN", + rmpv_arr(vec![rmpv_str("draft"), rmpv_str("published")]), + )), + column_types: Some(HashMap::from([("status".into(), enum_spec())])), + ..Default::default() + }; + let (sql, params) = build_sql(&ir, Dialect::Postgres).unwrap(); + assert!(sql.contains("$1::\"post_status_enum\""), "{sql}"); + assert!(sql.contains("$2::\"post_status_enum\""), "{sql}"); + assert_eq!(params.len(), 2); + } + #[test] fn test_mysql_placeholders() { let ir = QueryIR { diff --git a/crates/oxyde-query/src/utils/bind.rs b/crates/oxyde-query/src/utils/bind.rs index 7505f53..de08069 100644 --- a/crates/oxyde-query/src/utils/bind.rs +++ b/crates/oxyde-query/src/utils/bind.rs @@ -12,9 +12,11 @@ use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; use oxyde_codec::ColumnTypeSpec; use rust_decimal::Decimal; use sea_query::value::ArrayType; -use sea_query::Value; +use sea_query::{Expr, SimpleExpr, Value}; use uuid::Uuid; +use crate::Dialect; + /// Convert an rmpv value without a column spec (raw SQL parameters). /// Identical to `bind_value(value, &ColumnTypeSpec::Unknown)`. pub fn rmpv_to_value(value: &rmpv::Value) -> Value { @@ -70,6 +72,15 @@ pub fn bind_value(value: &rmpv::Value, spec: &ColumnTypeSpec) -> Value { } } +pub fn typed_value_expr(value: Value, spec: &ColumnTypeSpec, dialect: Dialect) -> SimpleExpr { + if dialect == Dialect::Postgres { + if let Some(type_name) = postgres_enum_cast_type(spec) { + return Expr::cust_with_values(format!("$1::{type_name}"), vec![value]); + } + } + Expr::val(value).into() +} + /// Bind a string payload according to the spec. /// /// Temporal specs and `Unknown` fall back to the RFC3339 heuristic on parse @@ -120,6 +131,7 @@ fn bind_string(s: &str, spec: &ColumnTypeSpec) -> Value { Err(_) => string_value(s), } } + ColumnTypeSpec::Enum { .. } => string_value(s), // Known non-temporal specs: definitely not a datetime — no heuristic. ColumnTypeSpec::BigInteger | ColumnTypeSpec::Double @@ -139,7 +151,9 @@ fn typed_null(spec: &ColumnTypeSpec) -> Value { ColumnTypeSpec::BigInteger | ColumnTypeSpec::Timedelta => Value::BigInt(None), ColumnTypeSpec::Double => Value::Double(None), ColumnTypeSpec::Boolean => Value::Bool(None), - ColumnTypeSpec::Text | ColumnTypeSpec::String { .. } => Value::String(None), + ColumnTypeSpec::Text | ColumnTypeSpec::String { .. } | ColumnTypeSpec::Enum { .. } => { + Value::String(None) + } ColumnTypeSpec::Blob => Value::Bytes(None), ColumnTypeSpec::DateTime => Value::ChronoDateTime(None), ColumnTypeSpec::DateTimeUtc => Value::ChronoDateTimeUtc(None), @@ -163,7 +177,9 @@ fn element_array_type(item: &ColumnTypeSpec) -> Option { ColumnTypeSpec::BigInteger | ColumnTypeSpec::Timedelta => Some(ArrayType::BigInt), ColumnTypeSpec::Double => Some(ArrayType::Double), ColumnTypeSpec::Boolean => Some(ArrayType::Bool), - ColumnTypeSpec::Text | ColumnTypeSpec::String { .. } => Some(ArrayType::String), + ColumnTypeSpec::Text | ColumnTypeSpec::String { .. } | ColumnTypeSpec::Enum { .. } => { + Some(ArrayType::String) + } ColumnTypeSpec::Blob => Some(ArrayType::Bytes), ColumnTypeSpec::DateTime => Some(ArrayType::ChronoDateTime), ColumnTypeSpec::DateTimeUtc => Some(ArrayType::ChronoDateTimeUtc), @@ -176,6 +192,24 @@ fn element_array_type(item: &ColumnTypeSpec) -> Option { } } +fn postgres_enum_cast_type(spec: &ColumnTypeSpec) -> Option { + match spec { + ColumnTypeSpec::Enum { name, .. } => Some(quote_pg_type_path(name)), + ColumnTypeSpec::Array { item } => match item.as_ref() { + ColumnTypeSpec::Enum { name, .. } => Some(format!("{}[]", quote_pg_type_path(name))), + _ => None, + }, + _ => None, + } +} + +fn quote_pg_type_path(name: &str) -> String { + name.split('.') + .map(|part| format!("\"{}\"", part.replace('"', "\"\""))) + .collect::>() + .join(".") +} + fn string_value(s: &str) -> Value { Value::String(Some(Box::new(s.to_string()))) } diff --git a/crates/oxyde-query/src/utils/mod.rs b/crates/oxyde-query/src/utils/mod.rs index 7d0fb7c..15848a4 100644 --- a/crates/oxyde-query/src/utils/mod.rs +++ b/crates/oxyde-query/src/utils/mod.rs @@ -5,6 +5,6 @@ pub mod identifier; pub mod value; // Re-exports for convenience -pub use bind::{bind_value, rmpv_to_value}; +pub use bind::{bind_value, rmpv_to_value, typed_value_expr}; pub use identifier::{ColumnIdent, TableIdent}; pub use value::{parse_expression, rmpv_to_simple_expr}; diff --git a/python/oxyde/cli/migrations.py b/python/oxyde/cli/migrations.py index cc91dd5..cf67eca 100644 --- a/python/oxyde/cli/migrations.py +++ b/python/oxyde/cli/migrations.py @@ -108,9 +108,23 @@ def makemigrations( typer.echo(f" - Add column: {op['table']}.{op['field']['name']}") elif op_type == "drop_column": typer.echo(f" - Drop column: {op['table']}.{op['field']}") + elif op_type == "alter_enum_type": + typer.secho( + f" - Manual enum change: {op['name']} " + f"{op['old_values']} -> {op['new_values']}", + fg=typer.colors.YELLOW, + ) else: typer.echo(f" - {op_type}") + if any(op.get("type") == "alter_enum_type" for op in operations): + typer.secho( + " ⚠️ One or more enum types changed in a way that requires " + "manual SQL. The migration file will include a ctx.require_manual(...) " + "guard; replace it with ctx.execute(...) and keep ctx.alter_enum_type(...).", + fg=typer.colors.YELLOW, + ) + except Exception as e: typer.secho(f" ❌ Error computing diff: {e}", fg=typer.colors.RED) raise typer.Exit(1) diff --git a/python/oxyde/core/column_types.py b/python/oxyde/core/column_types.py index b28641f..5f17ca3 100644 --- a/python/oxyde/core/column_types.py +++ b/python/oxyde/core/column_types.py @@ -24,6 +24,7 @@ from datetime import date, datetime, time, timedelta from decimal import Decimal +from enum import Enum from typing import Any, get_args, get_origin from uuid import UUID @@ -107,6 +108,9 @@ def compute_column_type( otherwise the Python annotation. Returns None when nothing is known — the column is then omitted from ``column_types``. """ + enum_spec = _spec_from_enum_annotation(python_type, db_type) + if enum_spec is not None: + return enum_spec if db_type: return _spec_from_db_type(db_type) return _spec_from_annotation( @@ -124,6 +128,8 @@ def spec_for_literal(value_type: type) -> ColumnSpec | None: that msgpack-encode as strings. Returns None for types msgpack carries natively with correct binding. """ + if _is_enum_type(value_type): + return _enum_spec(value_type) kind = _PY_SCALAR_KINDS.get(value_type) return {"kind": kind} if kind is not None else None @@ -194,6 +200,8 @@ def _spec_from_annotation( kind = _PY_SCALAR_KINDS.get(python_type) if kind is None: + if _is_enum_type(python_type): + return _enum_spec(python_type) return None spec: ColumnSpec = {"kind": kind} @@ -207,6 +215,72 @@ def _spec_from_annotation( return spec +def _spec_from_enum_annotation( + python_type: Any, + db_type: str | None, +) -> ColumnSpec | None: + enum_type = _unwrap_enum_annotation(python_type) + if enum_type is None: + return None + if db_type: + db_spec = _spec_from_db_type(db_type) + if db_spec is not None: + return db_spec + return _enum_spec(enum_type, db_type) + + +def _unwrap_enum_annotation(python_type: Any) -> type[Enum] | None: + origin = get_origin(python_type) + if origin is list: + return None + if origin is not None: + for arg in get_args(python_type): + if arg is type(None): + continue + enum_type = _unwrap_enum_annotation(arg) + if enum_type is not None: + return enum_type + return None + return python_type if _is_enum_type(python_type) else None + + +def _is_enum_type(value: Any) -> bool: + return isinstance(value, type) and issubclass(value, Enum) + + +def _enum_spec(enum_type: type[Enum], db_type: str | None = None) -> ColumnSpec: + return { + "kind": "enum", + "name": db_type or _default_enum_type_name(enum_type), + "values": _enum_values(enum_type), + } + + +def _enum_values(enum_type: type[Enum]) -> list[str]: + values = [] + for member in enum_type: + value = member.value + if not isinstance(value, str): + raise TypeError( + f"Enum field '{enum_type.__name__}' must define string values" + ) + values.append(value) + return values + + +def _default_enum_type_name(enum_type: type[Enum]) -> str: + name = enum_type.__name__ + parts: list[str] = [] + for index, char in enumerate(name): + if char.isupper() and index > 0: + prev = name[index - 1] + next_char = name[index + 1] if index + 1 < len(name) else "" + if prev != "_" and (not prev.isupper() or next_char.islower()): + parts.append("_") + parts.append(char.lower()) + return f"{''.join(parts)}_enum" + + def _split_type_params(upper: str) -> tuple[str, list[int]]: """Split "NUMERIC(10,2)" → ("NUMERIC", [10, 2]); no parens → (name, []).""" if "(" not in upper: diff --git a/python/oxyde/core/types.py b/python/oxyde/core/types.py index b905382..d0bb601 100644 --- a/python/oxyde/core/types.py +++ b/python/oxyde/core/types.py @@ -15,6 +15,7 @@ from dataclasses import dataclass from datetime import date, datetime, time, timedelta from decimal import Decimal +from enum import Enum from typing import Any from uuid import UUID @@ -55,6 +56,8 @@ def serialize_value(value: Any) -> Any: return [serialize_value(v) for v in value] if isinstance(value, dict): return {k: serialize_value(v) for k, v in value.items()} + if isinstance(value, Enum): + return value.value desc = TYPE_REGISTRY.get(type(value)) if desc is not None: return desc.serialize(value) diff --git a/python/oxyde/migrations/context.py b/python/oxyde/migrations/context.py index 3afb869..98a13f6 100644 --- a/python/oxyde/migrations/context.py +++ b/python/oxyde/migrations/context.py @@ -21,6 +21,11 @@ from oxyde.migrations.replay import SchemaState +def _is_postgres_enum_add_value_sql(sql: str) -> bool: + upper = sql.strip().upper() + return upper.startswith("ALTER TYPE ") and " ADD VALUE " in upper + + class MigrationContext: """Context for executing migrations. @@ -63,6 +68,67 @@ def dialect(self) -> str: """ return self._dialect + def create_enum_type(self, name: str, values: list[str]) -> None: + op = { + "type": "create_enum_type", + "name": name, + "values": values, + } + + if self._mode == "collect": + self._operations.append(op) + else: + self._execute_operation(op) + + def drop_enum_type(self, name: str) -> None: + op = { + "type": "drop_enum_type", + "name": name, + } + + if self._mode == "collect": + self._operations.append(op) + else: + self._execute_operation(op) + + def add_enum_value( + self, + name: str, + value: str, + fields: list[dict[str, Any]] | None = None, + ) -> None: + op = { + "type": "add_enum_value", + "name": name, + "value": value, + "fields": fields or [], + } + + if self._mode == "collect": + self._operations.append(op) + else: + self._execute_operation(op) + + def alter_enum_type( + self, + name: str, + old_values: list[str], + new_values: list[str], + ) -> None: + op = { + "type": "alter_enum_type", + "name": name, + "old_values": old_values, + "new_values": new_values, + } + + if self._mode == "collect": + self._operations.append(op) + + def require_manual(self, message: str) -> None: + if self._mode == "execute": + raise RuntimeError(message) + # ======================================================================== # Table operations # ======================================================================== @@ -482,7 +548,7 @@ async def _execute_collected_sql(self) -> None: This is called by the executor after upgrade() completes. Transaction behavior by dialect: - - PostgreSQL: DDL is transactional, uses Rust transaction API + - PostgreSQL: DDL is transactional except ALTER TYPE ADD VALUE chains - SQLite: DDL is transactional, uses Rust transaction API - MySQL: DDL is NOT transactional (implicit commit), no wrapping @@ -497,8 +563,7 @@ async def _execute_collected_sql(self) -> None: if self._db_conn is None: raise RuntimeError("Cannot execute SQL: no database connection provided") - # MySQL doesn't support transactional DDL - use_transaction = self._dialect in ("postgres", "sqlite") + use_transaction = self._should_use_transaction() tx_id = None try: @@ -528,5 +593,15 @@ async def _execute_collected_sql(self) -> None: # Clear collected statements self._sql_statements = [] + def _should_use_transaction(self) -> bool: + if self._dialect == "mysql": + return False + if self._dialect == "postgres" and any( + _is_postgres_enum_add_value_sql(sql) + for sql in getattr(self, "_sql_statements", []) + ): + return False + return self._dialect in ("postgres", "sqlite") + __all__ = ["MigrationContext"] diff --git a/python/oxyde/migrations/extract.py b/python/oxyde/migrations/extract.py index 7158c6f..8e31657 100644 --- a/python/oxyde/migrations/extract.py +++ b/python/oxyde/migrations/extract.py @@ -4,6 +4,7 @@ from datetime import date, datetime, time from decimal import Decimal +from enum import Enum from typing import Any from uuid import UUID @@ -79,6 +80,10 @@ def _serialize_default(value: Any, dialect: str) -> str | None: if callable(value): return None + if isinstance(value, Enum): + escaped = str(value.value).replace("'", "''") + return f"'{escaped}'" + # String - quote it if isinstance(value, str): # Escape single quotes diff --git a/python/oxyde/migrations/generator.py b/python/oxyde/migrations/generator.py index 8a0bf2b..1a480f6 100644 --- a/python/oxyde/migrations/generator.py +++ b/python/oxyde/migrations/generator.py @@ -66,7 +66,41 @@ def _operation_to_python(op: dict[str, Any], indent: str = " ") -> str: """ op_type = op.get("type") - if op_type == "create_table": + if op_type == "create_enum_type": + values_repr = _python_repr(op["values"]) + return f"{indent}ctx.create_enum_type({_python_repr(op['name'])}, {values_repr})" + + elif op_type == "drop_enum_type": + return f"{indent}ctx.drop_enum_type({_python_repr(op['name'])})" + + elif op_type == "add_enum_value": + args = f"{_python_repr(op['name'])}, {_python_repr(op['value'])}" + if fields := op.get("fields"): + return ( + f"{indent}ctx.add_enum_value(\n" + f"{indent} {args},\n" + f"{indent} fields={_python_repr(fields, indent=len(indent) + 11)},\n" + f"{indent})" + ) + return f"{indent}ctx.add_enum_value({args})" + + elif op_type == "alter_enum_type": + message = ( + f"Manual enum migration required for {op['name']}: " + f"{op['old_values']!r} -> {op['new_values']!r}. " + "Replace ctx.require_manual(...) with ctx.execute(...) statements and keep " + "ctx.alter_enum_type(...) to update migration replay state." + ) + return ( + f"{indent}ctx.alter_enum_type(\n" + f"{indent} {_python_repr(op['name'])},\n" + f"{indent} old_values={_python_repr(op['old_values'], indent=len(indent) + 15)},\n" + f"{indent} new_values={_python_repr(op['new_values'], indent=len(indent) + 15)},\n" + f"{indent})\n" + f"{indent}ctx.require_manual({_python_repr(message)})" + ) + + elif op_type == "create_table": table = op["table"] tname = table["name"] # Use consistent indentation for all kwargs @@ -201,7 +235,15 @@ def _infer_migration_name(operations: list[dict[str, Any]]) -> str: first_op = operations[0] op_type = first_op.get("type") - if op_type == "create_table": + if op_type == "create_enum_type": + return f"create_{first_op['name']}_enum" + elif op_type == "drop_enum_type": + return f"drop_{first_op['name']}_enum" + elif op_type == "add_enum_value": + return f"add_{first_op['value']}_to_{first_op['name']}_enum" + elif op_type == "alter_enum_type": + return f"alter_{first_op['name']}_enum" + elif op_type == "create_table": table_name = first_op["table"]["name"] return f"create_{table_name}_table" elif op_type == "drop_table": @@ -306,7 +348,46 @@ def generate_migration_file( op_type = op.get("type") # Generate reverse operation - if op_type == "create_table": + if op_type == "create_enum_type": + downgrade_lines.append( + f" ctx.drop_enum_type({_python_repr(op['name'])})" + ) + elif op_type == "drop_enum_type": + values = op.get("values") + if values: + values_repr = _python_repr(values) + downgrade_lines.append( + f" ctx.create_enum_type({_python_repr(op['name'])}, {values_repr})" + ) + else: + message = f"Cannot recreate enum type {op['name']} without its values" + downgrade_lines.append( + f" raise RuntimeError({_python_repr(message)})" + ) + elif op_type == "add_enum_value": + message = ( + f"Cannot automatically remove enum value {op['value']} " + f"from {op['name']}" + ) + downgrade_lines.append( + f" raise RuntimeError({_python_repr(message)})" + ) + elif op_type == "alter_enum_type": + message = ( + f"Manual enum migration required for {op['name']}: " + f"{op['new_values']!r} -> {op['old_values']!r}. " + "Replace ctx.require_manual(...) with ctx.execute(...) statements and keep " + "ctx.alter_enum_type(...) to update migration replay state." + ) + downgrade_lines.append( + f" ctx.alter_enum_type(\n" + f" {_python_repr(op['name'])},\n" + f" old_values={_python_repr(op['new_values'], indent=19)},\n" + f" new_values={_python_repr(op['old_values'], indent=19)},\n" + f" )\n" + f" ctx.require_manual({_python_repr(message)})" + ) + elif op_type == "create_table": downgrade_lines.append(f' ctx.drop_table("{op["table"]["name"]}")') elif op_type == "drop_table": # Reverse drop_table by recreating the table from stored structure diff --git a/python/oxyde/migrations/replay.py b/python/oxyde/migrations/replay.py index 16a8d26..bb00573 100644 --- a/python/oxyde/migrations/replay.py +++ b/python/oxyde/migrations/replay.py @@ -14,6 +14,37 @@ ) +def _add_enum_value_to_spec(spec: dict[str, Any], name: str, value: str) -> dict[str, Any]: + if spec.get("kind") == "enum" and spec.get("name") == name: + updated = dict(spec) + values = list(updated.get("values", [])) + if value not in values: + values.append(value) + updated["values"] = values + return updated + if spec.get("kind") == "array" and isinstance(spec.get("item"), dict): + updated = dict(spec) + updated["item"] = _add_enum_value_to_spec(spec["item"], name, value) + return updated + return spec + + +def _replace_enum_values_in_spec( + spec: dict[str, Any], + name: str, + values: list[str], +) -> dict[str, Any]: + if spec.get("kind") == "enum" and spec.get("name") == name: + updated = dict(spec) + updated["values"] = list(values) + return updated + if spec.get("kind") == "array" and isinstance(spec.get("item"), dict): + updated = dict(spec) + updated["item"] = _replace_enum_values_in_spec(spec["item"], name, values) + return updated + return spec + + class SchemaState: """Represents database schema in memory. @@ -32,7 +63,37 @@ def apply_operation(self, op: dict[str, Any]) -> None: """ op_type = op.get("type") - if op_type == "create_table": + if op_type == "create_enum_type": + pass + + elif op_type == "drop_enum_type": + pass + + elif op_type == "add_enum_value": + enum_name = op["name"] + enum_value = op["value"] + for table in self.tables.values(): + for field in table["fields"]: + if column_type := field.get("column_type"): + field["column_type"] = _add_enum_value_to_spec( + column_type, + enum_name, + enum_value, + ) + + elif op_type == "alter_enum_type": + enum_name = op["name"] + enum_values = op["new_values"] + for table in self.tables.values(): + for field in table["fields"]: + if column_type := field.get("column_type"): + field["column_type"] = _replace_enum_values_in_spec( + column_type, + enum_name, + enum_values, + ) + + elif op_type == "create_table": table = op["table"] self.tables[table["name"]] = { "name": table["name"], diff --git a/python/oxyde/queries/mixins/mutation.py b/python/oxyde/queries/mixins/mutation.py index 36165ff..343bb2c 100644 --- a/python/oxyde/queries/mixins/mutation.py +++ b/python/oxyde/queries/mixins/mutation.py @@ -4,7 +4,10 @@ import warnings from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Literal, overload +from enum import Enum +from typing import TYPE_CHECKING, Any, Literal, get_args, get_origin, overload + +from pydantic import TypeAdapter from oxyde._msgpack import msgpack from oxyde.core import ir @@ -14,6 +17,7 @@ _dump_insert_data, _normalize_instance, ) +from oxyde.models.utils import _unpack_annotated, _unwrap_optional from oxyde.queries.base import ( SupportsExecute, _build_column_types, @@ -22,7 +26,7 @@ _model_key, _resolve_execution_client, ) -from oxyde.queries.expressions import F, _serialize_value_for_ir +from oxyde.queries.expressions import F, _Expression, _serialize_value_for_ir from oxyde.queries.insert import InsertQuery if TYPE_CHECKING: @@ -66,6 +70,39 @@ def _decode_columnar_models(model_class: type[Model], result: list[Any]) -> list return _hydrate_models(model_class, result[0], result[1]) +def _is_enum_annotation(annotation: Any) -> bool: + annotation, _ = _unpack_annotated(annotation) + annotation, _ = _unwrap_optional(annotation) + origin = get_origin(annotation) + if origin is list: + args = get_args(annotation) + return bool(args) and _is_enum_annotation(args[0]) + if origin is not None: + return any( + arg is not type(None) and _is_enum_annotation(arg) + for arg in get_args(annotation) + ) + return isinstance(annotation, type) and issubclass(annotation, Enum) + + +def _validate_enum_update_values( + model_class: type[Model], + values: dict[str, Any], +) -> dict[str, Any]: + validated = dict(values) + metadata = model_class._db_meta.field_metadata + for field, value in values.items(): + if isinstance(value, (F, _Expression)): + continue + meta = metadata.get(field) + if meta is not None and _is_enum_annotation(meta.python_type): + if value is None and meta.nullable: + validated[field] = None + else: + validated[field] = TypeAdapter(meta.python_type).validate_python(value) + return validated + + class MutationMixin: """Mixin providing data mutation capabilities.""" @@ -164,6 +201,7 @@ async def update( """ exec_client = await _resolve_execution_client(using, client) column_types = _build_column_types(self.model_class) + values = _validate_enum_update_values(self.model_class, values) mapped_values = _map_values_to_columns(self.model_class, values) serialized_values = { key: _serialize_value_for_ir(value) for key, value in mapped_values.items() diff --git a/python/oxyde/tests/unit/test_db_types.py b/python/oxyde/tests/unit/test_db_types.py index 3f9f872..382c5cd 100644 --- a/python/oxyde/tests/unit/test_db_types.py +++ b/python/oxyde/tests/unit/test_db_types.py @@ -8,6 +8,7 @@ from datetime import date, datetime, time, timedelta from decimal import Decimal +from enum import Enum from typing import Annotated from uuid import UUID @@ -16,6 +17,11 @@ from oxyde import Field, Model +class Status(Enum): + DRAFT = "draft" + PUBLISHED = "published" + + # ── Model with all type variations ──────────────────────────────────── @@ -35,6 +41,7 @@ class DbTypesModel(Model): infer_uuid: UUID | None = Field(default=None, db_nullable=True) infer_decimal: Decimal | None = Field(default=None, db_nullable=True) infer_json: dict | None = Field(default=None, db_nullable=True) + infer_enum: Status = Field(default=Status.DRAFT) # Explicit db_type (scalar) db_uuid: str = Field(default="", db_type="UUID") @@ -67,12 +74,15 @@ class DbTypesModel(Model): db_real: float = Field(default=0.0, db_type="REAL") db_bytea: bytes | None = Field(default=None, db_nullable=True, db_type="BYTEA") db_blob: bytes | None = Field(default=None, db_nullable=True, db_type="BLOB") + db_enum: Status = Field(default=Status.DRAFT, db_type="post_status_enum") + db_enum_as_text: Status = Field(default=Status.DRAFT, db_type="TEXT") # Inferred array types (no db_type) infer_str_list: list[str] | None = Field(default=None, db_nullable=True) infer_int_list: list[int] | None = Field(default=None, db_nullable=True) infer_uuid_list: list[UUID] | None = Field(default=None, db_nullable=True) infer_decimal_list: list[Decimal] | None = Field(default=None, db_nullable=True) + infer_enum_list: list[Status] | None = Field(default=None, db_nullable=True) # Explicit db_type on arrays db_varchar_arr: list[str] | None = Field( @@ -120,6 +130,10 @@ class Meta: ("infer_uuid", {"kind": "uuid"}), ("infer_decimal", {"kind": "decimal"}), ("infer_json", {"kind": "json"}), + ( + "infer_enum", + {"kind": "enum", "name": "status_enum", "values": ["draft", "published"]}, + ), # Explicit db_type scalar — semantic kind via KNOWN_DB_TYPES, # the verbatim string travels separately (FieldDef.db_type) ("db_uuid", {"kind": "uuid"}), @@ -144,11 +158,31 @@ class Meta: ("db_real", {"kind": "double"}), ("db_bytea", {"kind": "blob"}), ("db_blob", {"kind": "blob"}), + ( + "db_enum", + { + "kind": "enum", + "name": "post_status_enum", + "values": ["draft", "published"], + }, + ), + ("db_enum_as_text", {"kind": "text"}), # Inferred arrays ("infer_str_list", {"kind": "array", "item": {"kind": "string"}}), ("infer_int_list", {"kind": "array", "item": {"kind": "big_integer"}}), ("infer_uuid_list", {"kind": "array", "item": {"kind": "uuid"}}), ("infer_decimal_list", {"kind": "array", "item": {"kind": "decimal"}}), + ( + "infer_enum_list", + { + "kind": "array", + "item": { + "kind": "enum", + "name": "status_enum", + "values": ["draft", "published"], + }, + }, + ), # Explicit db_type arrays — kind per element, params parsed ("db_varchar_arr", {"kind": "array", "item": {"kind": "string", "length": 100}}), ( diff --git a/python/oxyde/tests/unit/test_migrations_execution.py b/python/oxyde/tests/unit/test_migrations_execution.py index 90192bd..d76364e 100644 --- a/python/oxyde/tests/unit/test_migrations_execution.py +++ b/python/oxyde/tests/unit/test_migrations_execution.py @@ -13,7 +13,6 @@ from oxyde.migrations.replay import SchemaState - @pytest.fixture def temp_migrations_dir(): """Create a temporary migrations directory.""" @@ -355,6 +354,9 @@ def test_validate_migration_name(self): for name in invalid_names: assert not re.match(pattern, name), f"{name} should be invalid" + for name in invalid_names: + assert not re.match(pattern, name), f"{name} should be invalid" + def test_validate_table_name(self): """Test table name validation.""" import re @@ -382,6 +384,9 @@ def test_validate_column_name(self): for name in valid_names: assert re.match(pattern, name), f"{name} should be valid" + for name in invalid_names: + assert not re.match(pattern, name), f"{name} should be invalid" + class TestMigrationDependencies: """Test migration dependency resolution.""" @@ -564,6 +569,23 @@ def test_rust_diff_drop_index_roundtrip(self): assert any("idx_users_email" in s for s in sqls) +class TestMigrationTransactionMode: + def test_postgres_enum_add_value_runs_without_transaction(self): + ctx = MigrationContext(mode="execute", dialect="postgres") + ctx._sql_statements = [ + 'ALTER TYPE "post_status_enum" ADD VALUE IF NOT EXISTS \'archived\'', + 'CREATE INDEX "posts_archived_idx" ON "posts" ("status")', + ] + + assert ctx._should_use_transaction() is False + + def test_regular_postgres_migration_uses_transaction(self): + ctx = MigrationContext(mode="execute", dialect="postgres") + ctx._sql_statements = ['CREATE TABLE "users" ("id" BIGINT)'] + + assert ctx._should_use_transaction() is True + + ALL_DIALECTS = ["postgres", "mysql", "sqlite"] NON_SQLITE = ["postgres", "mysql"] PARTIAL_INDEX_DIALECTS = ["postgres", "sqlite"] diff --git a/python/oxyde/tests/unit/test_migrations_pipeline.py b/python/oxyde/tests/unit/test_migrations_pipeline.py index f537d09..a852021 100644 --- a/python/oxyde/tests/unit/test_migrations_pipeline.py +++ b/python/oxyde/tests/unit/test_migrations_pipeline.py @@ -15,6 +15,7 @@ from __future__ import annotations import json +from enum import Enum from pathlib import Path import pytest @@ -24,15 +25,20 @@ from oxyde.migrations.context import MigrationContext from oxyde.migrations.extract import extract_current_schema from oxyde.migrations.generator import generate_migration_file +from oxyde.migrations.replay import replay_migrations from oxyde.migrations.utils import load_migration_module from oxyde.models.registry import clear_registry - ALL_DIALECTS = ["postgres", "mysql", "sqlite"] NON_SQLITE = ["postgres", "mysql"] PARTIAL_INDEX_DIALECTS = ["postgres", "sqlite"] +class PublishState(Enum): + DRAFT = "draft" + PUBLISHED = "published" + + def _snapshot_from_models(models: list[type[Model]], dialect: str) -> dict: """Register given models fresh and return their schema snapshot.""" clear_registry() @@ -79,6 +85,47 @@ def _run_pipeline( return ops, up_sql, down_sql +def _enum_snapshot(values: list[str]) -> dict: + return { + "version": 1, + "tables": { + "articles": { + "name": "articles", + "fields": [ + { + "name": "id", + "column_type": {"kind": "big_integer"}, + "db_type": None, + "nullable": False, + "primary_key": True, + "unique": False, + "default": None, + "auto_increment": False, + }, + { + "name": "state", + "column_type": { + "kind": "enum", + "name": "publish_state_enum", + "values": values, + }, + "db_type": None, + "nullable": False, + "primary_key": False, + "unique": False, + "default": None, + "auto_increment": False, + }, + ], + "indexes": [], + "foreign_keys": [], + "checks": [], + "comment": None, + } + }, + } + + class TestCreateTablePipeline: @pytest.mark.parametrize("dialect", ALL_DIALECTS) def test_create_table_from_scratch(self, tmp_path, dialect): @@ -100,6 +147,117 @@ class Meta: assert any("DROP TABLE" in s.upper() and "users" in s for s in down_sql) +class TestEnumPipeline: + @pytest.mark.parametrize("dialect", ALL_DIALECTS) + def test_create_table_with_enum(self, tmp_path, dialect): + class Article(Model): + id: int | None = Field(default=None, db_pk=True) + state: PublishState = Field(default=PublishState.DRAFT) + + class Meta: + is_table = True + table_name = "articles" + + ops, up_sql, down_sql = _run_pipeline( + [], [Article], dialect, tmp_path, "create_articles" + ) + + assert [op["type"] for op in ops[:2]] == ["create_enum_type", "create_table"] + if dialect == "postgres": + assert up_sql[0] == ( + 'CREATE TYPE "publish_state_enum" AS ENUM (' + "'draft', 'published')" + ) + assert any('"state" "publish_state_enum" NOT NULL' in s for s in up_sql) + assert down_sql[-1] == 'DROP TYPE "publish_state_enum"' + elif dialect == "mysql": + assert not any("CREATE TYPE" in s.upper() for s in up_sql) + assert any("ENUM('draft','published')" in s for s in up_sql) + else: + assert not any("CREATE TYPE" in s.upper() for s in up_sql) + assert any('"state" TEXT NOT NULL' in s for s in up_sql) + + def test_add_enum_value_diff_emits_enum_op_only(self): + old = _enum_snapshot(["draft", "published"]) + new = _enum_snapshot(["draft", "published", "archived"]) + + ops_json = migration_compute_diff(json.dumps(old), json.dumps(new)) + ops = json.loads(ops_json) + + assert len(ops) == 1 + op = ops[0] + assert op["type"] == "add_enum_value" + assert op["name"] == "publish_state_enum" + assert op["value"] == "archived" + assert op["fields"][0]["table"] == "articles" + assert op["fields"][0]["field"]["name"] == "state" + assert migration_to_sql(ops_json, "postgres") == [ + 'ALTER TYPE "publish_state_enum" ADD VALUE IF NOT EXISTS \'archived\'' + ] + assert migration_to_sql(ops_json, "mysql") == [ + "ALTER TABLE `articles` MODIFY COLUMN `state` ENUM('draft','published','archived') NOT NULL" + ] + + def test_add_multiple_enum_values_updates_mysql_inline_enum_progressively(self): + old = _enum_snapshot(["draft", "published"]) + new = _enum_snapshot(["draft", "published", "archived", "deleted"]) + + ops_json = migration_compute_diff(json.dumps(old), json.dumps(new)) + ops = json.loads(ops_json) + + assert [op["value"] for op in ops] == ["archived", "deleted"] + assert migration_to_sql(ops_json, "mysql") == [ + "ALTER TABLE `articles` MODIFY COLUMN `state` ENUM('draft','published','archived') NOT NULL", + "ALTER TABLE `articles` MODIFY COLUMN `state` ENUM('draft','published','archived','deleted') NOT NULL", + ] + + def test_enum_value_removal_requires_manual_migration(self): + old = _enum_snapshot(["draft", "published"]) + new = _enum_snapshot(["draft"]) + + ops_json = migration_compute_diff(json.dumps(old), json.dumps(new)) + ops = json.loads(ops_json) + + assert ops == [ + { + "type": "alter_enum_type", + "name": "publish_state_enum", + "old_values": ["draft", "published"], + "new_values": ["draft"], + } + ] + + def test_manual_enum_migration_file_contains_replay_marker(self, tmp_path): + old = _enum_snapshot(["draft", "published"]) + new = _enum_snapshot(["draft"]) + ops = json.loads(migration_compute_diff(json.dumps(old), json.dumps(new))) + + empty = {"version": 1, "tables": {}} + create_ops = json.loads(migration_compute_diff(json.dumps(empty), json.dumps(old))) + generate_migration_file( + create_ops, + migrations_dir=tmp_path, + name="create_articles", + ) + filepath = generate_migration_file( + ops, + migrations_dir=tmp_path, + name="alter_publish_state", + ) + content = filepath.read_text() + + assert "Manual enum migration required for publish_state_enum" in content + assert "ctx.alter_enum_type(" in content + assert "ctx.require_manual(" in content + assert "raise RuntimeError" not in content + assert "old_values=[" in content + assert "new_values=[" in content + + replayed = replay_migrations(str(tmp_path)) + state_field = replayed["tables"]["articles"]["fields"][1] + assert state_field["column_type"]["values"] == ["draft"] + + class TestAddColumnPipeline: @pytest.mark.parametrize("dialect", ALL_DIALECTS) def test_add_column(self, tmp_path, dialect): diff --git a/python/oxyde/tests/unit/test_migrations_squash.py b/python/oxyde/tests/unit/test_migrations_squash.py index ef051a6..e3196af 100644 --- a/python/oxyde/tests/unit/test_migrations_squash.py +++ b/python/oxyde/tests/unit/test_migrations_squash.py @@ -118,6 +118,59 @@ def test_raw_sql_files_reported(self, tmp_path): assert result.raw_sql_files == ["0002_add_age.py"] assert sorted(result.legacy_files) == ["0001_users.py", "0002_add_age.py"] + def test_legacy_field_next_to_enum_add_value_replays(self, tmp_path): + (tmp_path / "0001_create.py").write_text( + '''depends_on = None + + +def upgrade(ctx): + ctx.create_enum_type("status_enum", ["draft"]) + ctx.create_table( + "articles", + fields=[ + { + "name": "id", + "python_type": "int", + "db_type": None, + "nullable": False, + "primary_key": True, + "unique": False, + "default": None, + "auto_increment": True, + }, + { + "name": "status", + "column_type": { + "kind": "enum", + "name": "status_enum", + "values": ["draft"], + }, + "db_type": None, + "nullable": False, + "primary_key": False, + "unique": False, + "default": None, + "auto_increment": False, + }, + ], + ) +''' + ) + (tmp_path / "0002_add_value.py").write_text( + '''depends_on = "0001_create" + + +def upgrade(ctx): + ctx.add_enum_value("status_enum", "published") +''' + ) + + snapshot = replay_migrations(str(tmp_path)) + + fields = snapshot["tables"]["articles"]["fields"] + status = next(field for field in fields if field["name"] == "status") + assert status["column_type"]["values"] == ["draft", "published"] + def test_squashed_history_replays_to_same_schema(self, tmp_path, recwarn): _write_legacy_history(tmp_path) before = replay_migrations(str(tmp_path)) diff --git a/python/oxyde/tests/unit/test_query.py b/python/oxyde/tests/unit/test_query.py index 7fb630e..abc9260 100644 --- a/python/oxyde/tests/unit/test_query.py +++ b/python/oxyde/tests/unit/test_query.py @@ -1,11 +1,13 @@ from __future__ import annotations from datetime import datetime, timezone +from enum import Enum from types import SimpleNamespace -from typing import Any +from typing import Any, ClassVar import msgpack import pytest +from pydantic import ValidationError from oxyde import Field, Model from oxyde.db import atomic @@ -161,7 +163,7 @@ class Meta: # db_type is NOT auto-inferred - only set if user explicitly specifies Field(db_type="...") # Type inference happens at schema extraction time based on dialect assert user_meta.field_metadata["email"].db_type is None - assert user_meta.field_metadata["email"].python_type == str + assert user_meta.field_metadata["email"].python_type is str article_meta = Article._db_meta @@ -414,6 +416,71 @@ class Meta: clear_registry() +@pytest.mark.asyncio +async def test_update_validates_enum_values() -> None: + clear_registry() + + class Status(Enum): + DRAFT = "draft" + PUBLISHED = "published" + + meta = type("Meta", (), {"is_table": True}) + Post = type( + "EnumUpdatePost", + (Model,), + { + "__module__": __name__, + "__annotations__": { + "id": int | None, + "status": Status, + "status_tags": list[Status] | None, + "Meta": ClassVar[type], + }, + "id": Field(default=None, db_pk=True), + "status": Field(default=Status.DRAFT), + "status_tags": Field(default=None, db_nullable=True), + "Meta": meta, + }, + ) + + valid_stub = StubExecuteClient([{"affected": 1}]) + result = await Post.objects.filter(id=1).update( + status="published", + status_tags=["draft", "published"], + client=valid_stub, + ) + + assert result == 1 + assert valid_stub.calls[0]["values"]["status"] == "published" + assert valid_stub.calls[0]["values"]["status_tags"] == ["draft", "published"] + + null_stub = StubExecuteClient([{"affected": 1}]) + result = await Post.objects.filter(id=1).update( + status_tags=None, + client=null_stub, + ) + assert result == 1 + assert null_stub.calls[0]["values"]["status_tags"] is None + + invalid_stub = StubExecuteClient([{"affected": 1}]) + with pytest.raises(ValidationError): + await Post.objects.filter(id=1).update( + status="deleted", + client=invalid_stub, + ) + assert invalid_stub.calls == [] + + invalid_list_stub = StubExecuteClient([{"affected": 1}]) + with pytest.raises(ValidationError): + await Post.objects.filter(id=1).update( + status_tags=["draft", "deleted"], + client=invalid_list_stub, + ) + assert invalid_list_stub.calls == [] + + clear_registry() + + def test_queryset_values_distinct_and_slicing() -> None: clear_registry() diff --git a/python/oxyde/tests/unit/test_type_binding.py b/python/oxyde/tests/unit/test_type_binding.py index 5ee7074..dff81f8 100644 --- a/python/oxyde/tests/unit/test_type_binding.py +++ b/python/oxyde/tests/unit/test_type_binding.py @@ -8,6 +8,7 @@ from datetime import date, datetime, time, timedelta from decimal import Decimal +from enum import Enum from uuid import UUID import pytest @@ -33,6 +34,11 @@ DEC = Decimal("99.99") +class Status(Enum): + DRAFT = "draft" + PUBLISHED = "published" + + # -- Model covering all supported types -- @@ -56,6 +62,8 @@ class TypeModel(Model): decimal_tags: list[Decimal] | None = Field( default=None, db_nullable=True, max_digits=10, decimal_places=2 ) + status: Status = Field(default=Status.DRAFT) + status_tags: list[Status] | None = Field(default=None, db_nullable=True) class Meta: is_table = True @@ -82,6 +90,7 @@ class Meta: ("decimal", {"price": DEC}, [("Decimal", str(DEC))]), ("bytes", {"blob": b"hello"}, [("Bytes", b"hello")]), ("dict_json", {"data": {"key": "val"}}, [("Json", '{"key":"val"}')]), + ("enum", {"status": Status.DRAFT}, [("String", "draft")]), ( "list_str_array", {"tags": ["a", "b"]}, @@ -102,6 +111,11 @@ class Meta: {"decimal_tags": [DEC]}, [("Array", ("Decimal", [str(DEC)]))], ), + ( + "list_enum_array", + {"status_tags": [Status.DRAFT, Status.PUBLISHED]}, + [("Array", ("String", ["draft", "published"]))], + ), ] @@ -225,6 +239,12 @@ def test_without_types_returns_plain_values(): assert isinstance(params[1], str) +def test_postgres_enum_filter_casts_parameter(): + sql, params = TypeModel.objects.filter(status=Status.DRAFT).sql(dialect="postgres") + assert '$1::"status_enum"' in sql + assert params == ["draft"] + + # ---- count() / exists() use to_ir() and inherit column_types ----