diff --git a/crates/iceberg/src/transaction/rollback_to_snapshot.rs b/crates/iceberg/src/transaction/rollback_to_snapshot.rs index db6da553b3..ac3c05fe06 100644 --- a/crates/iceberg/src/transaction/rollback_to_snapshot.rs +++ b/crates/iceberg/src/transaction/rollback_to_snapshot.rs @@ -29,6 +29,7 @@ use crate::{Error, ErrorKind, TableRequirement, TableUpdate}; #[derive(Default)] pub struct RollbackToSnapshotAction { snapshot_id: Option, + rollback_schema: bool, } impl RollbackToSnapshotAction { @@ -42,6 +43,12 @@ impl RollbackToSnapshotAction { self.snapshot_id = Some(snapshot_id); self } + + /// Updates the table's current schema to match the schema of the target snapshot. + pub fn with_schema_rollback(mut self) -> Self { + self.rollback_schema = true; + self + } } #[async_trait] @@ -51,10 +58,9 @@ impl TransactionAction for RollbackToSnapshotAction { return Err(Error::new(ErrorKind::DataInvalid, "snapshot id is not set")); }; - table + let snapshot = table .metadata() - .snapshots() - .find(|s| s.snapshot_id() == snapshot_id) + .snapshot_by_id(snapshot_id) .ok_or_else(|| { Error::new( ErrorKind::DataInvalid, @@ -68,12 +74,12 @@ impl TransactionAction for RollbackToSnapshotAction { let reference = SnapshotReference::new(snapshot_id, SnapshotRetention::branch(None, None, None)); - let updates = vec![TableUpdate::SetSnapshotRef { + let mut updates = vec![TableUpdate::SetSnapshotRef { ref_name: MAIN_BRANCH.to_string(), reference, }]; - let requirements = vec![ + let mut requirements = vec![ TableRequirement::UuidMatch { uuid: table.metadata().uuid(), }, @@ -83,6 +89,17 @@ impl TransactionAction for RollbackToSnapshotAction { }, ]; + let current_schema_id = table.metadata().current_schema_id(); + if self.rollback_schema + && let Some(snapshot_schema_id) = snapshot.schema_id() + && current_schema_id != snapshot_schema_id + { + updates.push(TableUpdate::SetCurrentSchema { + schema_id: snapshot_schema_id, + }); + requirements.push(TableRequirement::CurrentSchemaIdMatch { current_schema_id }); + } + Ok(ActionCommit::new(updates, requirements)) } } @@ -95,6 +112,7 @@ mod tests { use arrow_array::cast::AsArray; use arrow_array::types::Int32Type; use arrow_array::{RecordBatch, record_batch}; + use arrow_schema::{DataType, Field, Schema}; use futures::TryStreamExt; use itertools::Itertools; use uuid::Uuid; @@ -107,7 +125,7 @@ mod tests { }; use crate::table::Table; use crate::transaction::tests::make_v3_minimal_table_in_catalog; - use crate::transaction::{ApplyTransactionAction, Transaction, TransactionAction}; + use crate::transaction::{AddColumn, ApplyTransactionAction, Transaction, TransactionAction}; use crate::writer::base_writer::data_file_writer::DataFileWriterBuilder; use crate::writer::file_writer::ParquetWriterBuilder; use crate::writer::file_writer::location_generator::{ @@ -116,19 +134,23 @@ mod tests { use crate::writer::file_writer::rolling_writer::RollingFileWriterBuilder; use crate::writer::{IcebergWriter, IcebergWriterBuilder}; use crate::{ - Catalog, NamespaceIdent, TableCreation, TableIdent, TableRequirement, TableUpdate, + Catalog, Error, NamespaceIdent, TableCreation, TableIdent, TableRequirement, TableUpdate, }; static FILE_NAME_GENERATOR: LazyLock = LazyLock::new(|| { DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet) }); - async fn write_and_commit(table: &Table, catalog: &dyn Catalog, batch: RecordBatch) -> Table { + async fn apply_write_and_commit( + tx: Transaction, + table: &Table, + batch: RecordBatch, + ) -> Result { let iceberg_schema = table.metadata().current_schema(); - let arrow_schema = schema_to_arrow_schema(iceberg_schema).unwrap(); - let batch = batch.with_schema(Arc::new(arrow_schema)).unwrap(); + let arrow_schema = schema_to_arrow_schema(iceberg_schema)?; + let batch = batch.with_schema(Arc::new(arrow_schema))?; - let location_generator = DefaultLocationGenerator::new(table.metadata().clone()).unwrap(); + let location_generator = DefaultLocationGenerator::new(table.metadata().clone())?; let parquet_writer_builder = ParquetWriterBuilder::new( parquet::file::properties::WriterProperties::default(), table.metadata().current_schema().clone(), @@ -140,14 +162,22 @@ mod tests { FILE_NAME_GENERATOR.clone(), ); let data_file_writer_builder = DataFileWriterBuilder::new(rolling_file_writer_builder); - let mut data_file_writer = data_file_writer_builder.build(None).await.unwrap(); - data_file_writer.write(batch).await.unwrap(); - let data_file = data_file_writer.close().await.unwrap(); + let mut data_file_writer = data_file_writer_builder.build(None).await?; + data_file_writer.write(batch).await?; + let data_file = data_file_writer.close().await?; - let tx = Transaction::new(table); let append_action = tx.fast_append().add_data_files(data_file); - let tx = append_action.apply(tx).unwrap(); - tx.commit(catalog).await.unwrap() + append_action.apply(tx) + } + + async fn write_and_commit( + table: &Table, + catalog: &dyn Catalog, + batch: RecordBatch, + ) -> Result { + let tx = Transaction::new(table); + let tx = apply_write_and_commit(tx, table, batch).await?; + tx.commit(catalog).await } async fn get_batches(table: &Table) -> Vec { @@ -162,9 +192,7 @@ mod tests { batch_stream.try_collect().await.unwrap() } - #[tokio::test] - async fn test_rollback_to_snapshot() { - let catalog = new_memory_catalog().await; + async fn create_test_table(catalog: &dyn Catalog) -> Table { let namespace_ident = NamespaceIdent::new(format!("ns-{}", Uuid::new_v4())); let table_ident = TableIdent::new(namespace_ident.clone(), format!("table-{}", Uuid::new_v4())); @@ -187,10 +215,16 @@ mod tests { .await .unwrap(); - let table = catalog + catalog .create_table(&namespace_ident, table_creation) .await - .unwrap(); + .unwrap() + } + + #[tokio::test] + async fn test_rollback_to_snapshot() { + let catalog = new_memory_catalog().await; + let table = create_test_table(&catalog).await; let get_id_columns = |batches: &[RecordBatch]| { batches @@ -206,14 +240,20 @@ mod tests { }; let insert_batch = record_batch!(("id", Int32, [1, 2])).unwrap(); - let table = write_and_commit(&table, &catalog, insert_batch).await; + let table = write_and_commit(&table, &catalog, insert_batch) + .await + .unwrap(); + let snapshot_id_1 = table.metadata().current_snapshot_id().unwrap(); let batch_1 = get_batches(&table).await; let ids = get_id_columns(&batch_1); assert_eq!(ids, [1, 2]); let insert_batch = record_batch!(("id", Int32, [3, 4])).unwrap(); - let table = write_and_commit(&table, &catalog, insert_batch).await; + let table = write_and_commit(&table, &catalog, insert_batch) + .await + .unwrap(); + let snapshot_id_2 = table.metadata().current_snapshot_id().unwrap(); let batch_2 = get_batches(&table).await; let ids = get_id_columns(&batch_2); @@ -233,7 +273,10 @@ mod tests { assert_eq!(ids, [1, 2]); let insert_batch = record_batch!(("id", Int32, [5, 6])).unwrap(); - let table = write_and_commit(&table, &catalog, insert_batch).await; + let table = write_and_commit(&table, &catalog, insert_batch) + .await + .unwrap(); + let snapshot_id_3 = table.metadata().current_snapshot_id().unwrap(); assert_ne!(snapshot_id_3, snapshot_id_2); @@ -268,6 +311,103 @@ mod tests { assert_eq!(ids, [1, 2, 5, 6]); } + struct SchemaRollbackTestSetup { + table: Table, + snapshot_id_before_schema_change: i64, + initial_schema_id: i32, + } + + async fn create_table_with_modified_schema(catalog: &dyn Catalog) -> SchemaRollbackTestSetup { + let table = create_test_table(catalog).await; + + let insert_batch = record_batch!(("id", Int32, [1, 2])).unwrap(); + let table = write_and_commit(&table, catalog, insert_batch) + .await + .unwrap(); + + let snapshot_id_before_schema_change = table.metadata().current_snapshot_id().unwrap(); + let initial_schema_id = table.metadata().current_schema_id(); + + let tx = Transaction::new(&table); + let add_column = AddColumn::optional("new_column", Type::Primitive(PrimitiveType::Int)); + let action = tx.update_schema().add_column(add_column); + let tx = action.apply(tx).unwrap(); + let table = tx.commit(catalog).await.unwrap(); + + let new_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("new_column", DataType::Int32, true), + ]); + let insert_batch = + record_batch!(("id", Int32, [3, 4]), ("new_column", Int32, [100, 200])).unwrap(); + let insert_batch = insert_batch.with_schema(Arc::new(new_schema)).unwrap(); + let table = write_and_commit(&table, catalog, insert_batch) + .await + .unwrap(); + + assert_eq!(table.metadata().current_schema_id(), initial_schema_id + 1); + let current_schema = table.metadata().current_schema(); + assert!(current_schema.field_by_name("new_column").is_some()); + + SchemaRollbackTestSetup { + table, + snapshot_id_before_schema_change, + initial_schema_id, + } + } + + #[tokio::test] + async fn test_rollback_to_snapshot_with_schema_update() { + let catalog = new_memory_catalog().await; + let setup = create_table_with_modified_schema(&catalog).await; + + let tx = Transaction::new(&setup.table); + let action = tx + .rollback_to_snapshot() + .set_snapshot_id(setup.snapshot_id_before_schema_change) + .apply(tx) + .unwrap(); + let table = action.commit(&catalog).await.unwrap(); + + assert_eq!( + table.metadata().current_snapshot_id(), + Some(setup.snapshot_id_before_schema_change) + ); + + let current_schema = table.metadata().current_schema(); + assert_eq!( + table.metadata().current_schema_id(), + setup.initial_schema_id + 1, + ); + assert!(current_schema.field_by_name("new_column").is_some()); + } + + #[tokio::test] + async fn test_rollback_to_snapshot_with_schema_rollback() { + let catalog = new_memory_catalog().await; + let setup = create_table_with_modified_schema(&catalog).await; + + let tx = Transaction::new(&setup.table); + let action = tx + .rollback_to_snapshot() + .set_snapshot_id(setup.snapshot_id_before_schema_change) + .with_schema_rollback() + .apply(tx) + .unwrap(); + + let table = action.commit(&catalog).await.unwrap(); + assert_eq!( + table.metadata().current_snapshot_id(), + Some(setup.snapshot_id_before_schema_change) + ); + assert_eq!( + table.metadata().current_schema_id(), + setup.initial_schema_id + ); + let current_schema = table.metadata().current_schema(); + assert!(current_schema.field_by_name("new_column").is_none()); + } + async fn insert_data(catalog: &dyn Catalog, table_ident: &TableIdent) -> Table { let table = catalog.load_table(table_ident).await.unwrap(); let data_file = DataFileBuilder::default()