Skip to content

Commit

Permalink
[#3264] feat(spark-connector): Support Iceberg time travel in SQL que…
Browse files Browse the repository at this point in the history
…ries (#3265)

### What changes were proposed in this pull request?

Support Iceberg time travel in SQL queries

### Why are the changes needed?
supports time travel in SQL queries using `TIMESTAMP AS OF`, `FOR
SYSTEM_TIME AS OF` or `VERSION AS OF`, `FOR SYSTEM_VERSION AS OF`
clauses.

Fix: #3264

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
New ITs.
  • Loading branch information
caican00 authored May 20, 2024
1 parent 2e57a44 commit 9929a99
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import lombok.Data;
Expand Down Expand Up @@ -252,6 +253,74 @@ void testIcebergCallOperations() throws NoSuchTableException {
testIcebergCallRewritePositionDeleteFiles();
}

@Test
void testIcebergTimeTravelQuery() throws NoSuchTableException {
String tableName = "test_iceberg_as_of_query";
dropTableIfExists(tableName);
createSimpleTable(tableName);
checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName));

sql(String.format("INSERT INTO %s VALUES (1, '1', 1)", tableName));
List<String> tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

SparkIcebergTable sparkIcebergTable = getSparkIcebergTableInstance(tableName);
long snapshotId = getCurrentSnapshotId(tableName);
sparkIcebergTable.table().manageSnapshots().createBranch("test_branch", snapshotId).commit();
sparkIcebergTable.table().manageSnapshots().createTag("test_tag", snapshotId).commit();
long snapshotTimestamp = getCurrentSnapshotTimestamp(tableName);
long timestamp = waitUntilAfter(snapshotTimestamp + 1000);
long timestampInSeconds = TimeUnit.MILLISECONDS.toSeconds(timestamp);

// create a second snapshot
sql(String.format("INSERT INTO %s VALUES (2, '2', 2)", tableName));
tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(2, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData));

tableData =
getQueryData(
String.format("SELECT * FROM %s TIMESTAMP AS OF %s", tableName, timestampInSeconds));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
tableData =
getQueryData(
String.format(
"SELECT * FROM %s FOR SYSTEM_TIME AS OF %s", tableName, timestampInSeconds));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

tableData =
getQueryData(String.format("SELECT * FROM %s VERSION AS OF %d", tableName, snapshotId));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
tableData =
getQueryData(
String.format("SELECT * FROM %s FOR SYSTEM_VERSION AS OF %d", tableName, snapshotId));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

tableData =
getQueryData(String.format("SELECT * FROM %s VERSION AS OF 'test_branch'", tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
tableData =
getQueryData(
String.format("SELECT * FROM %s FOR SYSTEM_VERSION AS OF 'test_branch'", tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

tableData = getQueryData(String.format("SELECT * FROM %s VERSION AS OF 'test_tag'", tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
tableData =
getQueryData(
String.format("SELECT * FROM %s FOR SYSTEM_VERSION AS OF 'test_tag'", tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
}

private void testMetadataColumns() {
String tableName = "test_metadata_columns";
dropTableIfExists(tableName);
Expand Down Expand Up @@ -722,13 +791,31 @@ static IcebergTableWriteProperties of(
}
}

private long getCurrentSnapshotId(String tableName) throws NoSuchTableException {
private SparkIcebergTable getSparkIcebergTableInstance(String tableName)
throws NoSuchTableException {
CatalogPlugin catalogPlugin =
getSparkSession().sessionState().catalogManager().catalog(getCatalogName());
Assertions.assertInstanceOf(TableCatalog.class, catalogPlugin);
TableCatalog catalog = (TableCatalog) catalogPlugin;
Table table = catalog.loadTable(Identifier.of(new String[] {getDefaultDatabase()}, tableName));
SparkIcebergTable sparkIcebergTable = (SparkIcebergTable) table;
return (SparkIcebergTable) table;
}

private long getCurrentSnapshotTimestamp(String tableName) throws NoSuchTableException {
SparkIcebergTable sparkIcebergTable = getSparkIcebergTableInstance(tableName);
return sparkIcebergTable.table().currentSnapshot().timestampMillis();
}

private long getCurrentSnapshotId(String tableName) throws NoSuchTableException {
SparkIcebergTable sparkIcebergTable = getSparkIcebergTableInstance(tableName);
return sparkIcebergTable.table().currentSnapshot().snapshotId();
}

private long waitUntilAfter(Long timestampMillis) {
long current = System.currentTimeMillis();
while (current <= timestampMillis) {
current = System.currentTimeMillis();
}
return current;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ protected abstract TableCatalog createAndInitSparkCatalog(
*
* @param identifier Spark's table identifier
* @param gravitinoTable Gravitino table to do DDL operations
* @param sparkTable Spark internal table to do IO operations
* @param sparkCatalog specific Spark catalog to do IO operations
* @param propertiesConverter transform properties between Gravitino and Spark
* @param sparkTransformConverter sparkTransformConverter convert transforms between Gravitino and
Expand All @@ -101,6 +102,7 @@ protected abstract TableCatalog createAndInitSparkCatalog(
protected abstract Table createSparkTable(
Identifier identifier,
com.datastrato.gravitino.rel.Table gravitinoTable,
Table sparkTable,
TableCatalog sparkCatalog,
PropertiesConverter propertiesConverter,
SparkTransformConverter sparkTransformConverter);
Expand Down Expand Up @@ -194,8 +196,14 @@ public Table createTable(
partitionings,
distributionAndSortOrdersInfo.getDistribution(),
distributionAndSortOrdersInfo.getSortOrders());
org.apache.spark.sql.connector.catalog.Table sparkTable = loadSparkTable(ident);
return createSparkTable(
ident, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter);
ident,
gravitinoTable,
sparkTable,
sparkCatalog,
propertiesConverter,
sparkTransformConverter);
} catch (NoSuchSchemaException e) {
throw new NoSuchNamespaceException(ident.namespace());
} catch (com.datastrato.gravitino.exceptions.TableAlreadyExistsException e) {
Expand All @@ -206,14 +214,16 @@ public Table createTable(
@Override
public Table loadTable(Identifier ident) throws NoSuchTableException {
try {
String database = getDatabase(ident);
com.datastrato.gravitino.rel.Table gravitinoTable =
gravitinoCatalogClient
.asTableCatalog()
.loadTable(NameIdentifier.of(metalakeName, catalogName, database, ident.name()));
com.datastrato.gravitino.rel.Table gravitinoTable = loadGravitinoTable(ident);
org.apache.spark.sql.connector.catalog.Table sparkTable = loadSparkTable(ident);
// Will create a catalog specific table
return createSparkTable(
ident, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter);
ident,
gravitinoTable,
sparkTable,
sparkCatalog,
propertiesConverter,
sparkTransformConverter);
} catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) {
throw new NoSuchTableException(ident);
}
Expand All @@ -240,8 +250,14 @@ public Table alterTable(Identifier ident, TableChange... changes) throws NoSuchT
.alterTable(
NameIdentifier.of(metalakeName, catalogName, getDatabase(ident), ident.name()),
gravitinoTableChanges);
org.apache.spark.sql.connector.catalog.Table sparkTable = loadSparkTable(ident);
return createSparkTable(
ident, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter);
ident,
gravitinoTable,
sparkTable,
sparkCatalog,
propertiesConverter,
sparkTransformConverter);
} catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) {
throw new NoSuchTableException(ident);
}
Expand Down Expand Up @@ -377,6 +393,25 @@ public boolean dropNamespace(String[] namespace, boolean cascade)
}
}

protected com.datastrato.gravitino.rel.Table loadGravitinoTable(Identifier ident)
throws NoSuchTableException {
try {
String database = getDatabase(ident);
return gravitinoCatalogClient
.asTableCatalog()
.loadTable(NameIdentifier.of(metalakeName, catalogName, database, ident.name()));
} catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) {
throw new NoSuchTableException(ident);
}
}

protected String getDatabase(Identifier sparkIdentifier) {
if (sparkIdentifier.namespace().length > 0) {
return sparkIdentifier.namespace()[0];
}
return getCatalogDefaultNamespace();
}

private void validateNamespace(String[] namespace) {
Preconditions.checkArgument(
namespace.length == 1,
Expand All @@ -403,13 +438,6 @@ private com.datastrato.gravitino.rel.Column createGravitinoColumn(Column sparkCo
com.datastrato.gravitino.rel.Column.DEFAULT_VALUE_NOT_SET);
}

protected String getDatabase(Identifier sparkIdentifier) {
if (sparkIdentifier.namespace().length > 0) {
return sparkIdentifier.namespace()[0];
}
return getCatalogDefaultNamespace();
}

private String getDatabase(NameIdentifier gravitinoIdentifier) {
Preconditions.checkArgument(
gravitinoIdentifier.namespace().length() == 3,
Expand Down Expand Up @@ -497,4 +525,16 @@ private static com.datastrato.gravitino.rel.TableChange.ColumnPosition transform
"Unsupported table column position %s", columnPosition.getClass().getName()));
}
}

private Table loadSparkTable(Identifier ident) {
try {
return sparkCatalog.loadTable(ident);
} catch (NoSuchTableException e) {
throw new RuntimeException(
String.format(
"Failed to load the real sparkTable: %s",
String.join(".", getDatabase(ident), ident.name())),
e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.util.Map;
import org.apache.kyuubi.spark.connector.hive.HiveTable;
import org.apache.kyuubi.spark.connector.hive.HiveTableCatalog;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.connector.catalog.Identifier;
import org.apache.spark.sql.connector.catalog.TableCatalog;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
Expand All @@ -34,19 +33,10 @@ protected TableCatalog createAndInitSparkCatalog(
protected org.apache.spark.sql.connector.catalog.Table createSparkTable(
Identifier identifier,
Table gravitinoTable,
org.apache.spark.sql.connector.catalog.Table sparkTable,
TableCatalog sparkHiveCatalog,
PropertiesConverter propertiesConverter,
SparkTransformConverter sparkTransformConverter) {
org.apache.spark.sql.connector.catalog.Table sparkTable;
try {
sparkTable = sparkHiveCatalog.loadTable(identifier);
} catch (NoSuchTableException e) {
throw new RuntimeException(
String.format(
"Failed to load the real sparkTable: %s",
String.join(".", getDatabase(identifier), identifier.name())),
e);
}
return new SparkHiveTable(
identifier,
gravitinoTable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,10 @@ protected TableCatalog createAndInitSparkCatalog(
protected org.apache.spark.sql.connector.catalog.Table createSparkTable(
Identifier identifier,
Table gravitinoTable,
org.apache.spark.sql.connector.catalog.Table sparkTable,
TableCatalog sparkIcebergCatalog,
PropertiesConverter propertiesConverter,
SparkTransformConverter sparkTransformConverter) {
org.apache.spark.sql.connector.catalog.Table sparkTable;
try {
sparkTable = sparkIcebergCatalog.loadTable(identifier);
} catch (NoSuchTableException e) {
throw new RuntimeException(
String.format(
"Failed to load the real sparkTable: %s",
String.join(".", getDatabase(identifier), identifier.name())),
e);
}
return new SparkIcebergTable(
identifier,
gravitinoTable,
Expand Down Expand Up @@ -128,6 +119,44 @@ public Catalog icebergCatalog() {
return ((SparkCatalog) sparkCatalog).icebergCatalog();
}

@Override
public org.apache.spark.sql.connector.catalog.Table loadTable(Identifier ident, String version)
throws NoSuchTableException {
try {
com.datastrato.gravitino.rel.Table gravitinoTable = loadGravitinoTable(ident);
org.apache.spark.sql.connector.catalog.Table sparkTable = loadSparkTable(ident, version);
// Will create a catalog specific table
return createSparkTable(
ident,
gravitinoTable,
sparkTable,
sparkCatalog,
propertiesConverter,
sparkTransformConverter);
} catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) {
throw new NoSuchTableException(ident);
}
}

@Override
public org.apache.spark.sql.connector.catalog.Table loadTable(Identifier ident, long timestamp)
throws NoSuchTableException {
try {
com.datastrato.gravitino.rel.Table gravitinoTable = loadGravitinoTable(ident);
org.apache.spark.sql.connector.catalog.Table sparkTable = loadSparkTable(ident, timestamp);
// Will create a catalog specific table
return createSparkTable(
ident,
gravitinoTable,
sparkTable,
sparkCatalog,
propertiesConverter,
sparkTransformConverter);
} catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) {
throw new NoSuchTableException(ident);
}
}

private boolean isSystemNamespace(String[] namespace)
throws NoSuchMethodException, InvocationTargetException, IllegalAccessException,
ClassNotFoundException {
Expand All @@ -136,4 +165,30 @@ private boolean isSystemNamespace(String[] namespace)
isSystemNamespace.setAccessible(true);
return (Boolean) isSystemNamespace.invoke(baseCatalog, (Object) namespace);
}

private org.apache.spark.sql.connector.catalog.Table loadSparkTable(
Identifier ident, String version) {
try {
return sparkCatalog.loadTable(ident, version);
} catch (NoSuchTableException e) {
throw new RuntimeException(
String.format(
"Failed to load the real sparkTable: %s",
String.join(".", getDatabase(ident), ident.name())),
e);
}
}

private org.apache.spark.sql.connector.catalog.Table loadSparkTable(
Identifier ident, long timestamp) {
try {
return sparkCatalog.loadTable(ident, timestamp);
} catch (NoSuchTableException e) {
throw new RuntimeException(
String.format(
"Failed to load the real sparkTable: %s",
String.join(".", getDatabase(ident), ident.name())),
e);
}
}
}
Loading

0 comments on commit 9929a99

Please sign in to comment.