diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java index 51db8d321059..e1ea2207e630 100644 --- a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java @@ -212,12 +212,12 @@ public void testInvalidRollbackToSnapshotCases() { Assertions.assertThatThrownBy( () -> sql("CALL %s.system.set_current_snapshot('t')", catalogName)) - .isInstanceOf(AnalysisException.class) - .hasMessage("Missing required parameters: [snapshot_id]"); + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("snapshot_id and ref cannot both be null"); Assertions.assertThatThrownBy(() -> sql("CALL %s.system.set_current_snapshot(1L)", catalogName)) - .isInstanceOf(AnalysisException.class) - .hasMessage("Missing required parameters: [snapshot_id]"); + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot parse identifier for arg table: 1"); Assertions.assertThatThrownBy( () -> sql("CALL %s.system.set_current_snapshot(snapshot_id => 1L)", catalogName)) @@ -226,8 +226,8 @@ public void testInvalidRollbackToSnapshotCases() { Assertions.assertThatThrownBy( () -> sql("CALL %s.system.set_current_snapshot(table => 't')", catalogName)) - .isInstanceOf(AnalysisException.class) - .hasMessage("Missing required parameters: [snapshot_id]"); + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("snapshot_id and ref cannot both be null"); Assertions.assertThatThrownBy( () -> sql("CALL %s.system.set_current_snapshot('t', 2.2)", catalogName)) @@ -239,4 +239,49 @@ public void testInvalidRollbackToSnapshotCases() { .isInstanceOf(IllegalArgumentException.class) .hasMessage("Cannot handle an empty identifier for argument table"); } + + @Test + public void testSetCurrentSnapshotToRef() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String ref = "s1"; + sql("ALTER TABLE %s CREATE TAG %s", tableName, ref); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')", + catalogName, tableIdent, ref); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Set must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + String notExistRef = "s2"; + Assertions.assertThatThrownBy( + () -> + sql( + "CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')", + catalogName, tableIdent, notExistRef)) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot find matching snapshot ID for ref " + notExistRef); + } } diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java index f8f8049c22b6..d6928cd3e764 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java @@ -19,6 +19,10 @@ package org.apache.iceberg.spark.procedures; import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.Identifier; @@ -42,7 +46,8 @@ class SetCurrentSnapshotProcedure extends BaseProcedure { private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[] { ProcedureParameter.required("table", DataTypes.StringType), - ProcedureParameter.required("snapshot_id", DataTypes.LongType) + ProcedureParameter.optional("snapshot_id", DataTypes.LongType), + ProcedureParameter.optional("ref", DataTypes.StringType) }; private static final StructType OUTPUT_TYPE = @@ -78,7 +83,10 @@ public StructType outputType() { @Override public InternalRow[] call(InternalRow args) { Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); - long snapshotId = args.getLong(1); + Long snapshotId = args.isNullAt(1) ? null : args.getLong(1); + String ref = args.isNullAt(2) ? null : args.getString(2); + Preconditions.checkArgument( + snapshotId != null || ref != null, "snapshot_id and ref cannot both be null"); return modifyIcebergTable( tableIdent, @@ -86,9 +94,10 @@ public InternalRow[] call(InternalRow args) { Snapshot previousSnapshot = table.currentSnapshot(); Long previousSnapshotId = previousSnapshot != null ? previousSnapshot.snapshotId() : null; - table.manageSnapshots().setCurrentSnapshot(snapshotId).commit(); + long sid = snapshotId != null ? snapshotId : getSnapshotIdFromRef(table, ref); + table.manageSnapshots().setCurrentSnapshot(sid).commit(); - InternalRow outputRow = newInternalRow(previousSnapshotId, snapshotId); + InternalRow outputRow = newInternalRow(previousSnapshotId, sid); return new InternalRow[] {outputRow}; }); } @@ -97,4 +106,10 @@ public InternalRow[] call(InternalRow args) { public String description() { return "SetCurrentSnapshotProcedure"; } + + private long getSnapshotIdFromRef(Table table, String refName) { + SnapshotRef ref = table.refs().get(refName); + ValidationException.check(ref != null, "Cannot find matching snapshot ID for ref " + refName); + return ref.snapshotId(); + } }