Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions crates/oxyde-codec/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ pub enum ColumnTypeSpec {
Json,
/// JSONB on Postgres; identical to Json elsewhere.
JsonBinary,
Enum {
name: String,
values: Vec<String>,
},
Array {
item: Box<ColumnTypeSpec>,
},
Expand Down Expand Up @@ -613,6 +617,19 @@ mod tests {
);
}

#[test]
fn test_spec_enum() {
assert_eq!(
spec_from_json(
r#"{"kind": "enum", "name": "post_status_enum", "values": ["draft", "published"]}"#
),
ColumnTypeSpec::Enum {
name: "post_status_enum".into(),
values: vec!["draft".into(), "published".into()],
}
);
}

#[test]
fn test_spec_unknown_kind_is_error() {
let result: std::result::Result<ColumnTypeSpec, _> =
Expand All @@ -632,6 +649,10 @@ mod tests {
ColumnTypeSpec::Array {
item: Box::new(ColumnTypeSpec::Uuid),
},
ColumnTypeSpec::Enum {
name: "post_status_enum".into(),
values: vec!["draft".into(), "published".into()],
},
ColumnTypeSpec::Unknown,
];
for spec in specs {
Expand Down
2 changes: 1 addition & 1 deletion crates/oxyde-driver/src/convert/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl CellEncoder for MySqlEncoder {
}
true
}
ColumnTypeSpec::Text | ColumnTypeSpec::String { .. } => {
ColumnTypeSpec::Text | ColumnTypeSpec::String { .. } | ColumnTypeSpec::Enum { .. } => {
match row.try_get::<Option<String>, _>(idx) {
Ok(Some(v)) => write_str(buf, &v),
Ok(None) => write_nil(buf),
Expand Down
46 changes: 45 additions & 1 deletion crates/oxyde-driver/src/convert/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,46 @@
use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
use oxyde_codec::ColumnTypeSpec;
use rust_decimal::Decimal;
use sqlx::{postgres::PgRow, Row};
use sqlx::{
error::BoxDynError,
postgres::{PgHasArrayType, PgRow, PgTypeInfo, PgTypeKind, PgValueRef, Postgres},
Decode, Row, Type,
};
use uuid::Uuid;

use super::encoder::*;

pub struct PgEncoder;

#[derive(Debug, Clone)]
struct PgEnumText(String);

impl Type<Postgres> for PgEnumText {
fn type_info() -> PgTypeInfo {
<String as Type<Postgres>>::type_info()
}

fn compatible(ty: &PgTypeInfo) -> bool {
<String as Type<Postgres>>::compatible(ty) || matches!(ty.kind(), PgTypeKind::Enum(_))
}
}

impl PgHasArrayType for PgEnumText {
fn array_type_info() -> PgTypeInfo {
<String as PgHasArrayType>::array_type_info()
}

fn array_compatible(ty: &PgTypeInfo) -> bool {
matches!(ty.kind(), PgTypeKind::Array(element) if Self::compatible(element))
}
}

impl<'r> Decode<'r, Postgres> for PgEnumText {
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(Self(value.as_str()?.to_string()))
}
}

impl CellEncoder for PgEncoder {
type Row = PgRow;

Expand Down Expand Up @@ -44,6 +77,14 @@ impl CellEncoder for PgEncoder {
}
true
}
ColumnTypeSpec::Enum { .. } => {
match row.try_get::<Option<PgEnumText>, _>(idx) {
Ok(Some(v)) => write_str(buf, &v.0),
Ok(None) => write_nil(buf),
Err(_) => write_nil(buf),
}
true
}
ColumnTypeSpec::Double => {
match row.try_get::<Option<f64>, _>(idx) {
Ok(Some(v)) => write_f64(buf, v),
Expand Down Expand Up @@ -259,6 +300,9 @@ fn encode_pg_array(buf: &mut Vec<u8>, row: &PgRow, idx: usize, item: &ColumnType
ColumnTypeSpec::Text | ColumnTypeSpec::String { .. } => {
encode_pg_array_ref::<String>(buf, row, idx, |b, v| write_str(b, v));
}
ColumnTypeSpec::Enum { .. } => {
encode_pg_array_ref::<PgEnumText>(buf, row, idx, |b, v| write_str(b, &v.0));
}
ColumnTypeSpec::Uuid => {
encode_pg_array_ref::<Uuid>(buf, row, idx, |b, v| {
write_str(b, &v.to_string());
Expand Down
1 change: 1 addition & 0 deletions crates/oxyde-driver/src/convert/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ impl CellEncoder for SqliteEncoder {
// datetime, date, time, decimal, uuid — stored as TEXT in SQLite
ColumnTypeSpec::Text
| ColumnTypeSpec::String { .. }
| ColumnTypeSpec::Enum { .. }
| ColumnTypeSpec::DateTime
| ColumnTypeSpec::DateTimeUtc
| ColumnTypeSpec::Date
Expand Down
209 changes: 196 additions & 13 deletions crates/oxyde-migrate/src/diff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

use std::collections::{HashMap, HashSet};

use crate::op::MigrationOp;
use crate::op::{EnumFieldRef, MigrationOp};
use crate::types::{Dialect, MigrateError, Result, Snapshot, TableDef};
use oxyde_codec::ColumnTypeSpec;
use serde::{Deserialize, Serialize};

/// Topologically sort table names so that referenced tables come before
Expand Down Expand Up @@ -86,6 +87,136 @@ fn topo_sort_table_names(tables: &HashMap<String, TableDef>) -> Result<Vec<Strin
Ok(result)
}

fn collect_enum_defs(snapshot: &Snapshot) -> Result<HashMap<String, Vec<String>>> {
let mut defs = HashMap::new();
for table in snapshot.tables.values() {
for field in &table.fields {
collect_enum_def_from_spec(&field.column_type, &mut defs)?;
}
}
Ok(defs)
}

fn collect_enum_def_from_spec(
spec: &ColumnTypeSpec,
defs: &mut HashMap<String, Vec<String>>,
) -> Result<()> {
match spec {
ColumnTypeSpec::Enum { name, values } => {
if let Some(existing) = defs.get(name) {
if existing != values {
return Err(MigrateError::DiffError(format!(
"enum type '{}' has conflicting value sets",
name
)));
}
} else {
defs.insert(name.clone(), values.clone());
}
}
ColumnTypeSpec::Array { item } => collect_enum_def_from_spec(item, defs)?,
_ => {}
}
Ok(())
}

fn sorted_keys(map: &HashMap<String, Vec<String>>) -> Vec<String> {
let mut keys = map.keys().cloned().collect::<Vec<_>>();
keys.sort();
keys
}

fn enum_values_are_append_only(old_values: &[String], new_values: &[String]) -> bool {
new_values.len() >= old_values.len() && &new_values[..old_values.len()] == old_values
}

fn column_type_requires_alter(old: &ColumnTypeSpec, new: &ColumnTypeSpec) -> bool {
match (old, new) {
(
ColumnTypeSpec::Enum { name: old_name, .. },
ColumnTypeSpec::Enum { name: new_name, .. },
) => old_name != new_name,
(ColumnTypeSpec::Array { item: old_item }, ColumnTypeSpec::Array { item: new_item }) => {
column_type_requires_alter(old_item, new_item)
}
_ => old != new,
}
}

fn db_type_requires_alter(
old_type: &ColumnTypeSpec,
new_type: &ColumnTypeSpec,
old_db_type: &Option<String>,
new_db_type: &Option<String>,
) -> bool {
if !column_type_requires_alter(old_type, new_type) && contains_enum_type(old_type) {
return false;
}
old_db_type != new_db_type
}

fn contains_enum_type(spec: &ColumnTypeSpec) -> bool {
match spec {
ColumnTypeSpec::Enum { .. } => true,
ColumnTypeSpec::Array { item } => contains_enum_type(item),
_ => false,
}
}

fn scalar_enum_name(spec: &ColumnTypeSpec) -> Option<&str> {
match spec {
ColumnTypeSpec::Enum { name, .. } => Some(name),
_ => None,
}
}

fn existing_scalar_enum_fields(
old: &Snapshot,
new: &Snapshot,
enum_name: &str,
values: &[String],
) -> Vec<EnumFieldRef> {
let mut fields = Vec::new();
let mut table_names = old
.tables
.keys()
.filter(|name| new.tables.contains_key(*name))
.cloned()
.collect::<Vec<_>>();
table_names.sort();

for table_name in table_names {
let old_table = &old.tables[&table_name];
let new_table = &new.tables[&table_name];
for old_field in &old_table.fields {
if scalar_enum_name(&old_field.column_type) != Some(enum_name) {
continue;
}
if let Some(new_field) = new_table
.fields
.iter()
.find(|field| field.name == old_field.name)
.filter(|field| scalar_enum_name(&field.column_type) == Some(enum_name))
{
let mut field = new_field.clone();
if let ColumnTypeSpec::Enum {
values: field_values,
..
} = &mut field.column_type
{
*field_values = values.to_vec();
}
fields.push(EnumFieldRef {
table: table_name.clone(),
field,
});
}
}
}

fields
}

/// Migration file
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Migration {
Expand Down Expand Up @@ -124,19 +255,23 @@ impl Migration {
/// This ensures referenced tables exist before FK constraints are added
/// (PG/MySQL emit FK as separate ALTER TABLE, not inline in CREATE TABLE).
pub fn to_sql(&self, dialect: Dialect) -> Result<Vec<String>> {
let mut all_sql = Vec::new();
let mut all_sql: Vec<(u8, String)> = Vec::new();
for op in &self.operations {
let sqls = op.to_sql(dialect)?;
all_sql.extend(sqls);
}
all_sql.sort_by_key(|s| {
if s.trim_start().starts_with("ALTER") {
1
} else {
0
for sql in sqls {
let bucket = match op {
MigrationOp::CreateEnumType { .. } => 0,
MigrationOp::AddEnumValue { .. } => 1,
MigrationOp::DropEnumType { .. } => 4,
MigrationOp::AlterEnumType { .. } => 5,
_ if sql.trim_start().starts_with("ALTER TABLE") => 3,
_ => 2,
};
all_sql.push((bucket, sql));
}
});
Ok(all_sql)
}
all_sql.sort_by_key(|(bucket, _)| *bucket);
Ok(all_sql.into_iter().map(|(_, sql)| sql).collect())
}
}

Expand All @@ -147,6 +282,39 @@ impl Migration {
/// Cycles among unchanged tables are irrelevant and pass through silently.
pub fn compute_diff(old: &Snapshot, new: &Snapshot) -> Result<Vec<MigrationOp>> {
let mut ops = Vec::new();
let old_enums = collect_enum_defs(old)?;
let new_enums = collect_enum_defs(new)?;

for name in sorted_keys(&new_enums) {
if !old_enums.contains_key(&name) {
ops.push(MigrationOp::CreateEnumType {
name: name.clone(),
values: new_enums[&name].clone(),
});
}
}

for name in sorted_keys(&new_enums) {
if let Some(old_values) = old_enums.get(&name) {
let new_values = &new_enums[&name];
if enum_values_are_append_only(old_values, new_values) {
for (index, value) in new_values[old_values.len()..].iter().cloned().enumerate() {
let values = &new_values[..old_values.len() + index + 1];
ops.push(MigrationOp::AddEnumValue {
name: name.clone(),
value,
fields: existing_scalar_enum_fields(old, new, &name, values),
});
}
} else {
ops.push(MigrationOp::AlterEnumType {
name: name.clone(),
old_values: old_values.clone(),
new_values: new_values.clone(),
});
}
}
}

// Topo-sort only the subset of tables that are actually being created.
// FKs from this subset to tables that already exist in `old` are not
Expand Down Expand Up @@ -215,8 +383,14 @@ pub fn compute_diff(old: &Snapshot, new: &Snapshot) -> Result<Vec<MigrationOp>>
if let Some(old_field) = old_table.fields.iter().find(|f| f.name == new_field.name)
{
// Check if type changed using column_type or db_type
let type_changed = old_field.column_type != new_field.column_type
|| old_field.db_type != new_field.db_type;
let type_changed =
column_type_requires_alter(&old_field.column_type, &new_field.column_type)
|| db_type_requires_alter(
&old_field.column_type,
&new_field.column_type,
&old_field.db_type,
&new_field.db_type,
);

let nullable_changed = old_field.nullable != new_field.nullable;
let default_changed = old_field.default != new_field.default;
Expand Down Expand Up @@ -350,5 +524,14 @@ pub fn compute_diff(old: &Snapshot, new: &Snapshot) -> Result<Vec<MigrationOp>>
}
}

for name in sorted_keys(&old_enums) {
if !new_enums.contains_key(&name) {
ops.push(MigrationOp::DropEnumType {
name: name.clone(),
values: Some(old_enums[&name].clone()),
});
}
}

Ok(ops)
}
Loading