Skip to content

Commit

Permalink
Spark 3.4: Support setting current snapshot to ref
Browse files Browse the repository at this point in the history
  • Loading branch information
manuzhang committed Aug 8, 2023
1 parent 6b1c9f0 commit 744d706
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,12 @@ public void testInvalidRollbackToSnapshotCases() {

Assertions.assertThatThrownBy(
() -> sql("CALL %s.system.set_current_snapshot('t')", catalogName))
.isInstanceOf(AnalysisException.class)
.hasMessage("Missing required parameters: [snapshot_id]");
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("snapshot_id and ref cannot both be null");

Assertions.assertThatThrownBy(() -> sql("CALL %s.system.set_current_snapshot(1L)", catalogName))
.isInstanceOf(AnalysisException.class)
.hasMessage("Missing required parameters: [snapshot_id]");
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Cannot parse identifier for arg table: 1");

Assertions.assertThatThrownBy(
() -> sql("CALL %s.system.set_current_snapshot(snapshot_id => 1L)", catalogName))
Expand All @@ -226,8 +226,8 @@ public void testInvalidRollbackToSnapshotCases() {

Assertions.assertThatThrownBy(
() -> sql("CALL %s.system.set_current_snapshot(table => 't')", catalogName))
.isInstanceOf(AnalysisException.class)
.hasMessage("Missing required parameters: [snapshot_id]");
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("snapshot_id and ref cannot both be null");

Assertions.assertThatThrownBy(
() -> sql("CALL %s.system.set_current_snapshot('t', 2.2)", catalogName))
Expand All @@ -239,4 +239,49 @@ public void testInvalidRollbackToSnapshotCases() {
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Cannot handle an empty identifier for argument table");
}

@Test
public void testSetCurrentSnapshotToRef() {
sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);

Table table = validationCatalog.loadTable(tableIdent);
Snapshot firstSnapshot = table.currentSnapshot();
String ref = "s1";
sql("ALTER TABLE %s CREATE TAG %s", tableName, ref);

sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
assertEquals(
"Should have expected rows",
ImmutableList.of(row(1L, "a"), row(1L, "a")),
sql("SELECT * FROM %s ORDER BY id", tableName));

table.refresh();

Snapshot secondSnapshot = table.currentSnapshot();

List<Object[]> output =
sql(
"CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')",
catalogName, tableIdent, ref);

assertEquals(
"Procedure output must match",
ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
output);

assertEquals(
"Set must be successful",
ImmutableList.of(row(1L, "a")),
sql("SELECT * FROM %s ORDER BY id", tableName));

String notExistRef = "s2";
Assertions.assertThatThrownBy(
() ->
sql(
"CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')",
catalogName, tableIdent, notExistRef))
.isInstanceOf(ValidationException.class)
.hasMessage("Cannot find matching snapshot ID for ref " + notExistRef);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
package org.apache.iceberg.spark.procedures;

import org.apache.iceberg.Snapshot;
import org.apache.iceberg.SnapshotRef;
import org.apache.iceberg.Table;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.Identifier;
Expand All @@ -42,7 +46,8 @@ class SetCurrentSnapshotProcedure extends BaseProcedure {
private static final ProcedureParameter[] PARAMETERS =
new ProcedureParameter[] {
ProcedureParameter.required("table", DataTypes.StringType),
ProcedureParameter.required("snapshot_id", DataTypes.LongType)
ProcedureParameter.optional("snapshot_id", DataTypes.LongType),
ProcedureParameter.optional("ref", DataTypes.StringType)
};

private static final StructType OUTPUT_TYPE =
Expand Down Expand Up @@ -78,17 +83,21 @@ public StructType outputType() {
@Override
public InternalRow[] call(InternalRow args) {
Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name());
long snapshotId = args.getLong(1);
Long snapshotId = args.isNullAt(1) ? null : args.getLong(1);
String ref = args.isNullAt(2) ? null : args.getString(2);
Preconditions.checkArgument(
snapshotId != null || ref != null, "snapshot_id and ref cannot both be null");

return modifyIcebergTable(
tableIdent,
table -> {
Snapshot previousSnapshot = table.currentSnapshot();
Long previousSnapshotId = previousSnapshot != null ? previousSnapshot.snapshotId() : null;

table.manageSnapshots().setCurrentSnapshot(snapshotId).commit();
long sid = snapshotId != null ? snapshotId : getSnapshotIdFromRef(table, ref);
table.manageSnapshots().setCurrentSnapshot(sid).commit();

InternalRow outputRow = newInternalRow(previousSnapshotId, snapshotId);
InternalRow outputRow = newInternalRow(previousSnapshotId, sid);
return new InternalRow[] {outputRow};
});
}
Expand All @@ -97,4 +106,10 @@ public InternalRow[] call(InternalRow args) {
public String description() {
return "SetCurrentSnapshotProcedure";
}

private long getSnapshotIdFromRef(Table table, String refName) {
SnapshotRef ref = table.refs().get(refName);
ValidationException.check(ref != null, "Cannot find matching snapshot ID for ref " + refName);
return ref.snapshotId();
}
}

0 comments on commit 744d706

Please sign in to comment.