From 7f04c3ff4772e07a0476483b92f95d0fb6520ff3 Mon Sep 17 00:00:00 2001 From: manuzhang Date: Sat, 26 Aug 2023 09:13:52 +0800 Subject: [PATCH] Spark3.1, Spark3.2, Spark3.3: Support setting current snapshot with ref Back-port of https://github.com/apache/iceberg/pull/8163 to `spark/v3.3`, `spark/v3.2` and `spark/v3.1` --- .../TestSetCurrentSnapshotProcedure.java | 66 +++++++++++++++++-- .../SetCurrentSnapshotProcedure.java | 24 +++++-- .../TestSetCurrentSnapshotProcedure.java | 66 +++++++++++++++++-- .../SetCurrentSnapshotProcedure.java | 24 +++++-- .../TestSetCurrentSnapshotProcedure.java | 66 +++++++++++++++++-- .../SetCurrentSnapshotProcedure.java | 24 +++++-- 6 files changed, 240 insertions(+), 30 deletions(-) diff --git a/spark/v3.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java b/spark/v3.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java index 8a8a974bbebe..da101e46de28 100644 --- a/spark/v3.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java +++ b/spark/v3.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java @@ -31,6 +31,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Iterables; import org.apache.spark.sql.AnalysisException; import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.assertj.core.api.Assertions; import org.junit.After; import org.junit.Assume; import org.junit.Test; @@ -219,14 +220,14 @@ public void testInvalidRollbackToSnapshotCases() { AssertHelpers.assertThrows( "Should reject calls without all required args", - AnalysisException.class, - "Missing required parameters", + IllegalArgumentException.class, + "Either snapshot_id or ref must be provided, not both", () -> sql("CALL %s.system.set_current_snapshot('t')", catalogName)); AssertHelpers.assertThrows( "Should reject calls without all required args", - AnalysisException.class, - "Missing required parameters", + IllegalArgumentException.class, + "Cannot parse identifier for arg table: 1", () -> sql("CALL %s.system.set_current_snapshot(1L)", catalogName)); AssertHelpers.assertThrows( @@ -237,8 +238,8 @@ public void testInvalidRollbackToSnapshotCases() { AssertHelpers.assertThrows( "Should reject calls without all required args", - AnalysisException.class, - "Missing required parameters", + IllegalArgumentException.class, + "Either snapshot_id or ref must be provided, not both", () -> sql("CALL %s.system.set_current_snapshot(table => 't')", catalogName)); AssertHelpers.assertThrows( @@ -252,5 +253,58 @@ public void testInvalidRollbackToSnapshotCases() { IllegalArgumentException.class, "Cannot handle an empty identifier", () -> sql("CALL %s.system.set_current_snapshot('', 1L)", catalogName)); + + Assertions.assertThatThrownBy( + () -> + sql( + "CALL %s.system.set_current_snapshot(table => 't', snapshot_id => 1L, ref => 's1')", + catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Either snapshot_id or ref must be provided, not both"); + } + + @Test + public void testSetCurrentSnapshotToRef() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String ref = "s1"; + sql("ALTER TABLE %s CREATE TAG %s", tableName, ref); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')", + catalogName, tableIdent, ref); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Set must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + String notExistRef = "s2"; + Assertions.assertThatThrownBy( + () -> + sql( + "CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')", + catalogName, tableIdent, notExistRef)) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot find matching snapshot ID for ref " + notExistRef); } } diff --git a/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java b/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java index f8f8049c22b6..22719e43c057 100644 --- a/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java +++ b/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java @@ -19,6 +19,10 @@ package org.apache.iceberg.spark.procedures; import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.Identifier; @@ -42,7 +46,8 @@ class SetCurrentSnapshotProcedure extends BaseProcedure { private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[] { ProcedureParameter.required("table", DataTypes.StringType), - ProcedureParameter.required("snapshot_id", DataTypes.LongType) + ProcedureParameter.optional("snapshot_id", DataTypes.LongType), + ProcedureParameter.optional("ref", DataTypes.StringType) }; private static final StructType OUTPUT_TYPE = @@ -78,7 +83,11 @@ public StructType outputType() { @Override public InternalRow[] call(InternalRow args) { Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); - long snapshotId = args.getLong(1); + Long snapshotId = args.isNullAt(1) ? null : args.getLong(1); + String ref = args.isNullAt(2) ? null : args.getString(2); + Preconditions.checkArgument( + (snapshotId != null && ref == null) || (snapshotId == null && ref != null), + "Either snapshot_id or ref must be provided, not both"); return modifyIcebergTable( tableIdent, @@ -86,9 +95,10 @@ public InternalRow[] call(InternalRow args) { Snapshot previousSnapshot = table.currentSnapshot(); Long previousSnapshotId = previousSnapshot != null ? previousSnapshot.snapshotId() : null; - table.manageSnapshots().setCurrentSnapshot(snapshotId).commit(); + long targetSnapshotId = snapshotId != null ? snapshotId : toSnapshotId(table, ref); + table.manageSnapshots().setCurrentSnapshot(targetSnapshotId).commit(); - InternalRow outputRow = newInternalRow(previousSnapshotId, snapshotId); + InternalRow outputRow = newInternalRow(previousSnapshotId, targetSnapshotId); return new InternalRow[] {outputRow}; }); } @@ -97,4 +107,10 @@ public InternalRow[] call(InternalRow args) { public String description() { return "SetCurrentSnapshotProcedure"; } + + private long toSnapshotId(Table table, String refName) { + SnapshotRef ref = table.refs().get(refName); + ValidationException.check(ref != null, "Cannot find matching snapshot ID for ref " + refName); + return ref.snapshotId(); + } } diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java index 8a8a974bbebe..da101e46de28 100644 --- a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java +++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java @@ -31,6 +31,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Iterables; import org.apache.spark.sql.AnalysisException; import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.assertj.core.api.Assertions; import org.junit.After; import org.junit.Assume; import org.junit.Test; @@ -219,14 +220,14 @@ public void testInvalidRollbackToSnapshotCases() { AssertHelpers.assertThrows( "Should reject calls without all required args", - AnalysisException.class, - "Missing required parameters", + IllegalArgumentException.class, + "Either snapshot_id or ref must be provided, not both", () -> sql("CALL %s.system.set_current_snapshot('t')", catalogName)); AssertHelpers.assertThrows( "Should reject calls without all required args", - AnalysisException.class, - "Missing required parameters", + IllegalArgumentException.class, + "Cannot parse identifier for arg table: 1", () -> sql("CALL %s.system.set_current_snapshot(1L)", catalogName)); AssertHelpers.assertThrows( @@ -237,8 +238,8 @@ public void testInvalidRollbackToSnapshotCases() { AssertHelpers.assertThrows( "Should reject calls without all required args", - AnalysisException.class, - "Missing required parameters", + IllegalArgumentException.class, + "Either snapshot_id or ref must be provided, not both", () -> sql("CALL %s.system.set_current_snapshot(table => 't')", catalogName)); AssertHelpers.assertThrows( @@ -252,5 +253,58 @@ public void testInvalidRollbackToSnapshotCases() { IllegalArgumentException.class, "Cannot handle an empty identifier", () -> sql("CALL %s.system.set_current_snapshot('', 1L)", catalogName)); + + Assertions.assertThatThrownBy( + () -> + sql( + "CALL %s.system.set_current_snapshot(table => 't', snapshot_id => 1L, ref => 's1')", + catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Either snapshot_id or ref must be provided, not both"); + } + + @Test + public void testSetCurrentSnapshotToRef() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String ref = "s1"; + sql("ALTER TABLE %s CREATE TAG %s", tableName, ref); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')", + catalogName, tableIdent, ref); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Set must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + String notExistRef = "s2"; + Assertions.assertThatThrownBy( + () -> + sql( + "CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')", + catalogName, tableIdent, notExistRef)) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot find matching snapshot ID for ref " + notExistRef); } } diff --git a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java index f8f8049c22b6..22719e43c057 100644 --- a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java +++ b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java @@ -19,6 +19,10 @@ package org.apache.iceberg.spark.procedures; import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.Identifier; @@ -42,7 +46,8 @@ class SetCurrentSnapshotProcedure extends BaseProcedure { private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[] { ProcedureParameter.required("table", DataTypes.StringType), - ProcedureParameter.required("snapshot_id", DataTypes.LongType) + ProcedureParameter.optional("snapshot_id", DataTypes.LongType), + ProcedureParameter.optional("ref", DataTypes.StringType) }; private static final StructType OUTPUT_TYPE = @@ -78,7 +83,11 @@ public StructType outputType() { @Override public InternalRow[] call(InternalRow args) { Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); - long snapshotId = args.getLong(1); + Long snapshotId = args.isNullAt(1) ? null : args.getLong(1); + String ref = args.isNullAt(2) ? null : args.getString(2); + Preconditions.checkArgument( + (snapshotId != null && ref == null) || (snapshotId == null && ref != null), + "Either snapshot_id or ref must be provided, not both"); return modifyIcebergTable( tableIdent, @@ -86,9 +95,10 @@ public InternalRow[] call(InternalRow args) { Snapshot previousSnapshot = table.currentSnapshot(); Long previousSnapshotId = previousSnapshot != null ? previousSnapshot.snapshotId() : null; - table.manageSnapshots().setCurrentSnapshot(snapshotId).commit(); + long targetSnapshotId = snapshotId != null ? snapshotId : toSnapshotId(table, ref); + table.manageSnapshots().setCurrentSnapshot(targetSnapshotId).commit(); - InternalRow outputRow = newInternalRow(previousSnapshotId, snapshotId); + InternalRow outputRow = newInternalRow(previousSnapshotId, targetSnapshotId); return new InternalRow[] {outputRow}; }); } @@ -97,4 +107,10 @@ public InternalRow[] call(InternalRow args) { public String description() { return "SetCurrentSnapshotProcedure"; } + + private long toSnapshotId(Table table, String refName) { + SnapshotRef ref = table.refs().get(refName); + ValidationException.check(ref != null, "Cannot find matching snapshot ID for ref " + refName); + return ref.snapshotId(); + } } diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java index 8a8a974bbebe..da101e46de28 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java @@ -31,6 +31,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Iterables; import org.apache.spark.sql.AnalysisException; import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; +import org.assertj.core.api.Assertions; import org.junit.After; import org.junit.Assume; import org.junit.Test; @@ -219,14 +220,14 @@ public void testInvalidRollbackToSnapshotCases() { AssertHelpers.assertThrows( "Should reject calls without all required args", - AnalysisException.class, - "Missing required parameters", + IllegalArgumentException.class, + "Either snapshot_id or ref must be provided, not both", () -> sql("CALL %s.system.set_current_snapshot('t')", catalogName)); AssertHelpers.assertThrows( "Should reject calls without all required args", - AnalysisException.class, - "Missing required parameters", + IllegalArgumentException.class, + "Cannot parse identifier for arg table: 1", () -> sql("CALL %s.system.set_current_snapshot(1L)", catalogName)); AssertHelpers.assertThrows( @@ -237,8 +238,8 @@ public void testInvalidRollbackToSnapshotCases() { AssertHelpers.assertThrows( "Should reject calls without all required args", - AnalysisException.class, - "Missing required parameters", + IllegalArgumentException.class, + "Either snapshot_id or ref must be provided, not both", () -> sql("CALL %s.system.set_current_snapshot(table => 't')", catalogName)); AssertHelpers.assertThrows( @@ -252,5 +253,58 @@ public void testInvalidRollbackToSnapshotCases() { IllegalArgumentException.class, "Cannot handle an empty identifier", () -> sql("CALL %s.system.set_current_snapshot('', 1L)", catalogName)); + + Assertions.assertThatThrownBy( + () -> + sql( + "CALL %s.system.set_current_snapshot(table => 't', snapshot_id => 1L, ref => 's1')", + catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Either snapshot_id or ref must be provided, not both"); + } + + @Test + public void testSetCurrentSnapshotToRef() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String ref = "s1"; + sql("ALTER TABLE %s CREATE TAG %s", tableName, ref); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')", + catalogName, tableIdent, ref); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Set must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + String notExistRef = "s2"; + Assertions.assertThatThrownBy( + () -> + sql( + "CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')", + catalogName, tableIdent, notExistRef)) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot find matching snapshot ID for ref " + notExistRef); } } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java index f8f8049c22b6..22719e43c057 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java @@ -19,6 +19,10 @@ package org.apache.iceberg.spark.procedures; import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.Identifier; @@ -42,7 +46,8 @@ class SetCurrentSnapshotProcedure extends BaseProcedure { private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[] { ProcedureParameter.required("table", DataTypes.StringType), - ProcedureParameter.required("snapshot_id", DataTypes.LongType) + ProcedureParameter.optional("snapshot_id", DataTypes.LongType), + ProcedureParameter.optional("ref", DataTypes.StringType) }; private static final StructType OUTPUT_TYPE = @@ -78,7 +83,11 @@ public StructType outputType() { @Override public InternalRow[] call(InternalRow args) { Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); - long snapshotId = args.getLong(1); + Long snapshotId = args.isNullAt(1) ? null : args.getLong(1); + String ref = args.isNullAt(2) ? null : args.getString(2); + Preconditions.checkArgument( + (snapshotId != null && ref == null) || (snapshotId == null && ref != null), + "Either snapshot_id or ref must be provided, not both"); return modifyIcebergTable( tableIdent, @@ -86,9 +95,10 @@ public InternalRow[] call(InternalRow args) { Snapshot previousSnapshot = table.currentSnapshot(); Long previousSnapshotId = previousSnapshot != null ? previousSnapshot.snapshotId() : null; - table.manageSnapshots().setCurrentSnapshot(snapshotId).commit(); + long targetSnapshotId = snapshotId != null ? snapshotId : toSnapshotId(table, ref); + table.manageSnapshots().setCurrentSnapshot(targetSnapshotId).commit(); - InternalRow outputRow = newInternalRow(previousSnapshotId, snapshotId); + InternalRow outputRow = newInternalRow(previousSnapshotId, targetSnapshotId); return new InternalRow[] {outputRow}; }); } @@ -97,4 +107,10 @@ public InternalRow[] call(InternalRow args) { public String description() { return "SetCurrentSnapshotProcedure"; } + + private long toSnapshotId(Table table, String refName) { + SnapshotRef ref = table.refs().get(refName); + ValidationException.check(ref != null, "Cannot find matching snapshot ID for ref " + refName); + return ref.snapshotId(); + } }