Skip to content

Commit

Permalink
add parition
Browse files Browse the repository at this point in the history
  • Loading branch information
FANNG1 committed Mar 20, 2024
1 parent b53ff6b commit d404224
Show file tree
Hide file tree
Showing 12 changed files with 790 additions and 28 deletions.
1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ iceberg = '1.3.1' # 1.4.0 causes test to fail
trino = '426'
spark = "3.4.1" # 3.5.0 causes tests to fail
scala-collection-compat = "2.7.0"
scala-java-compat = "1.0.2"
sqlite-jdbc = "3.42.0.0"
testng = "7.5.1"
testcontainers = "1.19.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo.SparkColumnInfo;
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker;
import com.google.common.collect.ImmutableMap;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.hadoop.fs.Path;
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
Expand All @@ -23,16 +25,10 @@
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIf;
import org.junit.platform.commons.util.StringUtils;

public abstract class SparkCommonIT extends SparkEnvIT {
private static String getSelectAllSql(String tableName) {
return String.format("SELECT * FROM %s", tableName);
}

private static String getInsertWithoutPartitionSql(String tableName, String values) {
return String.format("INSERT INTO %s VALUES (%s)", tableName, values);
}

// To generate test data for write&read table.
private static final Map<DataType, String> typeConstant =
Expand All @@ -51,8 +47,25 @@ private static String getInsertWithoutPartitionSql(String tableName, String valu
DataTypes.createStructField("col2", DataTypes.StringType, true))),
"struct(1, 'a')");

// Use a custom database not the original default database because SparkCommonIT couldn't
// read&write data to tables in default database. The main reason is default database location is
private static String getSelectAllSql(String tableName) {
return String.format("SELECT * FROM %s", tableName);
}

private static String getInsertWithoutPartitionSql(String tableName, String values) {
return String.format("INSERT INTO %s VALUES (%s)", tableName, values);
}

private static String getInsertWithPartitionSql(
String tableName, String partitionString, String values) {
return String.format(
"INSERT OVERWRITE %s PARTITION (%s) VALUES (%s)", tableName, partitionString, values);
}

// Whether supports [CLUSTERED BY col_name3 SORTED BY col_name INTO num_buckets BUCKETS]
protected abstract boolean supportsSparkSQLClusteredBy();

// Use a custom database not the original default database because SparkIT couldn't read&write
// data to tables in default database. The main reason is default database location is
// determined by `hive.metastore.warehouse.dir` in hive-site.xml which is local HDFS address
// not real HDFS address. The location of tables created under default database is like
// hdfs://localhost:9000/xxx which couldn't read write data from SparkCommonIT. Will use default
Expand All @@ -69,10 +82,6 @@ void init() {
sql("USE " + getDefaultDatabase());
}

protected String getDefaultDatabase() {
return "default_db";
}

@Test
void testLoadCatalogs() {
Set<String> catalogs = getCatalogs();
Expand Down Expand Up @@ -442,24 +451,94 @@ void testComplexType() {
checkTableReadWrite(tableInfo);
}

private void checkTableColumns(
String tableName, List<SparkColumnInfo> columnInfos, SparkTableInfo tableInfo) {
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(columnInfos)
.withComment(null)
.check(tableInfo);
@Test
void testCreateDatasourceFormatPartitionTable() {
String tableName = "datasource_partition_table";

dropTableIfExists(tableName);
String createTableSQL = getCreateSimpleTableString(tableName);
createTableSQL = createTableSQL + "USING PARQUET PARTITIONED BY (name, age)";
sql(createTableSQL);
SparkTableInfo tableInfo = getTableInfo(tableName);
SparkTableInfoChecker checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(getSimpleTableColumn())
.withIdentifyPartition(Arrays.asList("name", "age"));
checker.check(tableInfo);
checkTableReadWrite(tableInfo);
checkPartitionDirExists(tableInfo);
}

@Test
@EnabledIf("supportsSparkSQLClusteredBy")
void testCreateBucketTable() {
String tableName = "bucket_table";

dropTableIfExists(tableName);
String createTableSQL = getCreateSimpleTableString(tableName);
createTableSQL = createTableSQL + "CLUSTERED BY (id, name) INTO 4 buckets;";
sql(createTableSQL);
SparkTableInfo tableInfo = getTableInfo(tableName);
SparkTableInfoChecker checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(getSimpleTableColumn())
.withBucket(4, Arrays.asList("id", "name"));
checker.check(tableInfo);
checkTableReadWrite(tableInfo);
}

private void checkTableReadWrite(SparkTableInfo table) {
@Test
@EnabledIf("supportsSparkSQLClusteredBy")
void testCreateSortBucketTable() {
String tableName = "sort_bucket_table";

dropTableIfExists(tableName);
String createTableSQL = getCreateSimpleTableString(tableName);
createTableSQL =
createTableSQL + "CLUSTERED BY (id, name) SORTED BY (name, id) INTO 4 buckets;";
sql(createTableSQL);
SparkTableInfo tableInfo = getTableInfo(tableName);
SparkTableInfoChecker checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(getSimpleTableColumn())
.withBucket(4, Arrays.asList("id", "name"), Arrays.asList("name", "id"));
checker.check(tableInfo);
checkTableReadWrite(tableInfo);
}

protected void checkPartitionDirExists(SparkTableInfo table) {
Assertions.assertTrue(table.isPartitionTable(), "Not a partition table");
String tableLocation = table.getTableLocation();
String partitionExpression = getPartitionExpression(table, "/").replace("'", "");
Path partitionPath = new Path(tableLocation, partitionExpression);
try {
Assertions.assertTrue(
hdfs.exists(partitionPath), "Partition directory not exists," + partitionPath);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

protected void checkTableReadWrite(SparkTableInfo table) {
String name = table.getTableIdentifier();
boolean isPartitionTable = table.isPartitionTable();
String insertValues =
table.getColumns().stream()
table.getUnPartitionedColumns().stream()
.map(columnInfo -> typeConstant.get(columnInfo.getType()))
.map(Object::toString)
.collect(Collectors.joining(","));

sql(getInsertWithoutPartitionSql(name, insertValues));
String insertDataSQL = "";
if (isPartitionTable) {
String partitionExpressions = getPartitionExpression(table, ",");
insertDataSQL = getInsertWithPartitionSql(name, partitionExpressions, insertValues);
} else {
insertDataSQL = getInsertWithoutPartitionSql(name, insertValues);
}
sql(insertDataSQL);

// do something to match the query result:
// 1. remove "'" from values, such as 'a' is trans to a
Expand Down Expand Up @@ -514,23 +593,43 @@ private void checkTableReadWrite(SparkTableInfo table) {
Assertions.assertEquals(checkValues, queryResult.get(0));
}

private String getCreateSimpleTableString(String tableName) {
protected String getCreateSimpleTableString(String tableName) {
return String.format(
"CREATE TABLE %s (id INT COMMENT 'id comment', name STRING COMMENT '', age INT)",
tableName);
}

private List<SparkColumnInfo> getSimpleTableColumn() {
protected List<SparkColumnInfo> getSimpleTableColumn() {
return Arrays.asList(
SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment"),
SparkColumnInfo.of("name", DataTypes.StringType, ""),
SparkColumnInfo.of("age", DataTypes.IntegerType, null));
}

protected String getDefaultDatabase() {
return "default_db";
}

// Helper method to create a simple table, and could use corresponding
// getSimpleTableColumn to check table column.
private void createSimpleTable(String identifier) {
String createTableSql = getCreateSimpleTableString(identifier);
sql(createTableSql);
}

private void checkTableColumns(
String tableName, List<SparkColumnInfo> columns, SparkTableInfo tableInfo) {
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(columns)
.withComment(null)
.check(tableInfo);
}

// partition expression may contain "'", like a='s'/b=1
private String getPartitionExpression(SparkTableInfo table, String delimiter) {
return table.getPartitionedColumns().stream()
.map(column -> column.getName() + "=" + typeConstant.get(column.getType()))
.collect(Collectors.joining(delimiter));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
import com.datastrato.gravitino.spark.connector.GravitinoSparkConfig;
import com.datastrato.gravitino.spark.connector.plugin.GravitinoSparkPlugin;
import com.google.common.collect.Maps;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
Expand All @@ -28,11 +31,12 @@ public abstract class SparkEnvIT extends SparkUtilIT {
private static final Logger LOG = LoggerFactory.getLogger(SparkEnvIT.class);
private static final ContainerSuite containerSuite = ContainerSuite.getInstance();

protected FileSystem hdfs;
private final String metalakeName = "test";

private SparkSession sparkSession;
private String hiveMetastoreUri;
private String gravitinoUri;
private String hiveMetastoreUri = "thrift://127.0.0.1:9083";
private String gravitinoUri = "http://127.0.0.1:8090";

protected abstract String getCatalogName();

Expand All @@ -47,6 +51,7 @@ protected SparkSession getSparkSession() {
@BeforeAll
void startUp() {
initHiveEnv();
initHdfsFileSystem();
initGravitinoEnv();
initMetalakeAndCatalogs();
initSparkEnv();
Expand All @@ -58,6 +63,13 @@ void startUp() {

@AfterAll
void stop() {
if (hdfs != null) {
try {
hdfs.close();
} catch (IOException e) {
LOG.warn("Close HDFS filesystem failed,", e);
}
}
if (sparkSession != null) {
sparkSession.close();
}
Expand Down Expand Up @@ -92,6 +104,22 @@ private void initHiveEnv() {
HiveContainer.HIVE_METASTORE_PORT);
}

private void initHdfsFileSystem() {
Configuration conf = new Configuration();
conf.set(
"fs.defaultFS",
String.format(
"hdfs://%s:%d",
containerSuite.getHiveContainer().getContainerIpAddress(),
HiveContainer.HDFS_DEFAULTFS_PORT));
try {
hdfs = FileSystem.get(conf);
} catch (IOException e) {
LOG.error("Create HDFS filesystem failed", e);
throw new RuntimeException(e);
}
}

private void initSparkEnv() {
sparkSession =
SparkSession.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
package com.datastrato.gravitino.integration.test.spark.hive;

import com.datastrato.gravitino.integration.test.spark.SparkCommonIT;
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo;
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo.SparkColumnInfo;
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.sql.types.DataTypes;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

@Tag("gravitino-docker-it")
Expand All @@ -21,4 +29,33 @@ protected String getCatalogName() {
protected String getProvider() {
return "hive";
}

@Override
protected boolean supportsSparkSQLClusteredBy() {
return true;
}

@Test
public void testCreateHiveFormatPartitionTable() {
String tableName = "hive_partition_table";

dropTableIfExists(tableName);
String createTableSQL = getCreateSimpleTableString(tableName);
createTableSQL = createTableSQL + "PARTITIONED BY (age_p1 INT, age_p2 STRING)";
sql(createTableSQL);

List<SparkColumnInfo> columns = new ArrayList<>(getSimpleTableColumn());
columns.add(SparkColumnInfo.of("age_p1", DataTypes.IntegerType));
columns.add(SparkColumnInfo.of("age_p2", DataTypes.StringType));

SparkTableInfo tableInfo = getTableInfo(tableName);
SparkTableInfoChecker checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(columns)
.withIdentifyPartition(Arrays.asList("age_p1", "age_p2"));
checker.check(tableInfo);
checkTableReadWrite(tableInfo);
checkPartitionDirExists(tableInfo);
}
}
Loading

0 comments on commit d404224

Please sign in to comment.