From 4dc8769930f356e6138f88f290336f88fb409d3d Mon Sep 17 00:00:00 2001 From: birdstorm Date: Mon, 9 Jul 2018 14:23:03 +0800 Subject: [PATCH] update scalafmt for more rules (#386) --- core/pom.xml | 2 +- core/scalafmt.conf | 36 +++++++++---------- .../com/pingcap/tispark/MetaManager.scala | 15 +++----- .../scala/com/pingcap/tispark/TiUtils.scala | 17 ++++----- .../listener/CacheInvalidateListener.scala | 6 ++-- .../listener/PDCacheInvalidateListener.scala | 3 +- .../tispark/statistics/StatisticsHelper.scala | 9 ++--- .../statistics/StatisticsManager.scala | 6 ++-- .../org/apache/spark/sql/TiContext.scala | 6 ++-- .../org/apache/spark/sql/TiStrategy.scala | 35 +++++++----------- .../aggregate/CollectHandles.scala | 9 ++--- .../spark/sql/execution/CoprocessorRDD.scala | 23 +++++------- .../spark/sql/hive/TiSessionCatalog.scala | 18 ++++------ .../org/apache/spark/SparkFunSuite.scala | 6 ++-- .../spark/sql/AlterTableTestSuite.scala | 3 +- .../apache/spark/sql/BaseTiSparkSuite.scala | 15 +++----- .../org/apache/spark/sql/IssueTestSuite.scala | 3 +- .../org/apache/spark/sql/QueryTest.scala | 27 +++++--------- .../spark/sql/TiDBMapDatabaseSuite.scala | 3 +- .../spark/sql/catalyst/plans/PlanTest.scala | 9 ++--- .../index/PrefixIndexTestSuite.scala | 3 +- .../expression/index/UnsignedTestSuite.scala | 3 +- .../statistics/StatisticsManagerSuite.scala | 21 ++++------- .../spark/sql/test/SharedSQLContext.scala | 15 +++----- .../org/apache/spark/sql/test/Utils.scala | 9 ++--- 25 files changed, 108 insertions(+), 194 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index f83b65f757..a5a55b5227 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -253,7 +253,7 @@ org.antipathy mvn-scalafmt - 0.5_1.3.0 + 0.7_1.5.1 ${project.basedir}/scalafmt.conf diff --git a/core/scalafmt.conf b/core/scalafmt.conf index 7e506da6b7..1c4b4cd98a 100644 --- a/core/scalafmt.conf +++ b/core/scalafmt.conf @@ -27,22 +27,20 @@ lineEndings = unix maxColumn = 100 -# TODO Somehow the following lines causes a configuration parsing error... - -# newlines { -# alwaysBeforeElseAfterCurlyIf = false -# alwaysBeforeTopLevelStatements = false -# penalizeSingleSelectMultiArgList = false -# sometimesBeforeColonInMethodReturnType = false -# } -# -# optIn.breakChainOnFirstMethodDot = true -# -# rewrite.rules = [ -# RedundantBraces -# RedundantParens -# SortImports -# PreferCurlyFors -# ] -# -# spaces.afterKeywordBeforeParen = true +newlines { + alwaysBeforeElseAfterCurlyIf = false + alwaysBeforeTopLevelStatements = false + penalizeSingleSelectMultiArgList = false + sometimesBeforeColonInMethodReturnType = false +} + +optIn.breakChainOnFirstMethodDot = true + +rewrite.rules = [ + RedundantBraces + RedundantParens + SortImports + PreferCurlyFors +] + +spaces.afterKeywordBeforeParen = true diff --git a/core/src/main/scala/com/pingcap/tispark/MetaManager.scala b/core/src/main/scala/com/pingcap/tispark/MetaManager.scala index 461a4260b4..9204a5b4f9 100644 --- a/core/src/main/scala/com/pingcap/tispark/MetaManager.scala +++ b/core/src/main/scala/com/pingcap/tispark/MetaManager.scala @@ -23,23 +23,18 @@ import scala.collection.JavaConversions._ // Likely this needs to be merge to client project // and serving inside metastore if any class MetaManager(catalog: Catalog) { - def reloadMeta(): Unit = { + def reloadMeta(): Unit = catalog.reloadCache() - } - def getDatabases: List[TiDBInfo] = { + def getDatabases: List[TiDBInfo] = catalog.listDatabases().toList - } - def getTables(db: TiDBInfo): List[TiTableInfo] = { + def getTables(db: TiDBInfo): List[TiTableInfo] = catalog.listTables(db).toList - } - def getTable(dbName: String, tableName: String): Option[TiTableInfo] = { + def getTable(dbName: String, tableName: String): Option[TiTableInfo] = Option(catalog.getTable(dbName, tableName)) - } - def getDatabase(dbName: String): Option[TiDBInfo] = { + def getDatabase(dbName: String): Option[TiDBInfo] = Option(catalog.getDatabase(dbName)) - } } diff --git a/core/src/main/scala/com/pingcap/tispark/TiUtils.scala b/core/src/main/scala/com/pingcap/tispark/TiUtils.scala index 8ac07f35b6..b6a42e07a0 100644 --- a/core/src/main/scala/com/pingcap/tispark/TiUtils.scala +++ b/core/src/main/scala/com/pingcap/tispark/TiUtils.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression} import org.apache.spark.sql.types.{DataType, DataTypes, MetadataBuilder, StructField, StructType} import org.apache.spark.sql.{SparkSession, TiStrategy} -import org.apache.spark.{SparkConf, sql} +import org.apache.spark.{sql, SparkConf} import scala.collection.JavaConversions._ import scala.collection.mutable @@ -45,7 +45,7 @@ object TiUtils { def isSupportedAggregate(aggExpr: AggregateExpression, tiDBRelation: TiDBRelation, - blacklist: ExpressionBlacklist): Boolean = { + blacklist: ExpressionBlacklist): Boolean = aggExpr.aggregateFunction match { case Average(_) | Sum(_) | PromotedSum(_) | Count(_) | Min(_) | Max(_) => !aggExpr.isDistinct && @@ -53,7 +53,6 @@ object TiUtils { .forall(isSupportedBasicExpression(_, tiDBRelation, blacklist)) case _ => false } - } def isSupportedBasicExpression(expr: Expression, tiDBRelation: TiDBRelation, @@ -103,9 +102,8 @@ object TiUtils { def isSupportedFilter(expr: Expression, source: TiDBRelation, - blacklist: ExpressionBlacklist): Boolean = { + blacklist: ExpressionBlacklist): Boolean = isSupportedBasicExpression(expr, source, blacklist) && isPushDownSupported(expr, source) - } // if contains UDF / functions that cannot be folded def isSupportedGroupingExpr(expr: NamedExpression, @@ -114,7 +112,7 @@ object TiUtils { isSupportedBasicExpression(expr, source, blacklist) && isPushDownSupported(expr, source) // convert tikv-java client FieldType to Spark DataType - def toSparkDataType(tp: TiDataType): DataType = { + def toSparkDataType(tp: TiDataType): DataType = tp match { case _: StringType => sql.types.StringType case _: BytesType => sql.types.BinaryType @@ -146,9 +144,8 @@ object TiUtils { case _: SetType => sql.types.LongType case _: YearType => sql.types.LongType } - } - def fromSparkType(tp: DataType): TiDataType = { + def fromSparkType(tp: DataType): TiDataType = tp match { case _: sql.types.BinaryType => BytesType.BLOB case _: sql.types.StringType => StringType.VARCHAR @@ -158,7 +155,6 @@ object TiUtils { case _: sql.types.TimestampType => TimestampType.TIMESTAMP case _: sql.types.DateType => DateType.DATE } - } def getSchemaFromTable(table: TiTableInfo): StructType = { val fields = new Array[StructField](table.getColumns.size()) @@ -231,11 +227,10 @@ object TiUtils { StatisticsManager.initStatisticsManager(tiSession, session) } - def getReqEstCountStr(req: TiDAGRequest): String = { + def getReqEstCountStr(req: TiDAGRequest): String = if (req.getEstimatedCount > 0) { import java.text.DecimalFormat val df = new DecimalFormat("#.#") s" EstimatedCount:${df.format(req.getEstimatedCount)}" } else "" - } } diff --git a/core/src/main/scala/com/pingcap/tispark/listener/CacheInvalidateListener.scala b/core/src/main/scala/com/pingcap/tispark/listener/CacheInvalidateListener.scala index f208291d41..f287c16499 100644 --- a/core/src/main/scala/com/pingcap/tispark/listener/CacheInvalidateListener.scala +++ b/core/src/main/scala/com/pingcap/tispark/listener/CacheInvalidateListener.scala @@ -54,7 +54,7 @@ object CacheInvalidateListener { * @param sc The spark SparkContext used for attaching a cache listener. * @param regionManager The RegionManager to invalidate local cache. */ - def initCacheListener(sc: SparkContext, regionManager: RegionManager): Unit = { + def initCacheListener(sc: SparkContext, regionManager: RegionManager): Unit = if (manager == null) { synchronized { if (manager == null) { @@ -67,9 +67,8 @@ object CacheInvalidateListener { } } } - } - def init(sc: SparkContext, regionManager: RegionManager, manager: CacheInvalidateListener): Unit = { + def init(sc: SparkContext, regionManager: RegionManager, manager: CacheInvalidateListener): Unit = if (sc != null && regionManager != null) { sc.register(manager.CACHE_INVALIDATE_ACCUMULATOR, manager.CACHE_ACCUMULATOR_NAME) sc.addSparkListener( @@ -79,5 +78,4 @@ object CacheInvalidateListener { ) ) } - } } diff --git a/core/src/main/scala/com/pingcap/tispark/listener/PDCacheInvalidateListener.scala b/core/src/main/scala/com/pingcap/tispark/listener/PDCacheInvalidateListener.scala index 02da762d75..e9cf24bee6 100644 --- a/core/src/main/scala/com/pingcap/tispark/listener/PDCacheInvalidateListener.scala +++ b/core/src/main/scala/com/pingcap/tispark/listener/PDCacheInvalidateListener.scala @@ -28,7 +28,7 @@ class PDCacheInvalidateListener(accumulator: CacheInvalidateAccumulator, extends SparkListener { private final val logger: Logger = Logger.getLogger(getClass.getName) - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = if (accumulator != null && !accumulator.isZero && handler != null) { synchronized { if (!accumulator.isZero) { @@ -42,5 +42,4 @@ class PDCacheInvalidateListener(accumulator: CacheInvalidateAccumulator, } } } - } } diff --git a/core/src/main/scala/com/pingcap/tispark/statistics/StatisticsHelper.scala b/core/src/main/scala/com/pingcap/tispark/statistics/StatisticsHelper.scala index fa64a8da55..1b0dc0e294 100644 --- a/core/src/main/scala/com/pingcap/tispark/statistics/StatisticsHelper.scala +++ b/core/src/main/scala/com/pingcap/tispark/statistics/StatisticsHelper.scala @@ -200,7 +200,7 @@ object StatisticsHelper { private[statistics] def buildHistogramsRequest(histTable: TiTableInfo, targetTblId: Long, - startTs: Long): TiDAGRequest = { + startTs: Long): TiDAGRequest = TiDAGRequest.Builder .newBuilder() .setFullTableScan(histTable) @@ -213,14 +213,13 @@ object StatisticsHelper { ) .setStartTs(startTs) .build(PushDownType.NORMAL) - } private def checkColExists(table: TiTableInfo, column: String): Boolean = table.getColumns.exists { _.matchName(column) } private[statistics] def buildMetaRequest(metaTable: TiTableInfo, targetTblId: Long, - startTs: Long): TiDAGRequest = { + startTs: Long): TiDAGRequest = TiDAGRequest.Builder .newBuilder() .setFullTableScan(metaTable) @@ -231,11 +230,10 @@ object StatisticsHelper { .addRequiredCols(metaRequiredCols.filter(checkColExists(metaTable, _))) .setStartTs(startTs) .build(PushDownType.NORMAL) - } private[statistics] def buildBucketRequest(bucketTable: TiTableInfo, targetTblId: Long, - startTs: Long): TiDAGRequest = { + startTs: Long): TiDAGRequest = TiDAGRequest.Builder .newBuilder() .setFullTableScan(bucketTable) @@ -250,5 +248,4 @@ object StatisticsHelper { ) .setStartTs(startTs) .build(PushDownType.NORMAL) - } } diff --git a/core/src/main/scala/com/pingcap/tispark/statistics/StatisticsManager.scala b/core/src/main/scala/com/pingcap/tispark/statistics/StatisticsManager.scala index 2ff6d79e48..3f95bbc667 100644 --- a/core/src/main/scala/com/pingcap/tispark/statistics/StatisticsManager.scala +++ b/core/src/main/scala/com/pingcap/tispark/statistics/StatisticsManager.scala @@ -231,9 +231,8 @@ class StatisticsManager(tiSession: TiSession) { .toSeq } - def getTableStatistics(id: Long): TableStatistics = { + def getTableStatistics(id: Long): TableStatistics = statisticsMap.getIfPresent(id) - } /** * Estimated row count of one table @@ -259,7 +258,7 @@ class StatisticsManager(tiSession: TiSession) { object StatisticsManager { private var manager: StatisticsManager = _ - def initStatisticsManager(tiSession: TiSession, session: SparkSession): Unit = { + def initStatisticsManager(tiSession: TiSession, session: SparkSession): Unit = if (manager == null) { synchronized { if (manager == null) { @@ -267,7 +266,6 @@ object StatisticsManager { } } } - } def reset(): Unit = manager = null diff --git a/core/src/main/scala/org/apache/spark/sql/TiContext.scala b/core/src/main/scala/org/apache/spark/sql/TiContext.scala index c1e3fc811b..2d33f6aad5 100644 --- a/core/src/main/scala/org/apache/spark/sql/TiContext.scala +++ b/core/src/main/scala/org/apache/spark/sql/TiContext.scala @@ -47,9 +47,8 @@ class TiContext(val session: SparkSession) extends Serializable with Logging { conf.getBoolean("spark.tispark.statistics.auto_load", defaultValue = true) class DebugTool { - def getRegionDistribution(dbName: String, tableName: String): Map[String, Integer] = { + def getRegionDistribution(dbName: String, tableName: String): Map[String, Integer] = RegionUtils.getRegionDistribution(tiSession, dbName, tableName).asScala.toMap - } /** * Balance region leaders of a single table. @@ -146,9 +145,8 @@ class TiContext(val session: SparkSession) extends Serializable with Logging { df } - def tidbMapDatabase(dbName: String, dbNameAsPrefix: Boolean): Unit = { + def tidbMapDatabase(dbName: String, dbNameAsPrefix: Boolean): Unit = tidbMapDatabase(dbName, dbNameAsPrefix, autoLoad) - } def tidbMapDatabase(dbName: String, dbNameAsPrefix: Boolean = false, diff --git a/core/src/main/scala/org/apache/spark/sql/TiStrategy.scala b/core/src/main/scala/org/apache/spark/sql/TiStrategy.scala index c0ff110a54..b7f0d7781d 100644 --- a/core/src/main/scala/org/apache/spark/sql/TiStrategy.scala +++ b/core/src/main/scala/org/apache/spark/sql/TiStrategy.scala @@ -65,36 +65,31 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { new TypeBlacklist(blacklistString) } - private def allowAggregationPushdown(): Boolean = { + private def allowAggregationPushdown(): Boolean = sqlConf.getConfString(TiConfigConst.ALLOW_AGG_PUSHDOWN, "true").toBoolean - } - private def allowIndexDoubleRead(): Boolean = { + private def allowIndexDoubleRead(): Boolean = sqlConf.getConfString(TiConfigConst.ALLOW_INDEX_READ, "false").toBoolean - } - private def useStreamingProcess(): Boolean = { + private def useStreamingProcess(): Boolean = sqlConf.getConfString(TiConfigConst.COPROCESS_STREAMING, "false").toBoolean - } - private def timeZoneOffset(): Int = { + private def timeZoneOffset(): Int = sqlConf .getConfString( TiConfigConst.KV_TIMEZONE_OFFSET, String.valueOf(ZonedDateTime.now.getOffset.getTotalSeconds) ) .toInt - } - private def pushDownType(): PushDownType = { + private def pushDownType(): PushDownType = if (useStreamingProcess()) { PushDownType.STREAMING } else { PushDownType.NORMAL } - } - override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan .collectFirst { case LogicalRelation(relation: TiDBRelation, _, _) => @@ -102,7 +97,6 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { } .toSeq .flatten - } private def toCoprocessorRDD( source: TiDBRelation, @@ -394,9 +388,8 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { case e if e.deterministic => e.canonicalized -> Alias(e, e.toString())() }.toMap - def aliasPushedPartialResult(e: AggregateExpression): Alias = { + def aliasPushedPartialResult(e: AggregateExpression): Alias = deterministicAggAliases.getOrElse(e.canonicalized, Alias(e, e.toString())()) - } val residualAggregateExpressions = aggregateExpressions.map { aggExpr => // As `aggExpr` is being pushing down to TiKV, we need to replace the original Catalyst @@ -464,18 +457,17 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { aggregateExpressions: Seq[AggregateExpression], filters: Seq[Expression], source: TiDBRelation - ): Boolean = { + ): Boolean = allowAggregationPushdown && - filters.forall(TiUtils.isSupportedFilter(_, source, blacklist)) && - groupingExpressions.forall(TiUtils.isSupportedGroupingExpr(_, source, blacklist)) && - aggregateExpressions.forall(TiUtils.isSupportedAggregate(_, source, blacklist)) && - !aggregateExpressions.exists(_.isDistinct) - } + filters.forall(TiUtils.isSupportedFilter(_, source, blacklist)) && + groupingExpressions.forall(TiUtils.isSupportedGroupingExpr(_, source, blacklist)) && + aggregateExpressions.forall(TiUtils.isSupportedAggregate(_, source, blacklist)) && + !aggregateExpressions.exists(_.isDistinct) // We do through similar logic with original Spark as in SparkStrategies.scala // Difference is we need to test if a sub-plan can be consumed all together by TiKV // and then we don't return (don't planLater) and plan the remaining all at once - private def doPlan(source: TiDBRelation, plan: LogicalPlan): Seq[SparkPlan] = { + private def doPlan(source: TiDBRelation, plan: LogicalPlan): Seq[SparkPlan] = // TODO: This test should be done once for all children plan match { case logical.ReturnAnswer(rootPlan) => @@ -534,7 +526,6 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging { ) case _ => Nil } - } } object TiAggregation { diff --git a/core/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CollectHandles.scala b/core/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CollectHandles.scala index e644a9f4dc..79164e973d 100644 --- a/core/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CollectHandles.scala +++ b/core/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CollectHandles.scala @@ -68,9 +68,8 @@ case class CollectHandles(child: Expression, // actual order of input rows. override def deterministic: Boolean = false - override def initialize(b: InternalRow): Unit = { + override def initialize(b: InternalRow): Unit = buffer.clear() - } override def update(b: InternalRow, input: InternalRow): Unit = { val value = child.eval(input) @@ -79,11 +78,9 @@ case class CollectHandles(child: Expression, } } - override def merge(buffer: InternalRow, input: InternalRow): Unit = { + override def merge(buffer: InternalRow, input: InternalRow): Unit = sys.error("Collect cannot be used in partial aggregations.") - } - override def eval(input: InternalRow): Any = { + override def eval(input: InternalRow): Any = new GenericArrayData(buffer.toArray) - } } diff --git a/core/src/main/scala/org/apache/spark/sql/execution/CoprocessorRDD.scala b/core/src/main/scala/org/apache/spark/sql/execution/CoprocessorRDD.scala index 35c055bd02..939fea75ef 100644 --- a/core/src/main/scala/org/apache/spark/sql/execution/CoprocessorRDD.scala +++ b/core/src/main/scala/org/apache/spark/sql/execution/CoprocessorRDD.scala @@ -68,10 +68,9 @@ case class CoprocessorRDD(output: Seq[Attribute], tiRdd: TiRDD) extends LeafExec } } - override def verboseString: String = { + override def verboseString: String = s"TiDB $nodeName{${tiRdd.dagRequest.toString}}" + s"${TiUtils.getReqEstCountStr(tiRdd.dagRequest)}" - } override def simpleString: String = verboseString } @@ -119,10 +118,9 @@ case class HandleRDDExec(tiHandleRDD: TiHandleRDD) extends LeafExecNode { override def output: Seq[Attribute] = attributeRef - override def verboseString: String = { + override def verboseString: String = s"TiDB $nodeName{${tiHandleRDD.dagRequest.toString}}" + s"${TiUtils.getReqEstCountStr(tiHandleRDD.dagRequest)}" - } override def simpleString: String = verboseString } @@ -260,9 +258,8 @@ case class RegionTaskExec(child: SparkPlan, * * @return true, the number of handle ranges retrieved exceeds the `downgradeThreshold` after handle merge, false otherwise. */ - def satisfyDowngradeThreshold: Boolean = { + def satisfyDowngradeThreshold: Boolean = indexTaskRanges.size() > downgradeThreshold - } /** * If one task's ranges list exceeds some threshold, we split it into two sub tasks and @@ -296,22 +293,20 @@ case class RegionTaskExec(child: SparkPlan, finalTasks } - def isTaskRangeSizeInvalid(task: RegionTask): Boolean = { + def isTaskRangeSizeInvalid(task: RegionTask): Boolean = task == null || - task.getRanges.size() > tiConf.getMaxRequestKeyRangeSize - } + task.getRanges.size() > tiConf.getMaxRequestKeyRangeSize def submitTasks(tasks: List[RegionTask], dagRequest: TiDAGRequest): Unit = { taskCount += 1 val task = new Callable[util.Iterator[TiRow]] { - override def call(): util.Iterator[TiRow] = { + override def call(): util.Iterator[TiRow] = CoprocessIterator.getRowIterator(dagRequest, tasks, session) - } } completionService.submit(task) } - def doIndexScan(): Unit = { + def doIndexScan(): Unit = while (handleIterator.hasNext) { val handleList = feedBatch() numHandles += handleList.size() @@ -343,7 +338,6 @@ case class RegionTaskExec(child: SparkPlan, submitTasks(tasks.toList, dagRequest) numIndexRangesScanned += taskRange.size } - } /** * We merge potentially discrete index ranges from `taskRanges` into one large range @@ -463,9 +457,8 @@ case class RegionTaskExec(child: SparkPlan, } } - override def verboseString: String = { + override def verboseString: String = s"TiSpark $nodeName{downgradeThreshold=$downgradeThreshold,downgradeFilter=${dagRequest.getFilters}" - } override def simpleString: String = verboseString } diff --git a/core/src/main/scala/org/apache/spark/sql/hive/TiSessionCatalog.scala b/core/src/main/scala/org/apache/spark/sql/hive/TiSessionCatalog.scala index f11493cf0a..b9d09fd106 100644 --- a/core/src/main/scala/org/apache/spark/sql/hive/TiSessionCatalog.scala +++ b/core/src/main/scala/org/apache/spark/sql/hive/TiSessionCatalog.scala @@ -49,7 +49,7 @@ class TiSessionCatalog(externalCatalog: HiveExternalCatalog, val meta: MetaManager = new MetaManager(session.getCatalog) - override def lookupRelation(tableIdent: TableIdentifier, alias: Option[String]): LogicalPlan = { + override def lookupRelation(tableIdent: TableIdentifier, alias: Option[String]): LogicalPlan = synchronized { val table = formatTableName(tableIdent.table) val db = formatDatabaseName(tableIdent.database.getOrElse(currentDb)) @@ -63,7 +63,6 @@ class TiSessionCatalog(externalCatalog: HiveExternalCatalog, super.lookupRelation(tableIdent, alias) } } - } override def databaseExists(db: String): Boolean = { val dbName = formatDatabaseName(db) @@ -76,15 +75,13 @@ class TiSessionCatalog(externalCatalog: HiveExternalCatalog, } } - override def listDatabases(): Seq[String] = { + override def listDatabases(): Seq[String] = meta.getDatabases .map(_.getName) .union(super.listDatabases()) - } - override def listDatabases(pattern: String): Seq[String] = { + override def listDatabases(pattern: String): Seq[String] = StringUtils.filterPattern(listDatabases(), pattern) - } override def tableExists(name: TableIdentifier): Boolean = synchronized { val db = formatDatabaseName(name.database.getOrElse(currentDb)) @@ -98,11 +95,10 @@ class TiSessionCatalog(externalCatalog: HiveExternalCatalog, } } - private def requireDbExists(db: String): Unit = { + private def requireDbExists(db: String): Unit = if (!databaseExists(db)) { throw new NoSuchDatabaseException(db) } - } override def getDatabaseMetadata(db: String): CatalogDatabase = { val dbName = formatDatabaseName(db) @@ -153,11 +149,10 @@ class TiSessionCatalog(externalCatalog: HiveExternalCatalog, } } - def tiDBToCatalogDatabase(db: TiDBInfo): CatalogDatabase = { + def tiDBToCatalogDatabase(db: TiDBInfo): CatalogDatabase = CatalogDatabase(db.getName, "TiDB Database", null, null) - } - def tiTableToCatalogTable(name: TableIdentifier, tiTable: TiTableInfo): CatalogTable = { + def tiTableToCatalogTable(name: TableIdentifier, tiTable: TiTableInfo): CatalogTable = CatalogTable( name, CatalogTableType.EXTERNAL, @@ -165,5 +160,4 @@ class TiSessionCatalog(externalCatalog: HiveExternalCatalog, TiUtils.getSchemaFromTable(tiTable), Option("TiDB") ) - } } diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 0b5a8a19b4..b0c9890bd9 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -26,13 +26,11 @@ abstract class SparkFunSuite extends FunSuite with BeforeAndAfterAll with Loggin protected val logger: Logger = log // helper function - protected final def getTestResourceFile(file: String): File = { + protected final def getTestResourceFile(file: String): File = new File(getClass.getClassLoader.getResource(file).getFile) - } - protected final def getTestResourcePath(file: String): String = { + protected final def getTestResourcePath(file: String): String = getTestResourceFile(file).getCanonicalPath - } /** * Log the suite name and the test name before and after each test. diff --git a/core/src/test/scala/org/apache/spark/sql/AlterTableTestSuite.scala b/core/src/test/scala/org/apache/spark/sql/AlterTableTestSuite.scala index 7a59848432..ee2068026e 100644 --- a/core/src/test/scala/org/apache/spark/sql/AlterTableTestSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/AlterTableTestSuite.scala @@ -83,11 +83,10 @@ class AlterTableTestSuite extends BaseTiSparkSuite { alterTable("blob", "\"0\"", "null", "", defaultNullOnly = true) } - override def afterAll(): Unit = { + override def afterAll(): Unit = try { tidbStmt.execute("drop table if exists t") } finally { super.afterAll() } - } } diff --git a/core/src/test/scala/org/apache/spark/sql/BaseTiSparkSuite.scala b/core/src/test/scala/org/apache/spark/sql/BaseTiSparkSuite.scala index fa6712ea92..4011c0e7b3 100644 --- a/core/src/test/scala/org/apache/spark/sql/BaseTiSparkSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/BaseTiSparkSuite.scala @@ -116,9 +116,8 @@ class BaseTiSparkSuite extends QueryTest with SharedSQLContext { initializeTimeZone() } - def setLogLevel(level: String): Unit = { + def setLogLevel(level: String): Unit = spark.sparkContext.setLogLevel(level) - } /** Rename JDBC tables * - currently we use table names with `_j` suffix for JDBC tests @@ -144,9 +143,8 @@ class BaseTiSparkSuite extends QueryTest with SharedSQLContext { protected def judge(str: String, skipped: Boolean = false, checkLimit: Boolean = true): Unit = assert(execDBTSAndJudge(str, skipped, checkLimit)) - private def compSparkWithTiDB(sql: String, checkLimit: Boolean = true): Boolean = { + private def compSparkWithTiDB(sql: String, checkLimit: Boolean = true): Boolean = compSqlResult(sql, querySpark(sql), queryTiDB(sql), checkLimit) - } protected def execDBTSAndJudge(str: String, skipped: Boolean = false, @@ -189,7 +187,7 @@ class BaseTiSparkSuite extends QueryTest with SharedSQLContext { rTiDB: List[List[Any]] = null, skipJDBC: Boolean = false, skipTiDB: Boolean = false, - checkLimit: Boolean = true): Unit = { + checkLimit: Boolean = true): Unit = try { explainSpark(qSpark) if (qJDBC == null) { @@ -210,7 +208,6 @@ class BaseTiSparkSuite extends QueryTest with SharedSQLContext { } catch { case e: Throwable => fail(e) } - } /** Run test with sql `qSpark` for TiSpark and TiDB, `qJDBC` for Spark-JDBC. Throw fail exception when * - TiSpark query throws exception @@ -239,7 +236,7 @@ class BaseTiSparkSuite extends QueryTest with SharedSQLContext { rTiDB: List[List[Any]] = null, skipJDBC: Boolean = false, skipTiDB: Boolean = false, - checkLimit: Boolean = true): Unit = { + checkLimit: Boolean = true): Unit = runTestWithoutReplaceTableName( qSpark, replaceJDBCTableName(qSpark, skipJDBC), @@ -251,7 +248,6 @@ class BaseTiSparkSuite extends QueryTest with SharedSQLContext { skipTiDB, checkLimit ) - } /** Run test with sql `qSpark` for TiSpark and TiDB, `qJDBC` for Spark-JDBC. Throw fail exception when * - TiSpark query throws exception @@ -340,7 +336,7 @@ class BaseTiSparkSuite extends QueryTest with SharedSQLContext { private def mapStringList(result: List[Any]): String = if (result == null) "null" else "List(" + result.map(mapString).mkString(",") + ")" - private def mapString(result: Any): String = { + private def mapString(result: Any): String = if (result == null) "null" else result match { @@ -354,5 +350,4 @@ class BaseTiSparkSuite extends QueryTest with SharedSQLContext { case _ => result.toString } - } } diff --git a/core/src/test/scala/org/apache/spark/sql/IssueTestSuite.scala b/core/src/test/scala/org/apache/spark/sql/IssueTestSuite.scala index 6794bfbd74..5344e59d14 100644 --- a/core/src/test/scala/org/apache/spark/sql/IssueTestSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/IssueTestSuite.scala @@ -210,7 +210,7 @@ class IssueTestSuite extends BaseTiSparkSuite { judge("select count(c1 + c2) from t") } - override def afterAll(): Unit = { + override def afterAll(): Unit = try { tidbStmt.execute("drop table if exists t") tidbStmt.execute("drop table if exists tmp_debug") @@ -220,5 +220,4 @@ class IssueTestSuite extends BaseTiSparkSuite { } finally { super.afterAll() } - } } diff --git a/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8e0a7dc15e..e57a39c45c 100644 --- a/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -75,11 +75,10 @@ abstract class QueryTest extends PlanTest { case d: Number => d.longValue() } - def toString(value: Any): String = { + def toString(value: Any): String = new SimpleDateFormat("yy-MM-dd HH:mm:ss").format(value) - } - def compValue(lhs: Any, rhs: Any): Boolean = { + def compValue(lhs: Any, rhs: Any): Boolean = if (lhs == rhs || lhs.toString == rhs.toString) { true } else @@ -108,9 +107,8 @@ abstract class QueryTest extends PlanTest { case _ => false } - } - def compRow(lhs: List[Any], rhs: List[Any]): Boolean = { + def compRow(lhs: List[Any], rhs: List[Any]): Boolean = if (lhs == null && rhs == null) { true } else if (lhs == null || rhs == null) { @@ -120,13 +118,11 @@ abstract class QueryTest extends PlanTest { case (value, i) => !compValue(value, rhs(i)) } } - } - def comp(lhs: List[List[Any]], rhs: List[List[Any]]): Boolean = { + def comp(lhs: List[List[Any]], rhs: List[List[Any]]): Boolean = !lhs.zipWithIndex.exists { case (row, i) => !compRow(row, rhs(i)) } - } if (lhs != null && rhs != null) { try { @@ -301,13 +297,11 @@ abstract class QueryTest extends PlanTest { } } - protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { + protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = checkAnswer(df, Seq(expectedAnswer)) - } - protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { + protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = checkAnswer(df, expectedAnswer.collect()) - } /** * Runs the plan and makes sure the answer is within absTol of the expected result. @@ -334,9 +328,8 @@ abstract class QueryTest extends PlanTest { protected def checkAggregatesWithTol(dataFrame: DataFrame, expectedAnswer: Row, - absTol: Double): Unit = { + absTol: Double): Unit = checkAggregatesWithTol(dataFrame, Seq(expectedAnswer), absTol) - } /** * Asserts that a given [[Dataset]] will be executed using the given number of cached results. @@ -431,7 +424,7 @@ object QueryTest { } // We need to call prepareRow recursively to handle schemas with struct types. - def prepareRow(row: Row): Row = { + def prepareRow(row: Row): Row = Row.fromSeq(row.toSeq.map { case null => null case d: java.math.BigDecimal => BigDecimal(d) @@ -440,7 +433,6 @@ object QueryTest { case r: Row => prepareRow(r) case o => o }) - } def sameRows(expectedAnswer: Seq[Row], sparkAnswer: Seq[Row], @@ -490,10 +482,9 @@ object QueryTest { } } - def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = { + def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = checkAnswer(df, expectedAnswer.asScala) match { case Some(errorMessage) => errorMessage case None => null } - } } diff --git a/core/src/test/scala/org/apache/spark/sql/TiDBMapDatabaseSuite.scala b/core/src/test/scala/org/apache/spark/sql/TiDBMapDatabaseSuite.scala index 4537b332ef..b2237a9e2f 100644 --- a/core/src/test/scala/org/apache/spark/sql/TiDBMapDatabaseSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/TiDBMapDatabaseSuite.scala @@ -47,7 +47,7 @@ class TiDBMapDatabaseSuite extends BaseTiSparkSuite { judge("select * from `t-a`") } - override def afterAll(): Unit = { + override def afterAll(): Unit = try { tidbStmt.execute("drop database if exists `test-a`") tidbStmt.execute("drop database if exists `decimals`") @@ -55,5 +55,4 @@ class TiDBMapDatabaseSuite extends BaseTiSparkSuite { } finally { super.afterAll() } - } } diff --git a/core/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/core/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 86cfab6a2a..d9ca9b4394 100644 --- a/core/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/core/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -30,7 +30,7 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { * Since attribute references are given globally unique ids during analysis, * we must normalize them to check if two different queries are identical. */ - protected def normalizeExprIds(plan: LogicalPlan): plan.type = { + protected def normalizeExprIds(plan: LogicalPlan): plan.type = plan transformAllExpressions { case s: ScalarSubquery => s.copy(exprId = ExprId(0)) @@ -47,7 +47,6 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { case ae: AggregateExpression => ae.copy(resultId = ExprId(0)) } - } /** * Normalizes plans: @@ -57,7 +56,7 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { * - Sample the seed will replaced by 0L. * - Join conditions will be resorted by hashCode. */ - private def normalizePlan(plan: LogicalPlan): LogicalPlan = { + private def normalizePlan(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition: Expression, child: LogicalPlan) => Filter( @@ -77,7 +76,6 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { .reduce(And) Join(left, right, joinType, Some(newCondition)) } - } /** * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be @@ -106,7 +104,6 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { } /** Fails the test if the two expressions do not match */ - protected def compareExpressions(e1: Expression, e2: Expression): Unit = { + protected def compareExpressions(e1: Expression, e2: Expression): Unit = comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) - } } diff --git a/core/src/test/scala/org/apache/spark/sql/expression/index/PrefixIndexTestSuite.scala b/core/src/test/scala/org/apache/spark/sql/expression/index/PrefixIndexTestSuite.scala index 6dd884fac0..e25f8581ab 100644 --- a/core/src/test/scala/org/apache/spark/sql/expression/index/PrefixIndexTestSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/expression/index/PrefixIndexTestSuite.scala @@ -49,11 +49,10 @@ class PrefixIndexTestSuite extends BaseTiSparkSuite { explainAndTest("select a, b from prefix where b LIKE '%'") } - override def afterAll(): Unit = { + override def afterAll(): Unit = try { tidbStmt.execute("drop table if exists prefix") } finally { super.afterAll() } - } } diff --git a/core/src/test/scala/org/apache/spark/sql/expression/index/UnsignedTestSuite.scala b/core/src/test/scala/org/apache/spark/sql/expression/index/UnsignedTestSuite.scala index b1439ca278..d80549915b 100644 --- a/core/src/test/scala/org/apache/spark/sql/expression/index/UnsignedTestSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/expression/index/UnsignedTestSuite.scala @@ -116,11 +116,10 @@ class UnsignedTestSuite extends BaseTiSparkSuite { } } - override def afterAll(): Unit = { + override def afterAll(): Unit = try { tidbStmt.execute("drop table if exists `unsigned_test`") } finally { super.afterAll() } - } } diff --git a/core/src/test/scala/org/apache/spark/sql/statistics/StatisticsManagerSuite.scala b/core/src/test/scala/org/apache/spark/sql/statistics/StatisticsManagerSuite.scala index 3b60d9209f..bca87ac015 100644 --- a/core/src/test/scala/org/apache/spark/sql/statistics/StatisticsManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/statistics/StatisticsManagerSuite.scala @@ -169,53 +169,47 @@ class StatisticsManagerSuite extends BaseTiSparkSuite { } }) - private def isDoubleRead(executedPlan: SparkPlan): Boolean = { + private def isDoubleRead(executedPlan: SparkPlan): Boolean = executedPlan .find(_.isInstanceOf[HandleRDDExec]) .isDefined - } /** * Extract first Coprocessor tiRdd exec node from the given query * * @throws java.util.NoSuchElementException if the query does not contain any handle rdd exec node. */ - private def extractCoprocessorRDD(executedPlan: SparkPlan): CoprocessorRDD = { + private def extractCoprocessorRDD(executedPlan: SparkPlan): CoprocessorRDD = executedPlan .find(_.isInstanceOf[CoprocessorRDD]) .get .asInstanceOf[CoprocessorRDD] - } /** * Extract first handle rdd exec node from the given query * * @throws java.util.NoSuchElementException if the query does not contain any handle rdd exec node. */ - private def extractHandleRDDExec(executedPlan: SparkPlan): HandleRDDExec = { + private def extractHandleRDDExec(executedPlan: SparkPlan): HandleRDDExec = executedPlan .find(_.isInstanceOf[HandleRDDExec]) .get .asInstanceOf[HandleRDDExec] - } - private def extractUsedIndex(coprocessorRDD: CoprocessorRDD): String = { + private def extractUsedIndex(coprocessorRDD: CoprocessorRDD): String = getIndexName(coprocessorRDD.tiRdd.dagRequest.getIndexInfo) - } - private def extractUsedIndex(handleRDDExec: HandleRDDExec): String = { + private def extractUsedIndex(handleRDDExec: HandleRDDExec): String = getIndexName(handleRDDExec.tiHandleRDD.dagRequest.getIndexInfo) - } - private def getIndexName(indexInfo: TiIndexInfo): String = { + private def getIndexName(indexInfo: TiIndexInfo): String = if (indexInfo != null) { indexInfo.getName } else { "" } - } - override def afterAll(): Unit = { + override def afterAll(): Unit = try { tidbStmt.execute("DROP TABLE IF EXISTS `tb_fixed_float`") tidbStmt.execute("DROP TABLE IF EXISTS `tb_fixed_int`") @@ -223,5 +217,4 @@ class StatisticsManagerSuite extends BaseTiSparkSuite { } finally { super.afterAll() } - } } diff --git a/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 91f955daab..4fc061a9ce 100644 --- a/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -131,23 +131,20 @@ object SharedSQLContext extends Logging { * 'initializeSession' between a 'describe' and an 'it' call than it does to * call 'beforeAll'. */ - protected def initializeSession(): Unit = { + protected def initializeSession(): Unit = if (_spark == null) { _spark = _sparkSession } - } - private def initializeJDBC(): Unit = { + private def initializeJDBC(): Unit = if (_sparkJDBC == null) { _sparkJDBC = _sparkSession } - } - protected def initializeTiContext(): Unit = { + protected def initializeTiContext(): Unit = if (_spark != null && _ti == null) { _ti = new TiContext(_spark) } - } private def initStatistics(): Unit = { logger.info("Analyzing table tispark_test.full_data_type_table_idx...") @@ -157,7 +154,7 @@ object SharedSQLContext extends Logging { logger.info("Analyzing table finished.") } - private def initializeTiDB(forceNotLoad: Boolean = false): Unit = { + private def initializeTiDB(forceNotLoad: Boolean = false): Unit = if (_tidbConnection == null) { val jdbcUsername = getOrElse(_tidbConf, TiDB_USER, "root") @@ -199,9 +196,8 @@ object SharedSQLContext extends Logging { initStatistics() } } - } - private def initializeConf(): Unit = { + private def initializeConf(): Unit = if (_tidbConf == null) { val confStream = Thread .currentThread() @@ -220,7 +216,6 @@ object SharedSQLContext extends Logging { _tidbConf = prop _sparkSession = new TestSparkSession(sparkConf) } - } /** * Make sure the [[TestSparkSession]] is initialized before any tests are run. diff --git a/core/src/test/scala/org/apache/spark/sql/test/Utils.scala b/core/src/test/scala/org/apache/spark/sql/test/Utils.scala index b143a525e6..25cc451df5 100644 --- a/core/src/test/scala/org/apache/spark/sql/test/Utils.scala +++ b/core/src/test/scala/org/apache/spark/sql/test/Utils.scala @@ -10,23 +10,20 @@ import scala.collection.JavaConversions._ object Utils { - def TryResource[T](res: T)(closeOp: T => Unit)(taskOp: T => Unit): Unit = { + def TryResource[T](res: T)(closeOp: T => Unit)(taskOp: T => Unit): Unit = try { taskOp(res) } finally { closeOp(res) } - } - def writeFile(content: String, path: String): Unit = { + def writeFile(content: String, path: String): Unit = TryResource(new PrintWriter(path))(_.close()) { _.print(content) } - } - def readFile(path: String): List[String] = { + def readFile(path: String): List[String] = Files.readAllLines(Paths.get(path)).toList - } def getOrThrow(prop: Properties, key: String): String = { val jvmProp = System.getProperty(key)