Skip to content

Commit

Permalink
update scalafmt for more rules (#386)
Browse files Browse the repository at this point in the history
  • Loading branch information
birdstorm authored and ilovesoup committed Jul 9, 2018
1 parent 469d963 commit 4dc8769
Show file tree
Hide file tree
Showing 25 changed files with 108 additions and 194 deletions.
2 changes: 1 addition & 1 deletion core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@
<plugin>
<groupId>org.antipathy</groupId>
<artifactId>mvn-scalafmt</artifactId>
<version>0.5_1.3.0</version>
<version>0.7_1.5.1</version>
<configuration>
<configLocation>${project.basedir}/scalafmt.conf</configLocation>
</configuration>
Expand Down
36 changes: 17 additions & 19 deletions core/scalafmt.conf
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,20 @@ lineEndings = unix

maxColumn = 100

# TODO Somehow the following lines causes a configuration parsing error...

# newlines {
# alwaysBeforeElseAfterCurlyIf = false
# alwaysBeforeTopLevelStatements = false
# penalizeSingleSelectMultiArgList = false
# sometimesBeforeColonInMethodReturnType = false
# }
#
# optIn.breakChainOnFirstMethodDot = true
#
# rewrite.rules = [
# RedundantBraces
# RedundantParens
# SortImports
# PreferCurlyFors
# ]
#
# spaces.afterKeywordBeforeParen = true
newlines {
alwaysBeforeElseAfterCurlyIf = false
alwaysBeforeTopLevelStatements = false
penalizeSingleSelectMultiArgList = false
sometimesBeforeColonInMethodReturnType = false
}

optIn.breakChainOnFirstMethodDot = true

rewrite.rules = [
RedundantBraces
RedundantParens
SortImports
PreferCurlyFors
]

spaces.afterKeywordBeforeParen = true
15 changes: 5 additions & 10 deletions core/src/main/scala/com/pingcap/tispark/MetaManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,18 @@ import scala.collection.JavaConversions._
// Likely this needs to be merge to client project
// and serving inside metastore if any
class MetaManager(catalog: Catalog) {
def reloadMeta(): Unit = {
def reloadMeta(): Unit =
catalog.reloadCache()
}

def getDatabases: List[TiDBInfo] = {
def getDatabases: List[TiDBInfo] =
catalog.listDatabases().toList
}

def getTables(db: TiDBInfo): List[TiTableInfo] = {
def getTables(db: TiDBInfo): List[TiTableInfo] =
catalog.listTables(db).toList
}

def getTable(dbName: String, tableName: String): Option[TiTableInfo] = {
def getTable(dbName: String, tableName: String): Option[TiTableInfo] =
Option(catalog.getTable(dbName, tableName))
}

def getDatabase(dbName: String): Option[TiDBInfo] = {
def getDatabase(dbName: String): Option[TiDBInfo] =
Option(catalog.getDatabase(dbName))
}
}
17 changes: 6 additions & 11 deletions core/src/main/scala/com/pingcap/tispark/TiUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression}
import org.apache.spark.sql.types.{DataType, DataTypes, MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.{SparkSession, TiStrategy}
import org.apache.spark.{SparkConf, sql}
import org.apache.spark.{sql, SparkConf}

import scala.collection.JavaConversions._
import scala.collection.mutable
Expand All @@ -45,15 +45,14 @@ object TiUtils {

def isSupportedAggregate(aggExpr: AggregateExpression,
tiDBRelation: TiDBRelation,
blacklist: ExpressionBlacklist): Boolean = {
blacklist: ExpressionBlacklist): Boolean =
aggExpr.aggregateFunction match {
case Average(_) | Sum(_) | PromotedSum(_) | Count(_) | Min(_) | Max(_) =>
!aggExpr.isDistinct &&
aggExpr.aggregateFunction.children
.forall(isSupportedBasicExpression(_, tiDBRelation, blacklist))
case _ => false
}
}

def isSupportedBasicExpression(expr: Expression,
tiDBRelation: TiDBRelation,
Expand Down Expand Up @@ -103,9 +102,8 @@ object TiUtils {

def isSupportedFilter(expr: Expression,
source: TiDBRelation,
blacklist: ExpressionBlacklist): Boolean = {
blacklist: ExpressionBlacklist): Boolean =
isSupportedBasicExpression(expr, source, blacklist) && isPushDownSupported(expr, source)
}

// if contains UDF / functions that cannot be folded
def isSupportedGroupingExpr(expr: NamedExpression,
Expand All @@ -114,7 +112,7 @@ object TiUtils {
isSupportedBasicExpression(expr, source, blacklist) && isPushDownSupported(expr, source)

// convert tikv-java client FieldType to Spark DataType
def toSparkDataType(tp: TiDataType): DataType = {
def toSparkDataType(tp: TiDataType): DataType =
tp match {
case _: StringType => sql.types.StringType
case _: BytesType => sql.types.BinaryType
Expand Down Expand Up @@ -146,9 +144,8 @@ object TiUtils {
case _: SetType => sql.types.LongType
case _: YearType => sql.types.LongType
}
}

def fromSparkType(tp: DataType): TiDataType = {
def fromSparkType(tp: DataType): TiDataType =
tp match {
case _: sql.types.BinaryType => BytesType.BLOB
case _: sql.types.StringType => StringType.VARCHAR
Expand All @@ -158,7 +155,6 @@ object TiUtils {
case _: sql.types.TimestampType => TimestampType.TIMESTAMP
case _: sql.types.DateType => DateType.DATE
}
}

def getSchemaFromTable(table: TiTableInfo): StructType = {
val fields = new Array[StructField](table.getColumns.size())
Expand Down Expand Up @@ -231,11 +227,10 @@ object TiUtils {
StatisticsManager.initStatisticsManager(tiSession, session)
}

def getReqEstCountStr(req: TiDAGRequest): String = {
def getReqEstCountStr(req: TiDAGRequest): String =
if (req.getEstimatedCount > 0) {
import java.text.DecimalFormat
val df = new DecimalFormat("#.#")
s" EstimatedCount:${df.format(req.getEstimatedCount)}"
} else ""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ object CacheInvalidateListener {
* @param sc The spark SparkContext used for attaching a cache listener.
* @param regionManager The RegionManager to invalidate local cache.
*/
def initCacheListener(sc: SparkContext, regionManager: RegionManager): Unit = {
def initCacheListener(sc: SparkContext, regionManager: RegionManager): Unit =
if (manager == null) {
synchronized {
if (manager == null) {
Expand All @@ -67,9 +67,8 @@ object CacheInvalidateListener {
}
}
}
}

def init(sc: SparkContext, regionManager: RegionManager, manager: CacheInvalidateListener): Unit = {
def init(sc: SparkContext, regionManager: RegionManager, manager: CacheInvalidateListener): Unit =
if (sc != null && regionManager != null) {
sc.register(manager.CACHE_INVALIDATE_ACCUMULATOR, manager.CACHE_ACCUMULATOR_NAME)
sc.addSparkListener(
Expand All @@ -79,5 +78,4 @@ object CacheInvalidateListener {
)
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class PDCacheInvalidateListener(accumulator: CacheInvalidateAccumulator,
extends SparkListener {
private final val logger: Logger = Logger.getLogger(getClass.getName)

override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit =
if (accumulator != null && !accumulator.isZero && handler != null) {
synchronized {
if (!accumulator.isZero) {
Expand All @@ -42,5 +42,4 @@ class PDCacheInvalidateListener(accumulator: CacheInvalidateAccumulator,
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ object StatisticsHelper {

private[statistics] def buildHistogramsRequest(histTable: TiTableInfo,
targetTblId: Long,
startTs: Long): TiDAGRequest = {
startTs: Long): TiDAGRequest =
TiDAGRequest.Builder
.newBuilder()
.setFullTableScan(histTable)
Expand All @@ -213,14 +213,13 @@ object StatisticsHelper {
)
.setStartTs(startTs)
.build(PushDownType.NORMAL)
}

private def checkColExists(table: TiTableInfo, column: String): Boolean =
table.getColumns.exists { _.matchName(column) }

private[statistics] def buildMetaRequest(metaTable: TiTableInfo,
targetTblId: Long,
startTs: Long): TiDAGRequest = {
startTs: Long): TiDAGRequest =
TiDAGRequest.Builder
.newBuilder()
.setFullTableScan(metaTable)
Expand All @@ -231,11 +230,10 @@ object StatisticsHelper {
.addRequiredCols(metaRequiredCols.filter(checkColExists(metaTable, _)))
.setStartTs(startTs)
.build(PushDownType.NORMAL)
}

private[statistics] def buildBucketRequest(bucketTable: TiTableInfo,
targetTblId: Long,
startTs: Long): TiDAGRequest = {
startTs: Long): TiDAGRequest =
TiDAGRequest.Builder
.newBuilder()
.setFullTableScan(bucketTable)
Expand All @@ -250,5 +248,4 @@ object StatisticsHelper {
)
.setStartTs(startTs)
.build(PushDownType.NORMAL)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,8 @@ class StatisticsManager(tiSession: TiSession) {
.toSeq
}

def getTableStatistics(id: Long): TableStatistics = {
def getTableStatistics(id: Long): TableStatistics =
statisticsMap.getIfPresent(id)
}

/**
* Estimated row count of one table
Expand All @@ -259,15 +258,14 @@ class StatisticsManager(tiSession: TiSession) {
object StatisticsManager {
private var manager: StatisticsManager = _

def initStatisticsManager(tiSession: TiSession, session: SparkSession): Unit = {
def initStatisticsManager(tiSession: TiSession, session: SparkSession): Unit =
if (manager == null) {
synchronized {
if (manager == null) {
manager = new StatisticsManager(tiSession)
}
}
}
}

def reset(): Unit = manager = null

Expand Down
6 changes: 2 additions & 4 deletions core/src/main/scala/org/apache/spark/sql/TiContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ class TiContext(val session: SparkSession) extends Serializable with Logging {
conf.getBoolean("spark.tispark.statistics.auto_load", defaultValue = true)

class DebugTool {
def getRegionDistribution(dbName: String, tableName: String): Map[String, Integer] = {
def getRegionDistribution(dbName: String, tableName: String): Map[String, Integer] =
RegionUtils.getRegionDistribution(tiSession, dbName, tableName).asScala.toMap
}

/**
* Balance region leaders of a single table.
Expand Down Expand Up @@ -146,9 +145,8 @@ class TiContext(val session: SparkSession) extends Serializable with Logging {
df
}

def tidbMapDatabase(dbName: String, dbNameAsPrefix: Boolean): Unit = {
def tidbMapDatabase(dbName: String, dbNameAsPrefix: Boolean): Unit =
tidbMapDatabase(dbName, dbNameAsPrefix, autoLoad)
}

def tidbMapDatabase(dbName: String,
dbNameAsPrefix: Boolean = false,
Expand Down
35 changes: 13 additions & 22 deletions core/src/main/scala/org/apache/spark/sql/TiStrategy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,44 +65,38 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging {
new TypeBlacklist(blacklistString)
}

private def allowAggregationPushdown(): Boolean = {
private def allowAggregationPushdown(): Boolean =
sqlConf.getConfString(TiConfigConst.ALLOW_AGG_PUSHDOWN, "true").toBoolean
}

private def allowIndexDoubleRead(): Boolean = {
private def allowIndexDoubleRead(): Boolean =
sqlConf.getConfString(TiConfigConst.ALLOW_INDEX_READ, "false").toBoolean
}

private def useStreamingProcess(): Boolean = {
private def useStreamingProcess(): Boolean =
sqlConf.getConfString(TiConfigConst.COPROCESS_STREAMING, "false").toBoolean
}

private def timeZoneOffset(): Int = {
private def timeZoneOffset(): Int =
sqlConf
.getConfString(
TiConfigConst.KV_TIMEZONE_OFFSET,
String.valueOf(ZonedDateTime.now.getOffset.getTotalSeconds)
)
.toInt
}

private def pushDownType(): PushDownType = {
private def pushDownType(): PushDownType =
if (useStreamingProcess()) {
PushDownType.STREAMING
} else {
PushDownType.NORMAL
}
}

override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
override def apply(plan: LogicalPlan): Seq[SparkPlan] =
plan
.collectFirst {
case LogicalRelation(relation: TiDBRelation, _, _) =>
doPlan(relation, plan)
}
.toSeq
.flatten
}

private def toCoprocessorRDD(
source: TiDBRelation,
Expand Down Expand Up @@ -394,9 +388,8 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging {
case e if e.deterministic => e.canonicalized -> Alias(e, e.toString())()
}.toMap

def aliasPushedPartialResult(e: AggregateExpression): Alias = {
def aliasPushedPartialResult(e: AggregateExpression): Alias =
deterministicAggAliases.getOrElse(e.canonicalized, Alias(e, e.toString())())
}

val residualAggregateExpressions = aggregateExpressions.map { aggExpr =>
// As `aggExpr` is being pushing down to TiKV, we need to replace the original Catalyst
Expand Down Expand Up @@ -464,18 +457,17 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging {
aggregateExpressions: Seq[AggregateExpression],
filters: Seq[Expression],
source: TiDBRelation
): Boolean = {
): Boolean =
allowAggregationPushdown &&
filters.forall(TiUtils.isSupportedFilter(_, source, blacklist)) &&
groupingExpressions.forall(TiUtils.isSupportedGroupingExpr(_, source, blacklist)) &&
aggregateExpressions.forall(TiUtils.isSupportedAggregate(_, source, blacklist)) &&
!aggregateExpressions.exists(_.isDistinct)
}
filters.forall(TiUtils.isSupportedFilter(_, source, blacklist)) &&
groupingExpressions.forall(TiUtils.isSupportedGroupingExpr(_, source, blacklist)) &&
aggregateExpressions.forall(TiUtils.isSupportedAggregate(_, source, blacklist)) &&
!aggregateExpressions.exists(_.isDistinct)

// We do through similar logic with original Spark as in SparkStrategies.scala
// Difference is we need to test if a sub-plan can be consumed all together by TiKV
// and then we don't return (don't planLater) and plan the remaining all at once
private def doPlan(source: TiDBRelation, plan: LogicalPlan): Seq[SparkPlan] = {
private def doPlan(source: TiDBRelation, plan: LogicalPlan): Seq[SparkPlan] =
// TODO: This test should be done once for all children
plan match {
case logical.ReturnAnswer(rootPlan) =>
Expand Down Expand Up @@ -534,7 +526,6 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging {
)
case _ => Nil
}
}
}

object TiAggregation {
Expand Down
Loading

0 comments on commit 4dc8769

Please sign in to comment.