Skip to content

Commit

Permalink
[KYUUBI #5550] Optimizing TPC-DS dataset generation for 10x speedup
Browse files Browse the repository at this point in the history
### _Why are the changes needed?_

1. This PR fixes the precision loss issue in `xx_gmt_offset`. Please note that since `xx_gmt_offset` is of integer type, there is no actual loss of precision.

```
trino:tiny> select cc_gmt_offset from call_center ;
 cc_gmt_offset
---------------
         -5.00
         -5.00
```

Before this PR:

```scala
scala> spark.sql("select cc_gmt_offset from tpcds.tiny.call_center").show
+-------------+
|cc_gmt_offset|
+-------------+
|           -5|
|           -5|
+-------------+
```

After this PR:
```scala
scala> spark.sql("select cc_gmt_offset from tpcds.tiny.call_center").show
+-------------+
|cc_gmt_offset|
+-------------+
|        -5.00|
|        -5.00|
+-------------+
```

2. This PR accelerates the generation of the TPC-DS dataset by optimizing the way Rows are generated.

Before this PR, The previous process involved converting **Trino TableRow** into **String Row** and then further into **Spark InternalRow**.

After this PR, we have streamlined the process by directly converting **Trino TableRow** into **Spark InternalRow**, eliminating unnecessary toString operations. This change significantly improves the speed of TPC-DS dataset generation.

```scala
spark.table("tpcds.sf1000.catalog_sales").foreach(r => ())
```

Task Duration before this PR:

![截屏2023-10-30 下午4 04 12](https://github.com/apache/kyuubi/assets/8537877/69bd9938-2886-4044-99b8-79ed20d4791c)

Task Duration after this PR:

![截屏2023-10-30 下午4 02 08](https://github.com/apache/kyuubi/assets/8537877/ddfe01a9-081c-41b5-b82c-a0934dd8686c)

### _How was this patch tested?_

- New UT `tpcds.tiny count and checksum`
- Compare checksum values before and after this PR on the 1TB dataset

| table_name             | count           | checksum                  |
|------------------------|-----------------|---------------------------|
| call_center            | 42              | 95607401475               |
| catalog_page           | 30000           | 64470199469085            |
| catalog_returns        | 143996756       | 309202327050775220        |
| catalog_sales          | 1439980416      | 3092267266923848000       |
| customer               | 12000000        | 25769069905636795         |
| customer_address       | 6000000         | 12889423380880973         |
| customer_demographics  | 1920800         | 4124183189708148          |
| date_dim               | 73049           | 156926081012862           |
| household_demographics | 7200            | 15494873325812            |
| income_band            | 20              | 41180951007               |
| inventory              | 783000000       | 1681487454682584456       |
| item                   | 300000          | 643000708260945           |
| promotion              | 1500            | 3270935493709             |
| reason                 | 65              | 118806664977              |
| ship_mode              | 20              | 52349078860               |
| store                  | 1002            | 2096408105720             |
| store_returns          | 287999764       | 618451374856897114        |
| store_sales            | 2879987999      | 6184670571185100839       |
| time_dim               | 86400           | 186045071019485           |
| warehouse              | 20              | 31374161844               |
| web_page               | 3000            | 6502456139647             |
| web_returns            | 71997522        | 154614570845312413        |
| web_sales              | 720000376       | 1546188452223821591       |
| web_site               | 54              | 107485781738              |

### _Was this patch authored or co-authored using generative AI tooling?_

No

Closes #5562 from cfmcgrady/tpcds-perf.

Closes #5550

a789b9e [Fu Chen] maxPartitionBytes=384m
659e209 [Fu Chen] style
916f6d2 [Fu Chen] unnecessary change
75981af [Fu Chen] tpcds perf

Authored-by: Fu Chen <[email protected]>
Signed-off-by: Cheng Pan <[email protected]>
(cherry picked from commit 4c915b7)
Signed-off-by: Cheng Pan <[email protected]>
  • Loading branch information
cfmcgrady authored and pan3793 committed Oct 31, 2023
1 parent dfb90ef commit 2f063cf
Show file tree
Hide file tree
Showing 5 changed files with 1,790 additions and 71 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.kyuubi.spark.connector.tpcds

import java.lang.{Iterable => JIterable}
import java.lang.reflect.InvocationTargetException
import java.util.{Iterator => JIterator}

import com.google.common.collect.AbstractIterator
import io.trino.tpcds._
import io.trino.tpcds.`type`.{Decimal => TPCDSDecimal}
import io.trino.tpcds.row.generator.RowGenerator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, RebaseDateTime}
import org.apache.spark.sql.types.{CharType, DateType, Decimal, DecimalType, IntegerType, LongType, StringType, StructType, VarcharType}
import org.apache.spark.unsafe.types.UTF8String

import org.apache.kyuubi.spark.connector.tpcds.KyuubiResultsIterator.{FALSE_STRING, TRUE_STRING}
import org.apache.kyuubi.spark.connector.tpcds.row.KyuubiTableRows

class KyuubiTPCDSResults(
val table: Table,
val startingRowNumber: Long,
val rowCount: Long,
val session: Session,
val schema: StructType) extends JIterable[InternalRow] {

override def iterator: JIterator[InternalRow] =
new KyuubiResultsIterator(table, startingRowNumber, rowCount, session, schema)
}

object KyuubiTPCDSResults {
def constructResults(table: Table, session: Session, schema: StructType): KyuubiTPCDSResults = {
val chunkBoundaries = io.trino.tpcds.Parallel.splitWork(table, session)
new KyuubiTPCDSResults(
table,
chunkBoundaries.getFirstRow(),
chunkBoundaries.getLastRow(),
session,
schema)
}
}

class KyuubiResultsIterator(
val table: Table,
val startingRowNumber: Long,
val endingRowNumber: Long,
val session: Session,
val sparkSchema: StructType) extends AbstractIterator[InternalRow] {
private var rowNumber: Long = 0L
private var rowGenerator: RowGenerator = _
private var parentRowGenerator: Option[RowGenerator] = None
private var childRowGenerator: Option[RowGenerator] = None

try {
require(table != null, "table is null")
require(session != null, "session is null")
require(startingRowNumber >= 1, s"starting row number is less than 1: $startingRowNumber")
require(
endingRowNumber <= session.getScaling.getRowCount(table),
s"starting row number is greater than the total rows in $table: $endingRowNumber")
rowNumber = startingRowNumber
rowGenerator = table.getRowGeneratorClass().getDeclaredConstructor().newInstance()
parentRowGenerator = if (table.isChild()) {
Some(table.getParent().getRowGeneratorClass().getDeclaredConstructor().newInstance())
} else None
childRowGenerator = if (table.hasChild()) {
Some(table.getChild().getRowGeneratorClass().getDeclaredConstructor().newInstance())
} else None
} catch {
case e @ (_: NoSuchMethodException |
_: InstantiationException |
_: InvocationTargetException |
_: IllegalAccessException) =>
throw new TpcdsException(e.toString());
}
skipRowsUntilStartingRowNumber(startingRowNumber)

private def skipRowsUntilStartingRowNumber(startingRowNumber: Long): Unit = {
rowGenerator.skipRowsUntilStartingRowNumber(startingRowNumber)
parentRowGenerator.foreach(_.skipRowsUntilStartingRowNumber(startingRowNumber))
childRowGenerator.foreach(_.skipRowsUntilStartingRowNumber(startingRowNumber))
}

override protected def computeNext(): InternalRow = {
if (rowNumber > endingRowNumber) {
return endOfData
}
val result = rowGenerator.generateRowAndChildRows(
rowNumber,
session,
parentRowGenerator.orNull,
childRowGenerator.orNull)
var row: InternalRow = null
if (!result.getRowAndChildRows.isEmpty) {
row = toInternalRow(KyuubiTableRows.getValues(result.getRowAndChildRows.get(0)))
}

if (result.shouldEndRow) {
rowStop()
rowNumber += 1
}
if (result.getRowAndChildRows().isEmpty()) {
row = computeNext()
}
row
}

private def rowStop(): Unit = {
rowGenerator.consumeRemainingSeedsForRow()
parentRowGenerator.foreach(_.consumeRemainingSeedsForRow())
childRowGenerator.foreach(_.consumeRemainingSeedsForRow())
}

private val reusedRow = new Array[Any](sparkSchema.length)

def toInternalRow(values: Array[Any]): InternalRow = {
var i = 0
while (i < values.length) {
reusedRow(i) = (values(i), sparkSchema(i).dataType) match {
case (None | null, _) => null
case (Some(Options.DEFAULT_NULL_STRING), _) => null
case (Some(v: Boolean), _) => if (v) TRUE_STRING else FALSE_STRING
case (Some(v: Int), IntegerType) => v
case (Some(v: Long), IntegerType) => v.toInt
case (Some(v: Int), LongType) => v.toLong
case (Some(v: Long), LongType) => v
case (Some(v: Long), DateType) =>
RebaseDateTime.rebaseJulianToGregorianDays(v.toInt) - DateTimeUtils.JULIAN_DAY_OF_EPOCH
case (Some(v), StringType) => UTF8String.fromString(v.toString)
case (Some(v), CharType(_)) => UTF8String.fromString(v.toString)
case (Some(v), VarcharType(_)) => UTF8String.fromString(v.toString)
case (Some(v: TPCDSDecimal), t: DecimalType) =>
Decimal(v.getNumber, t.precision, t.scale)
case (Some(v: Int), t: DecimalType) =>
val decimal = Decimal(v)
decimal.changePrecision(t.precision, t.scale)
decimal
case (Some(v), dt) => throw new IllegalArgumentException(
s"value: $v, value class: ${v.getClass.getName} type: $dt")
}
i += 1
}
new GenericInternalRow(reusedRow)
}
}

object KyuubiResultsIterator {
private val TRUE_STRING = UTF8String.fromString("Y")
private val FALSE_STRING = UTF8String.fromString("N")
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,15 @@

package org.apache.kyuubi.spark.connector.tpcds

import java.time.LocalDate
import java.time.format.DateTimeFormatter
import java.util.OptionalLong

import scala.collection.JavaConverters._

import io.trino.tpcds._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

case class TPCDSTableChuck(table: String, scale: Double, parallelism: Int, index: Int)
case class TPCDSTableChunk(table: String, scale: Double, parallelism: Int, index: Int)
extends InputPartition

class TPCDSBatchScan(
Expand Down Expand Up @@ -62,10 +57,10 @@ class TPCDSBatchScan(
override def readSchema: StructType = schema

override def planInputPartitions: Array[InputPartition] =
(1 to parallelism).map { i => TPCDSTableChuck(table.getName, scale, parallelism, i) }.toArray
(1 to parallelism).map { i => TPCDSTableChunk(table.getName, scale, parallelism, i) }.toArray

def createReaderFactory: PartitionReaderFactory = (partition: InputPartition) => {
val chuck = partition.asInstanceOf[TPCDSTableChuck]
val chuck = partition.asInstanceOf[TPCDSTableChunk]
new TPCDSPartitionReader(chuck.table, chuck.scale, chuck.parallelism, chuck.index, schema)
}

Expand All @@ -90,32 +85,9 @@ class TPCDSPartitionReader(
opt.toSession.withChunkNumber(index)
}

private lazy val dateFmt: DateTimeFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd")

private val reusedRow = new Array[Any](schema.length)
private val iterator = Results
.constructResults(chuckInfo.getOnlyTableToGenerate, chuckInfo)
.iterator.asScala
.map { _.get(0).asScala } // the 1st row is specific table row
.map { stringRow =>
var i = 0
while (i < stringRow.length) {
reusedRow(i) = (stringRow(i), schema(i).dataType) match {
case (null, _) => null
case (Options.DEFAULT_NULL_STRING, _) => null
case (v, IntegerType) => v.toInt
case (v, LongType) => v.toLong
case (v, DateType) => LocalDate.parse(v, dateFmt).toEpochDay.toInt
case (v, StringType) => UTF8String.fromString(v)
case (v, CharType(_)) => UTF8String.fromString(v)
case (v, VarcharType(_)) => UTF8String.fromString(v)
case (v, DecimalType()) => Decimal(v)
case (v, dt) => throw new IllegalArgumentException(s"value: $v, type: $dt")
}
i += 1
}
InternalRow(reusedRow: _*)
}
private val iterator = KyuubiTPCDSResults
.constructResults(chuckInfo.getOnlyTableToGenerate, chuckInfo, schema)
.iterator

private var currentRow: InternalRow = _

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,5 @@ object TPCDSConf {

val TPCDS_CONNECTOR_READ_CONF_PREFIX = s"$TPCDS_CONNECTOR_CONF_PREFIX.read"
val MAX_PARTITION_BYTES_CONF = "maxPartitionBytes"
val MAX_PARTITION_BYTES_DEFAULT = "128m"
val MAX_PARTITION_BYTES_DEFAULT = "384m"
}
Loading

0 comments on commit 2f063cf

Please sign in to comment.