Skip to content

Commit

Permalink
add parition
Browse files Browse the repository at this point in the history
  • Loading branch information
FANNG1 committed Mar 18, 2024
1 parent 9a77dff commit fcbe74b
Show file tree
Hide file tree
Showing 11 changed files with 750 additions and 10 deletions.
1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ iceberg = '1.3.1' # 1.4.0 causes test to fail
trino = '426'
spark = "3.4.1" # 3.5.0 causes tests to fail
scala-collection-compat = "2.7.0"
scala-java-compat = "1.0.2"
sqlite-jdbc = "3.42.0.0"
testng = "7.5.1"
testcontainers = "1.19.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
import com.datastrato.gravitino.spark.connector.GravitinoSparkConfig;
import com.datastrato.gravitino.spark.connector.plugin.GravitinoSparkPlugin;
import com.google.common.collect.Maps;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
Expand All @@ -29,11 +32,12 @@ public class SparkEnvIT extends SparkUtilIT {
private static final ContainerSuite containerSuite = ContainerSuite.getInstance();

protected final String hiveCatalogName = "hive";
protected FileSystem hdfs;
private final String metalakeName = "test";

private SparkSession sparkSession;
private String hiveMetastoreUri;
private String gravitinoUri;
private String hiveMetastoreUri = "thrift://127.0.0.1:9083";
private String gravitinoUri = "http://127.0.0.1:8090";

@Override
protected SparkSession getSparkSession() {
Expand All @@ -44,6 +48,7 @@ protected SparkSession getSparkSession() {
@BeforeAll
void startUp() {
initHiveEnv();
initHdfsFileSystem();
initGravitinoEnv();
initMetalakeAndCatalogs();
initSparkEnv();
Expand All @@ -55,6 +60,13 @@ void startUp() {

@AfterAll
void stop() {
if (hdfs != null) {
try {
hdfs.close();
} catch (IOException e) {
LOG.warn("Close HDFS filesystem failed,", e);
}
}
if (sparkSession != null) {
sparkSession.close();
}
Expand Down Expand Up @@ -89,6 +101,22 @@ private void initHiveEnv() {
HiveContainer.HIVE_METASTORE_PORT);
}

private void initHdfsFileSystem() {
Configuration conf = new Configuration();
conf.set(
"fs.defaultFS",
String.format(
"hdfs://%s:%d",
containerSuite.getHiveContainer().getContainerIpAddress(),
HiveContainer.HDFS_DEFAULTFS_PORT));
try {
hdfs = FileSystem.get(conf);
} catch (IOException e) {
LOG.error("Create HDFS filesystem failed", e);
throw new RuntimeException(e);
}
}

private void initSparkEnv() {
sparkSession =
SparkSession.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo.SparkColumnInfo;
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker;
import com.google.common.collect.ImmutableMap;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.hadoop.fs.Path;
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
Expand All @@ -31,6 +33,11 @@
@Tag("gravitino-docker-it")
@TestInstance(Lifecycle.PER_CLASS)
public class SparkIT extends SparkEnvIT {

// To generate test data for write&read table.
private static final Map<DataType, String> typeConstant =
ImmutableMap.of(DataTypes.IntegerType, "2", DataTypes.StringType, "'gravitino_it_test'");

private static String getSelectAllSql(String tableName) {
return String.format("SELECT * FROM %s", tableName);
}
Expand All @@ -39,9 +46,11 @@ private static String getInsertWithoutPartitionSql(String tableName, String valu
return String.format("INSERT INTO %s VALUES (%s)", tableName, values);
}

// To generate test data for write&read table.
private static final Map<DataType, String> typeConstant =
ImmutableMap.of(DataTypes.IntegerType, "2", DataTypes.StringType, "'gravitino_it_test'");
private static String getInsertWithPartitionSql(
String tableName, String partitionString, String values) {
return String.format(
"INSERT OVERWRITE %s PARTITION (%s) VALUES (%s)", tableName, partitionString, values);
}

// Use a custom database not the original default database because SparkIT couldn't read&write
// data to tables in default database. The main reason is default database location is
Expand Down Expand Up @@ -406,23 +415,124 @@ void testAlterTableUpdateColumnComment() {
}

private void checkTableColumns(
String tableName, List<SparkColumnInfo> columnInfos, SparkTableInfo tableInfo) {
String tableName, List<SparkColumnInfo> columns, SparkTableInfo tableInfo) {
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(columnInfos)
.withColumns(columns)
.withComment(null)
.check(tableInfo);
}

@Test
public void testCreateDatasourceFormatPartitionTable() {
String tableName = "datasource_partition_table";

dropTableIfExists(tableName);
String createTableSQL = getCreateSimpleTableString(tableName);
createTableSQL = createTableSQL + "USING PARQUET PARTITIONED BY (name, age)";
sql(createTableSQL);
SparkTableInfo tableInfo = getTableInfo(tableName);
SparkTableInfoChecker checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(getSimpleTableColumn())
.withIdentifyPartition(Arrays.asList("name", "age"));
checker.check(tableInfo);
checkTableReadWrite(tableInfo);
checkPartitionDirExists(tableInfo);
}

@Test
public void testCreateHiveFormatPartitionTable() {
String tableName = "hive_partition_table";

dropTableIfExists(tableName);
String createTableSQL = getCreateSimpleTableString(tableName);
createTableSQL = createTableSQL + "PARTITIONED BY (age_p1 INT, age_p2 STRING)";
sql(createTableSQL);

List<SparkColumnInfo> columns = new ArrayList<>(getSimpleTableColumn());
columns.add(SparkColumnInfo.of("age_p1", DataTypes.IntegerType));
columns.add(SparkColumnInfo.of("age_p2", DataTypes.StringType));

SparkTableInfo tableInfo = getTableInfo(tableName);
SparkTableInfoChecker checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(columns)
.withIdentifyPartition(Arrays.asList("age_p1", "age_p2"));
checker.check(tableInfo);
checkTableReadWrite(tableInfo);
checkPartitionDirExists(tableInfo);
}

@Test
public void testCreateBucketTable() {
String tableName = "bucket_table";

dropTableIfExists(tableName);
String createTableSQL = getCreateSimpleTableString(tableName);
createTableSQL = createTableSQL + "CLUSTERED BY (id, name) INTO 4 buckets;";
sql(createTableSQL);
SparkTableInfo tableInfo = getTableInfo(tableName);
SparkTableInfoChecker checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(getSimpleTableColumn())
.withBucket(4, Arrays.asList("id", "name"));
checker.check(tableInfo);
checkTableReadWrite(tableInfo);
}

@Test
public void testCreateSortBucketTable() {
String tableName = "sort_bucket_table";

dropTableIfExists(tableName);
String createTableSQL = getCreateSimpleTableString(tableName);
createTableSQL =
createTableSQL + "CLUSTERED BY (id, name) SORTED BY (name, id) INTO 4 buckets;";
sql(createTableSQL);
SparkTableInfo tableInfo = getTableInfo(tableName);
SparkTableInfoChecker checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(getSimpleTableColumn())
.withBucket(4, Arrays.asList("id", "name"), Arrays.asList("name", "id"));
checker.check(tableInfo);
checkTableReadWrite(tableInfo);
}

private void checkPartitionDirExists(SparkTableInfo table) {
Assertions.assertTrue(table.isPartitionTable(), "Not a partition table");
String tableLocation = table.getTableLocation();
String partitionExpression = getPartitionExpression(table, "/").replace("'", "");
Path partitionPath = new Path(tableLocation, partitionExpression);
try {
Assertions.assertTrue(
hdfs.exists(partitionPath), "Partition directory not exists," + partitionPath);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private void checkTableReadWrite(SparkTableInfo table) {
String name = table.getTableIdentifier();
boolean isPartitionTable = table.isPartitionTable();
String insertValues =
table.getColumns().stream()
table.getNonPartitionColumns().stream()
.map(columnInfo -> typeConstant.get(columnInfo.getType()))
.map(Object::toString)
.collect(Collectors.joining(","));

sql(getInsertWithoutPartitionSql(name, insertValues));
String insertDataSQL = "";
if (isPartitionTable) {
String partitionExpressions = getPartitionExpression(table, ",");
insertDataSQL = getInsertWithPartitionSql(name, partitionExpressions, insertValues);
} else {
insertDataSQL = getInsertWithoutPartitionSql(name, insertValues);
}
sql(insertDataSQL);

// remove "'" from values, such as 'a' is trans to a
String checkValues =
Expand Down Expand Up @@ -469,4 +579,11 @@ private void createSimpleTable(String identifier) {
String createTableSql = getCreateSimpleTableString(identifier);
sql(createTableSql);
}

// partition expression may contain "'", like a='s'/b=1
private String getPartitionExpression(SparkTableInfo table, String delimiter) {
return table.getPartitionColumns().stream()
.map(column -> column.getName() + "=" + typeConstant.get(column.getType()))
.collect(Collectors.joining(delimiter));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,18 @@
import com.datastrato.gravitino.spark.connector.table.SparkBaseTable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import javax.ws.rs.NotSupportedException;
import lombok.Data;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.sql.connector.expressions.BucketTransform;
import org.apache.spark.sql.connector.expressions.IdentityTransform;
import org.apache.spark.sql.connector.expressions.SortedBucketTransform;
import org.apache.spark.sql.connector.expressions.Transform;
import org.apache.spark.sql.types.DataType;
import org.junit.jupiter.api.Assertions;

Expand All @@ -26,6 +33,9 @@ public class SparkTableInfo {
private List<SparkColumnInfo> columns;
private Map<String, String> tableProperties;
private List<String> unknownItems = new ArrayList<>();
private Transform bucket;
private List<Transform> partitions = new ArrayList<>();
private Set<String> partitionColumnNames = new HashSet<>();

public SparkTableInfo() {}

Expand All @@ -42,6 +52,28 @@ public String getTableIdentifier() {
}
}

public String getTableLocation() {
return tableProperties.get(ConnectorConstants.LOCATION);
}

public boolean isPartitionTable() {
return partitions.size() > 0;
}

void setBucket(Transform bucket) {
Assertions.assertNull(this.bucket, "Should only one distribution");
this.bucket = bucket;
}

void addPartition(Transform partition) {
this.partitions.add(partition);
if (partition instanceof IdentityTransform) {
partitionColumnNames.add(((IdentityTransform) partition).reference().fieldNames()[0]);
} else {
throw new NotSupportedException(partition.name() + " is not supported yet.");
}
}

static SparkTableInfo create(SparkBaseTable baseTable) {
SparkTableInfo sparkTableInfo = new SparkTableInfo();
String identifier = baseTable.name();
Expand All @@ -62,9 +94,33 @@ static SparkTableInfo create(SparkBaseTable baseTable) {
.collect(Collectors.toList());
sparkTableInfo.comment = baseTable.properties().remove(ConnectorConstants.COMMENT);
sparkTableInfo.tableProperties = baseTable.properties();
Arrays.stream(baseTable.partitioning())
.forEach(
transform -> {
if (transform instanceof BucketTransform
|| transform instanceof SortedBucketTransform) {
sparkTableInfo.setBucket(transform);
} else if (transform instanceof IdentityTransform) {
sparkTableInfo.addPartition(transform);
} else {
throw new NotSupportedException("Not support Spark transform: " + transform.name());
}
});
return sparkTableInfo;
}

public List<SparkColumnInfo> getNonPartitionColumns() {
return columns.stream()
.filter(column -> !partitionColumnNames.contains(column.name))
.collect(Collectors.toList());
}

public List<SparkColumnInfo> getPartitionColumns() {
return columns.stream()
.filter(column -> partitionColumnNames.contains(column.name))
.collect(Collectors.toList());
}

@Data
public static class SparkColumnInfo {
private String name;
Expand Down
Loading

0 comments on commit fcbe74b

Please sign in to comment.