diff --git a/crates/oxyde-migrate/src/diff.rs b/crates/oxyde-migrate/src/diff.rs index e807ba1..023e44a 100644 --- a/crates/oxyde-migrate/src/diff.rs +++ b/crates/oxyde-migrate/src/diff.rs @@ -316,6 +316,31 @@ pub fn compute_diff(old: &Snapshot, new: &Snapshot) -> Result> } } + // Find dropped and changed check constraints + for old_check in &old_table.checks { + match new_table.checks.iter().find(|c| c.name == old_check.name) { + Some(new_check) if new_check.expression != old_check.expression => { + ops.push(MigrationOp::DropCheck { + table: name.clone(), + name: old_check.name.clone(), + check_def: Some(old_check.clone()), + }); + ops.push(MigrationOp::AddCheck { + table: name.clone(), + check: new_check.clone(), + }); + } + None => { + ops.push(MigrationOp::DropCheck { + table: name.clone(), + name: old_check.name.clone(), + check_def: Some(old_check.clone()), + }); + } + _ => {} + } + } + // Find added check constraints for new_check in &new_table.checks { if !old_table.checks.iter().any(|c| c.name == new_check.name) { @@ -325,17 +350,6 @@ pub fn compute_diff(old: &Snapshot, new: &Snapshot) -> Result> }); } } - - // Find dropped check constraints - for old_check in &old_table.checks { - if !new_table.checks.iter().any(|c| c.name == old_check.name) { - ops.push(MigrationOp::DropCheck { - table: name.clone(), - name: old_check.name.clone(), - check_def: Some(old_check.clone()), - }); - } - } } } diff --git a/crates/oxyde-migrate/tests/migration_tests.rs b/crates/oxyde-migrate/tests/migration_tests.rs index d59e8f7..72ad4b7 100644 --- a/crates/oxyde-migrate/tests/migration_tests.rs +++ b/crates/oxyde-migrate/tests/migration_tests.rs @@ -65,6 +65,50 @@ fn test_snapshot_serialization_roundtrip() { assert_eq!(snapshot, deserialized); } +#[test] +fn test_check_constraint_expression_change_generates_drop_and_add() { + let mut old = Snapshot::new(); + let mut old_table = sample_table(); + old_table.checks.push(CheckDef { + name: "age_positive".into(), + expression: "age >= 0".into(), + }); + old.add_table(old_table); + + let mut new = Snapshot::new(); + let mut new_table = sample_table(); + new_table.checks.push(CheckDef { + name: "age_positive".into(), + expression: "age > 0".into(), + }); + new.add_table(new_table); + + let ops = compute_diff(&old, &new).unwrap(); + assert_eq!(ops.len(), 2); + + match &ops[0] { + MigrationOp::DropCheck { + table, + name, + check_def, + } => { + assert_eq!(table, "users"); + assert_eq!(name, "age_positive"); + assert_eq!(check_def.as_ref().unwrap().expression, "age >= 0"); + } + op => panic!("expected DropCheck, got {:?}", op), + } + + match &ops[1] { + MigrationOp::AddCheck { table, check } => { + assert_eq!(table, "users"); + assert_eq!(check.name, "age_positive"); + assert_eq!(check.expression, "age > 0"); + } + op => panic!("expected AddCheck, got {:?}", op), + } +} + #[test] fn test_migration_create_table_generates_sql() { let sql = MigrationOp::CreateTable { diff --git a/python/oxyde/tests/unit/test_migrations_execution.py b/python/oxyde/tests/unit/test_migrations_execution.py index 8ec375f..32deccd 100644 --- a/python/oxyde/tests/unit/test_migrations_execution.py +++ b/python/oxyde/tests/unit/test_migrations_execution.py @@ -909,6 +909,27 @@ def test_roundtrip_add_check(self, dialect): assert any(op["type"] == "add_check" for op in json.loads(ops_json)) migration_to_sql(ops_json, dialect) + @pytest.mark.parametrize("dialect", NON_SQLITE) + def test_roundtrip_change_check_expression(self, dialect): + old = _base_snapshot() + old["tables"]["users"]["checks"].append( + {"name": "chk_users_age", "expression": "age >= 0"} + ) + new = _base_snapshot() + new["tables"]["users"]["checks"].append( + {"name": "chk_users_age", "expression": "age > 0"} + ) + + ops_json = migration_compute_diff(json.dumps(old), json.dumps(new)) + ops = json.loads(ops_json) + + assert [op["type"] for op in ops] == ["drop_check", "add_check"] + assert ops[0]["name"] == "chk_users_age" + assert ops[0]["check_def"]["expression"] == "age >= 0" + assert ops[1]["check"]["name"] == "chk_users_age" + assert ops[1]["check"]["expression"] == "age > 0" + migration_to_sql(ops_json, dialect) + @pytest.mark.parametrize("dialect", NON_SQLITE) def test_roundtrip_drop_check(self, dialect): old = _base_snapshot()