Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions crates/oxyde-migrate/src/diff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,31 @@ pub fn compute_diff(old: &Snapshot, new: &Snapshot) -> Result<Vec<MigrationOp>>
}
}

// Find dropped and changed check constraints
for old_check in &old_table.checks {
match new_table.checks.iter().find(|c| c.name == old_check.name) {
Some(new_check) if new_check.expression != old_check.expression => {
ops.push(MigrationOp::DropCheck {
table: name.clone(),
name: old_check.name.clone(),
check_def: Some(old_check.clone()),
});
ops.push(MigrationOp::AddCheck {
table: name.clone(),
check: new_check.clone(),
});
}
None => {
ops.push(MigrationOp::DropCheck {
table: name.clone(),
name: old_check.name.clone(),
check_def: Some(old_check.clone()),
});
}
_ => {}
}
}

// Find added check constraints
for new_check in &new_table.checks {
if !old_table.checks.iter().any(|c| c.name == new_check.name) {
Expand All @@ -325,17 +350,6 @@ pub fn compute_diff(old: &Snapshot, new: &Snapshot) -> Result<Vec<MigrationOp>>
});
}
}

// Find dropped check constraints
for old_check in &old_table.checks {
if !new_table.checks.iter().any(|c| c.name == old_check.name) {
ops.push(MigrationOp::DropCheck {
table: name.clone(),
name: old_check.name.clone(),
check_def: Some(old_check.clone()),
});
}
}
}
}

Expand Down
44 changes: 44 additions & 0 deletions crates/oxyde-migrate/tests/migration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,50 @@ fn test_snapshot_serialization_roundtrip() {
assert_eq!(snapshot, deserialized);
}

#[test]
fn test_check_constraint_expression_change_generates_drop_and_add() {
let mut old = Snapshot::new();
let mut old_table = sample_table();
old_table.checks.push(CheckDef {
name: "age_positive".into(),
expression: "age >= 0".into(),
});
old.add_table(old_table);

let mut new = Snapshot::new();
let mut new_table = sample_table();
new_table.checks.push(CheckDef {
name: "age_positive".into(),
expression: "age > 0".into(),
});
new.add_table(new_table);

let ops = compute_diff(&old, &new).unwrap();
assert_eq!(ops.len(), 2);

match &ops[0] {
MigrationOp::DropCheck {
table,
name,
check_def,
} => {
assert_eq!(table, "users");
assert_eq!(name, "age_positive");
assert_eq!(check_def.as_ref().unwrap().expression, "age >= 0");
}
op => panic!("expected DropCheck, got {:?}", op),
}

match &ops[1] {
MigrationOp::AddCheck { table, check } => {
assert_eq!(table, "users");
assert_eq!(check.name, "age_positive");
assert_eq!(check.expression, "age > 0");
}
op => panic!("expected AddCheck, got {:?}", op),
}
}

#[test]
fn test_migration_create_table_generates_sql() {
let sql = MigrationOp::CreateTable {
Expand Down
21 changes: 21 additions & 0 deletions python/oxyde/tests/unit/test_migrations_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,27 @@ def test_roundtrip_add_check(self, dialect):
assert any(op["type"] == "add_check" for op in json.loads(ops_json))
migration_to_sql(ops_json, dialect)

@pytest.mark.parametrize("dialect", NON_SQLITE)
def test_roundtrip_change_check_expression(self, dialect):
old = _base_snapshot()
old["tables"]["users"]["checks"].append(
{"name": "chk_users_age", "expression": "age >= 0"}
)
new = _base_snapshot()
new["tables"]["users"]["checks"].append(
{"name": "chk_users_age", "expression": "age > 0"}
)

ops_json = migration_compute_diff(json.dumps(old), json.dumps(new))
ops = json.loads(ops_json)

assert [op["type"] for op in ops] == ["drop_check", "add_check"]
assert ops[0]["name"] == "chk_users_age"
assert ops[0]["check_def"]["expression"] == "age >= 0"
assert ops[1]["check"]["name"] == "chk_users_age"
assert ops[1]["check"]["expression"] == "age > 0"
migration_to_sql(ops_json, dialect)

@pytest.mark.parametrize("dialect", NON_SQLITE)
def test_roundtrip_drop_check(self, dialect):
old = _base_snapshot()
Expand Down
Loading