Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 165 additions & 25 deletions crates/iceberg/src/transaction/rollback_to_snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use crate::{Error, ErrorKind, TableRequirement, TableUpdate};
#[derive(Default)]
pub struct RollbackToSnapshotAction {
snapshot_id: Option<i64>,
rollback_schema: bool,
}

impl RollbackToSnapshotAction {
Expand All @@ -42,6 +43,12 @@ impl RollbackToSnapshotAction {
self.snapshot_id = Some(snapshot_id);
self
}

/// Updates the table's current schema to match the schema of the target snapshot.
pub fn with_schema_rollback(mut self) -> Self {
self.rollback_schema = true;
self
}
}

#[async_trait]
Expand All @@ -51,10 +58,9 @@ impl TransactionAction for RollbackToSnapshotAction {
return Err(Error::new(ErrorKind::DataInvalid, "snapshot id is not set"));
};

table
let snapshot = table
.metadata()
.snapshots()
.find(|s| s.snapshot_id() == snapshot_id)
.snapshot_by_id(snapshot_id)
.ok_or_else(|| {
Error::new(
ErrorKind::DataInvalid,
Expand All @@ -68,12 +74,12 @@ impl TransactionAction for RollbackToSnapshotAction {
let reference =
SnapshotReference::new(snapshot_id, SnapshotRetention::branch(None, None, None));

let updates = vec![TableUpdate::SetSnapshotRef {
let mut updates = vec![TableUpdate::SetSnapshotRef {
ref_name: MAIN_BRANCH.to_string(),
reference,
}];

let requirements = vec![
let mut requirements = vec![
TableRequirement::UuidMatch {
uuid: table.metadata().uuid(),
},
Expand All @@ -83,6 +89,17 @@ impl TransactionAction for RollbackToSnapshotAction {
},
];

let current_schema_id = table.metadata().current_schema_id();
if self.rollback_schema
&& let Some(snapshot_schema_id) = snapshot.schema_id()
&& current_schema_id != snapshot_schema_id
{
updates.push(TableUpdate::SetCurrentSchema {
schema_id: snapshot_schema_id,
});
requirements.push(TableRequirement::CurrentSchemaIdMatch { current_schema_id });
}

Ok(ActionCommit::new(updates, requirements))
}
}
Expand All @@ -95,6 +112,7 @@ mod tests {
use arrow_array::cast::AsArray;
use arrow_array::types::Int32Type;
use arrow_array::{RecordBatch, record_batch};
use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use itertools::Itertools;
use uuid::Uuid;
Expand All @@ -107,7 +125,7 @@ mod tests {
};
use crate::table::Table;
use crate::transaction::tests::make_v3_minimal_table_in_catalog;
use crate::transaction::{ApplyTransactionAction, Transaction, TransactionAction};
use crate::transaction::{AddColumn, ApplyTransactionAction, Transaction, TransactionAction};
use crate::writer::base_writer::data_file_writer::DataFileWriterBuilder;
use crate::writer::file_writer::ParquetWriterBuilder;
use crate::writer::file_writer::location_generator::{
Expand All @@ -116,19 +134,23 @@ mod tests {
use crate::writer::file_writer::rolling_writer::RollingFileWriterBuilder;
use crate::writer::{IcebergWriter, IcebergWriterBuilder};
use crate::{
Catalog, NamespaceIdent, TableCreation, TableIdent, TableRequirement, TableUpdate,
Catalog, Error, NamespaceIdent, TableCreation, TableIdent, TableRequirement, TableUpdate,
};

static FILE_NAME_GENERATOR: LazyLock<DefaultFileNameGenerator> = LazyLock::new(|| {
DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet)
});

async fn write_and_commit(table: &Table, catalog: &dyn Catalog, batch: RecordBatch) -> Table {
async fn apply_write_and_commit(
tx: Transaction,
table: &Table,
batch: RecordBatch,
) -> Result<Transaction, Error> {
let iceberg_schema = table.metadata().current_schema();
let arrow_schema = schema_to_arrow_schema(iceberg_schema).unwrap();
let batch = batch.with_schema(Arc::new(arrow_schema)).unwrap();
let arrow_schema = schema_to_arrow_schema(iceberg_schema)?;
let batch = batch.with_schema(Arc::new(arrow_schema))?;

let location_generator = DefaultLocationGenerator::new(table.metadata().clone()).unwrap();
let location_generator = DefaultLocationGenerator::new(table.metadata().clone())?;
let parquet_writer_builder = ParquetWriterBuilder::new(
parquet::file::properties::WriterProperties::default(),
table.metadata().current_schema().clone(),
Expand All @@ -140,14 +162,22 @@ mod tests {
FILE_NAME_GENERATOR.clone(),
);
let data_file_writer_builder = DataFileWriterBuilder::new(rolling_file_writer_builder);
let mut data_file_writer = data_file_writer_builder.build(None).await.unwrap();
data_file_writer.write(batch).await.unwrap();
let data_file = data_file_writer.close().await.unwrap();
let mut data_file_writer = data_file_writer_builder.build(None).await?;
data_file_writer.write(batch).await?;
let data_file = data_file_writer.close().await?;

let tx = Transaction::new(table);
let append_action = tx.fast_append().add_data_files(data_file);
let tx = append_action.apply(tx).unwrap();
tx.commit(catalog).await.unwrap()
append_action.apply(tx)
}

async fn write_and_commit(
table: &Table,
catalog: &dyn Catalog,
batch: RecordBatch,
) -> Result<Table, Error> {
let tx = Transaction::new(table);
let tx = apply_write_and_commit(tx, table, batch).await?;
tx.commit(catalog).await
}

async fn get_batches(table: &Table) -> Vec<RecordBatch> {
Expand All @@ -162,9 +192,7 @@ mod tests {
batch_stream.try_collect().await.unwrap()
}

#[tokio::test]
async fn test_rollback_to_snapshot() {
let catalog = new_memory_catalog().await;
async fn create_test_table(catalog: &dyn Catalog) -> Table {
let namespace_ident = NamespaceIdent::new(format!("ns-{}", Uuid::new_v4()));
let table_ident =
TableIdent::new(namespace_ident.clone(), format!("table-{}", Uuid::new_v4()));
Expand All @@ -187,10 +215,16 @@ mod tests {
.await
.unwrap();

let table = catalog
catalog
.create_table(&namespace_ident, table_creation)
.await
.unwrap();
.unwrap()
}

#[tokio::test]
async fn test_rollback_to_snapshot() {
let catalog = new_memory_catalog().await;
let table = create_test_table(&catalog).await;

let get_id_columns = |batches: &[RecordBatch]| {
batches
Expand All @@ -206,14 +240,20 @@ mod tests {
};

let insert_batch = record_batch!(("id", Int32, [1, 2])).unwrap();
let table = write_and_commit(&table, &catalog, insert_batch).await;
let table = write_and_commit(&table, &catalog, insert_batch)
.await
.unwrap();

let snapshot_id_1 = table.metadata().current_snapshot_id().unwrap();
let batch_1 = get_batches(&table).await;
let ids = get_id_columns(&batch_1);
assert_eq!(ids, [1, 2]);

let insert_batch = record_batch!(("id", Int32, [3, 4])).unwrap();
let table = write_and_commit(&table, &catalog, insert_batch).await;
let table = write_and_commit(&table, &catalog, insert_batch)
.await
.unwrap();

let snapshot_id_2 = table.metadata().current_snapshot_id().unwrap();
let batch_2 = get_batches(&table).await;
let ids = get_id_columns(&batch_2);
Expand All @@ -233,7 +273,10 @@ mod tests {
assert_eq!(ids, [1, 2]);

let insert_batch = record_batch!(("id", Int32, [5, 6])).unwrap();
let table = write_and_commit(&table, &catalog, insert_batch).await;
let table = write_and_commit(&table, &catalog, insert_batch)
.await
.unwrap();

let snapshot_id_3 = table.metadata().current_snapshot_id().unwrap();
assert_ne!(snapshot_id_3, snapshot_id_2);

Expand Down Expand Up @@ -268,6 +311,103 @@ mod tests {
assert_eq!(ids, [1, 2, 5, 6]);
}

struct SchemaRollbackTestSetup {
table: Table,
snapshot_id_before_schema_change: i64,
initial_schema_id: i32,
}

async fn create_table_with_modified_schema(catalog: &dyn Catalog) -> SchemaRollbackTestSetup {
let table = create_test_table(catalog).await;

let insert_batch = record_batch!(("id", Int32, [1, 2])).unwrap();
let table = write_and_commit(&table, catalog, insert_batch)
.await
.unwrap();

let snapshot_id_before_schema_change = table.metadata().current_snapshot_id().unwrap();
let initial_schema_id = table.metadata().current_schema_id();

let tx = Transaction::new(&table);
let add_column = AddColumn::optional("new_column", Type::Primitive(PrimitiveType::Int));
let action = tx.update_schema().add_column(add_column);
let tx = action.apply(tx).unwrap();
let table = tx.commit(catalog).await.unwrap();

let new_schema = Schema::new(vec![
Field::new("id", DataType::Int32, true),
Field::new("new_column", DataType::Int32, true),
]);
let insert_batch =
record_batch!(("id", Int32, [3, 4]), ("new_column", Int32, [100, 200])).unwrap();
let insert_batch = insert_batch.with_schema(Arc::new(new_schema)).unwrap();
let table = write_and_commit(&table, catalog, insert_batch)
.await
.unwrap();

assert_eq!(table.metadata().current_schema_id(), initial_schema_id + 1);
let current_schema = table.metadata().current_schema();
assert!(current_schema.field_by_name("new_column").is_some());

SchemaRollbackTestSetup {
table,
snapshot_id_before_schema_change,
initial_schema_id,
}
}

#[tokio::test]
async fn test_rollback_to_snapshot_with_schema_update() {
let catalog = new_memory_catalog().await;
let setup = create_table_with_modified_schema(&catalog).await;

let tx = Transaction::new(&setup.table);
let action = tx
.rollback_to_snapshot()
.set_snapshot_id(setup.snapshot_id_before_schema_change)
.apply(tx)
.unwrap();
let table = action.commit(&catalog).await.unwrap();

assert_eq!(
table.metadata().current_snapshot_id(),
Some(setup.snapshot_id_before_schema_change)
);

let current_schema = table.metadata().current_schema();
assert_eq!(
table.metadata().current_schema_id(),
setup.initial_schema_id + 1,
);
assert!(current_schema.field_by_name("new_column").is_some());
}

#[tokio::test]
async fn test_rollback_to_snapshot_with_schema_rollback() {
let catalog = new_memory_catalog().await;
let setup = create_table_with_modified_schema(&catalog).await;

let tx = Transaction::new(&setup.table);
let action = tx
.rollback_to_snapshot()
.set_snapshot_id(setup.snapshot_id_before_schema_change)
.with_schema_rollback()
.apply(tx)
.unwrap();

let table = action.commit(&catalog).await.unwrap();
assert_eq!(
table.metadata().current_snapshot_id(),
Some(setup.snapshot_id_before_schema_change)
);
assert_eq!(
table.metadata().current_schema_id(),
setup.initial_schema_id
);
let current_schema = table.metadata().current_schema();
assert!(current_schema.field_by_name("new_column").is_none());
}

async fn insert_data(catalog: &dyn Catalog, table_ident: &TableIdent) -> Table {
let table = catalog.load_table(table_ident).await.unwrap();
let data_file = DataFileBuilder::default()
Expand Down
Loading