Skip to content

Commit

Permalink
Spark3.1, Spark3.2, Spark3.3: Support setting current snapshot with ref
Browse files Browse the repository at this point in the history
Back-port of apache#8163 to `spark/v3.3`, `spark/v3.2` and `spark/v3.1`
  • Loading branch information
manuzhang committed Aug 26, 2023
1 parent 181d3e2 commit 7f04c3f
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.assertj.core.api.Assertions;
import org.junit.After;
import org.junit.Assume;
import org.junit.Test;
Expand Down Expand Up @@ -219,14 +220,14 @@ public void testInvalidRollbackToSnapshotCases() {

AssertHelpers.assertThrows(
"Should reject calls without all required args",
AnalysisException.class,
"Missing required parameters",
IllegalArgumentException.class,
"Either snapshot_id or ref must be provided, not both",
() -> sql("CALL %s.system.set_current_snapshot('t')", catalogName));

AssertHelpers.assertThrows(
"Should reject calls without all required args",
AnalysisException.class,
"Missing required parameters",
IllegalArgumentException.class,
"Cannot parse identifier for arg table: 1",
() -> sql("CALL %s.system.set_current_snapshot(1L)", catalogName));

AssertHelpers.assertThrows(
Expand All @@ -237,8 +238,8 @@ public void testInvalidRollbackToSnapshotCases() {

AssertHelpers.assertThrows(
"Should reject calls without all required args",
AnalysisException.class,
"Missing required parameters",
IllegalArgumentException.class,
"Either snapshot_id or ref must be provided, not both",
() -> sql("CALL %s.system.set_current_snapshot(table => 't')", catalogName));

AssertHelpers.assertThrows(
Expand All @@ -252,5 +253,58 @@ public void testInvalidRollbackToSnapshotCases() {
IllegalArgumentException.class,
"Cannot handle an empty identifier",
() -> sql("CALL %s.system.set_current_snapshot('', 1L)", catalogName));

Assertions.assertThatThrownBy(
() ->
sql(
"CALL %s.system.set_current_snapshot(table => 't', snapshot_id => 1L, ref => 's1')",
catalogName))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Either snapshot_id or ref must be provided, not both");
}

@Test
public void testSetCurrentSnapshotToRef() {
sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);

Table table = validationCatalog.loadTable(tableIdent);
Snapshot firstSnapshot = table.currentSnapshot();
String ref = "s1";
sql("ALTER TABLE %s CREATE TAG %s", tableName, ref);

sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
assertEquals(
"Should have expected rows",
ImmutableList.of(row(1L, "a"), row(1L, "a")),
sql("SELECT * FROM %s ORDER BY id", tableName));

table.refresh();

Snapshot secondSnapshot = table.currentSnapshot();

List<Object[]> output =
sql(
"CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')",
catalogName, tableIdent, ref);

assertEquals(
"Procedure output must match",
ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
output);

assertEquals(
"Set must be successful",
ImmutableList.of(row(1L, "a")),
sql("SELECT * FROM %s ORDER BY id", tableName));

String notExistRef = "s2";
Assertions.assertThatThrownBy(
() ->
sql(
"CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')",
catalogName, tableIdent, notExistRef))
.isInstanceOf(ValidationException.class)
.hasMessage("Cannot find matching snapshot ID for ref " + notExistRef);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
package org.apache.iceberg.spark.procedures;

import org.apache.iceberg.Snapshot;
import org.apache.iceberg.SnapshotRef;
import org.apache.iceberg.Table;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.Identifier;
Expand All @@ -42,7 +46,8 @@ class SetCurrentSnapshotProcedure extends BaseProcedure {
private static final ProcedureParameter[] PARAMETERS =
new ProcedureParameter[] {
ProcedureParameter.required("table", DataTypes.StringType),
ProcedureParameter.required("snapshot_id", DataTypes.LongType)
ProcedureParameter.optional("snapshot_id", DataTypes.LongType),
ProcedureParameter.optional("ref", DataTypes.StringType)
};

private static final StructType OUTPUT_TYPE =
Expand Down Expand Up @@ -78,17 +83,22 @@ public StructType outputType() {
@Override
public InternalRow[] call(InternalRow args) {
Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name());
long snapshotId = args.getLong(1);
Long snapshotId = args.isNullAt(1) ? null : args.getLong(1);
String ref = args.isNullAt(2) ? null : args.getString(2);
Preconditions.checkArgument(
(snapshotId != null && ref == null) || (snapshotId == null && ref != null),
"Either snapshot_id or ref must be provided, not both");

return modifyIcebergTable(
tableIdent,
table -> {
Snapshot previousSnapshot = table.currentSnapshot();
Long previousSnapshotId = previousSnapshot != null ? previousSnapshot.snapshotId() : null;

table.manageSnapshots().setCurrentSnapshot(snapshotId).commit();
long targetSnapshotId = snapshotId != null ? snapshotId : toSnapshotId(table, ref);
table.manageSnapshots().setCurrentSnapshot(targetSnapshotId).commit();

InternalRow outputRow = newInternalRow(previousSnapshotId, snapshotId);
InternalRow outputRow = newInternalRow(previousSnapshotId, targetSnapshotId);
return new InternalRow[] {outputRow};
});
}
Expand All @@ -97,4 +107,10 @@ public InternalRow[] call(InternalRow args) {
public String description() {
return "SetCurrentSnapshotProcedure";
}

private long toSnapshotId(Table table, String refName) {
SnapshotRef ref = table.refs().get(refName);
ValidationException.check(ref != null, "Cannot find matching snapshot ID for ref " + refName);
return ref.snapshotId();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.assertj.core.api.Assertions;
import org.junit.After;
import org.junit.Assume;
import org.junit.Test;
Expand Down Expand Up @@ -219,14 +220,14 @@ public void testInvalidRollbackToSnapshotCases() {

AssertHelpers.assertThrows(
"Should reject calls without all required args",
AnalysisException.class,
"Missing required parameters",
IllegalArgumentException.class,
"Either snapshot_id or ref must be provided, not both",
() -> sql("CALL %s.system.set_current_snapshot('t')", catalogName));

AssertHelpers.assertThrows(
"Should reject calls without all required args",
AnalysisException.class,
"Missing required parameters",
IllegalArgumentException.class,
"Cannot parse identifier for arg table: 1",
() -> sql("CALL %s.system.set_current_snapshot(1L)", catalogName));

AssertHelpers.assertThrows(
Expand All @@ -237,8 +238,8 @@ public void testInvalidRollbackToSnapshotCases() {

AssertHelpers.assertThrows(
"Should reject calls without all required args",
AnalysisException.class,
"Missing required parameters",
IllegalArgumentException.class,
"Either snapshot_id or ref must be provided, not both",
() -> sql("CALL %s.system.set_current_snapshot(table => 't')", catalogName));

AssertHelpers.assertThrows(
Expand All @@ -252,5 +253,58 @@ public void testInvalidRollbackToSnapshotCases() {
IllegalArgumentException.class,
"Cannot handle an empty identifier",
() -> sql("CALL %s.system.set_current_snapshot('', 1L)", catalogName));

Assertions.assertThatThrownBy(
() ->
sql(
"CALL %s.system.set_current_snapshot(table => 't', snapshot_id => 1L, ref => 's1')",
catalogName))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Either snapshot_id or ref must be provided, not both");
}

@Test
public void testSetCurrentSnapshotToRef() {
sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);

Table table = validationCatalog.loadTable(tableIdent);
Snapshot firstSnapshot = table.currentSnapshot();
String ref = "s1";
sql("ALTER TABLE %s CREATE TAG %s", tableName, ref);

sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
assertEquals(
"Should have expected rows",
ImmutableList.of(row(1L, "a"), row(1L, "a")),
sql("SELECT * FROM %s ORDER BY id", tableName));

table.refresh();

Snapshot secondSnapshot = table.currentSnapshot();

List<Object[]> output =
sql(
"CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')",
catalogName, tableIdent, ref);

assertEquals(
"Procedure output must match",
ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())),
output);

assertEquals(
"Set must be successful",
ImmutableList.of(row(1L, "a")),
sql("SELECT * FROM %s ORDER BY id", tableName));

String notExistRef = "s2";
Assertions.assertThatThrownBy(
() ->
sql(
"CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')",
catalogName, tableIdent, notExistRef))
.isInstanceOf(ValidationException.class)
.hasMessage("Cannot find matching snapshot ID for ref " + notExistRef);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
package org.apache.iceberg.spark.procedures;

import org.apache.iceberg.Snapshot;
import org.apache.iceberg.SnapshotRef;
import org.apache.iceberg.Table;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.Identifier;
Expand All @@ -42,7 +46,8 @@ class SetCurrentSnapshotProcedure extends BaseProcedure {
private static final ProcedureParameter[] PARAMETERS =
new ProcedureParameter[] {
ProcedureParameter.required("table", DataTypes.StringType),
ProcedureParameter.required("snapshot_id", DataTypes.LongType)
ProcedureParameter.optional("snapshot_id", DataTypes.LongType),
ProcedureParameter.optional("ref", DataTypes.StringType)
};

private static final StructType OUTPUT_TYPE =
Expand Down Expand Up @@ -78,17 +83,22 @@ public StructType outputType() {
@Override
public InternalRow[] call(InternalRow args) {
Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name());
long snapshotId = args.getLong(1);
Long snapshotId = args.isNullAt(1) ? null : args.getLong(1);
String ref = args.isNullAt(2) ? null : args.getString(2);
Preconditions.checkArgument(
(snapshotId != null && ref == null) || (snapshotId == null && ref != null),
"Either snapshot_id or ref must be provided, not both");

return modifyIcebergTable(
tableIdent,
table -> {
Snapshot previousSnapshot = table.currentSnapshot();
Long previousSnapshotId = previousSnapshot != null ? previousSnapshot.snapshotId() : null;

table.manageSnapshots().setCurrentSnapshot(snapshotId).commit();
long targetSnapshotId = snapshotId != null ? snapshotId : toSnapshotId(table, ref);
table.manageSnapshots().setCurrentSnapshot(targetSnapshotId).commit();

InternalRow outputRow = newInternalRow(previousSnapshotId, snapshotId);
InternalRow outputRow = newInternalRow(previousSnapshotId, targetSnapshotId);
return new InternalRow[] {outputRow};
});
}
Expand All @@ -97,4 +107,10 @@ public InternalRow[] call(InternalRow args) {
public String description() {
return "SetCurrentSnapshotProcedure";
}

private long toSnapshotId(Table table, String refName) {
SnapshotRef ref = table.refs().get(refName);
ValidationException.check(ref != null, "Cannot find matching snapshot ID for ref " + refName);
return ref.snapshotId();
}
}
Loading

0 comments on commit 7f04c3f

Please sign in to comment.