Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update scalafmt for more rules #386

Merged
merged 1 commit into from
Jul 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@
<plugin>
<groupId>org.antipathy</groupId>
<artifactId>mvn-scalafmt</artifactId>
<version>0.5_1.3.0</version>
<version>0.7_1.5.1</version>
<configuration>
<configLocation>${project.basedir}/scalafmt.conf</configLocation>
</configuration>
Expand Down
36 changes: 17 additions & 19 deletions core/scalafmt.conf
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,20 @@ lineEndings = unix

maxColumn = 100

# TODO Somehow the following lines causes a configuration parsing error...

# newlines {
# alwaysBeforeElseAfterCurlyIf = false
# alwaysBeforeTopLevelStatements = false
# penalizeSingleSelectMultiArgList = false
# sometimesBeforeColonInMethodReturnType = false
# }
#
# optIn.breakChainOnFirstMethodDot = true
#
# rewrite.rules = [
# RedundantBraces
# RedundantParens
# SortImports
# PreferCurlyFors
# ]
#
# spaces.afterKeywordBeforeParen = true
newlines {
alwaysBeforeElseAfterCurlyIf = false
alwaysBeforeTopLevelStatements = false
penalizeSingleSelectMultiArgList = false
sometimesBeforeColonInMethodReturnType = false
}

optIn.breakChainOnFirstMethodDot = true

rewrite.rules = [
RedundantBraces
RedundantParens
SortImports
PreferCurlyFors
]

spaces.afterKeywordBeforeParen = true
15 changes: 5 additions & 10 deletions core/src/main/scala/com/pingcap/tispark/MetaManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,18 @@ import scala.collection.JavaConversions._
// Likely this needs to be merge to client project
// and serving inside metastore if any
class MetaManager(catalog: Catalog) {
def reloadMeta(): Unit = {
def reloadMeta(): Unit =
catalog.reloadCache()
}

def getDatabases: List[TiDBInfo] = {
def getDatabases: List[TiDBInfo] =
catalog.listDatabases().toList
}

def getTables(db: TiDBInfo): List[TiTableInfo] = {
def getTables(db: TiDBInfo): List[TiTableInfo] =
catalog.listTables(db).toList
}

def getTable(dbName: String, tableName: String): Option[TiTableInfo] = {
def getTable(dbName: String, tableName: String): Option[TiTableInfo] =
Option(catalog.getTable(dbName, tableName))
}

def getDatabase(dbName: String): Option[TiDBInfo] = {
def getDatabase(dbName: String): Option[TiDBInfo] =
Option(catalog.getDatabase(dbName))
}
}
17 changes: 6 additions & 11 deletions core/src/main/scala/com/pingcap/tispark/TiUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression}
import org.apache.spark.sql.types.{DataType, DataTypes, MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.{SparkSession, TiStrategy}
import org.apache.spark.{SparkConf, sql}
import org.apache.spark.{sql, SparkConf}

import scala.collection.JavaConversions._
import scala.collection.mutable
Expand All @@ -45,15 +45,14 @@ object TiUtils {

def isSupportedAggregate(aggExpr: AggregateExpression,
tiDBRelation: TiDBRelation,
blacklist: ExpressionBlacklist): Boolean = {
blacklist: ExpressionBlacklist): Boolean =
aggExpr.aggregateFunction match {
case Average(_) | Sum(_) | PromotedSum(_) | Count(_) | Min(_) | Max(_) =>
!aggExpr.isDistinct &&
aggExpr.aggregateFunction.children
.forall(isSupportedBasicExpression(_, tiDBRelation, blacklist))
case _ => false
}
}

def isSupportedBasicExpression(expr: Expression,
tiDBRelation: TiDBRelation,
Expand Down Expand Up @@ -103,9 +102,8 @@ object TiUtils {

def isSupportedFilter(expr: Expression,
source: TiDBRelation,
blacklist: ExpressionBlacklist): Boolean = {
blacklist: ExpressionBlacklist): Boolean =
isSupportedBasicExpression(expr, source, blacklist) && isPushDownSupported(expr, source)
}

// if contains UDF / functions that cannot be folded
def isSupportedGroupingExpr(expr: NamedExpression,
Expand All @@ -114,7 +112,7 @@ object TiUtils {
isSupportedBasicExpression(expr, source, blacklist) && isPushDownSupported(expr, source)

// convert tikv-java client FieldType to Spark DataType
def toSparkDataType(tp: TiDataType): DataType = {
def toSparkDataType(tp: TiDataType): DataType =
tp match {
case _: StringType => sql.types.StringType
case _: BytesType => sql.types.BinaryType
Expand Down Expand Up @@ -146,9 +144,8 @@ object TiUtils {
case _: SetType => sql.types.LongType
case _: YearType => sql.types.LongType
}
}

def fromSparkType(tp: DataType): TiDataType = {
def fromSparkType(tp: DataType): TiDataType =
tp match {
case _: sql.types.BinaryType => BytesType.BLOB
case _: sql.types.StringType => StringType.VARCHAR
Expand All @@ -158,7 +155,6 @@ object TiUtils {
case _: sql.types.TimestampType => TimestampType.TIMESTAMP
case _: sql.types.DateType => DateType.DATE
}
}

def getSchemaFromTable(table: TiTableInfo): StructType = {
val fields = new Array[StructField](table.getColumns.size())
Expand Down Expand Up @@ -231,11 +227,10 @@ object TiUtils {
StatisticsManager.initStatisticsManager(tiSession, session)
}

def getReqEstCountStr(req: TiDAGRequest): String = {
def getReqEstCountStr(req: TiDAGRequest): String =
if (req.getEstimatedCount > 0) {
import java.text.DecimalFormat
val df = new DecimalFormat("#.#")
s" EstimatedCount:${df.format(req.getEstimatedCount)}"
} else ""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ object CacheInvalidateListener {
* @param sc The spark SparkContext used for attaching a cache listener.
* @param regionManager The RegionManager to invalidate local cache.
*/
def initCacheListener(sc: SparkContext, regionManager: RegionManager): Unit = {
def initCacheListener(sc: SparkContext, regionManager: RegionManager): Unit =
if (manager == null) {
synchronized {
if (manager == null) {
Expand All @@ -67,9 +67,8 @@ object CacheInvalidateListener {
}
}
}
}

def init(sc: SparkContext, regionManager: RegionManager, manager: CacheInvalidateListener): Unit = {
def init(sc: SparkContext, regionManager: RegionManager, manager: CacheInvalidateListener): Unit =
if (sc != null && regionManager != null) {
sc.register(manager.CACHE_INVALIDATE_ACCUMULATOR, manager.CACHE_ACCUMULATOR_NAME)
sc.addSparkListener(
Expand All @@ -79,5 +78,4 @@ object CacheInvalidateListener {
)
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class PDCacheInvalidateListener(accumulator: CacheInvalidateAccumulator,
extends SparkListener {
private final val logger: Logger = Logger.getLogger(getClass.getName)

override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit =
if (accumulator != null && !accumulator.isZero && handler != null) {
synchronized {
if (!accumulator.isZero) {
Expand All @@ -42,5 +42,4 @@ class PDCacheInvalidateListener(accumulator: CacheInvalidateAccumulator,
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ object StatisticsHelper {

private[statistics] def buildHistogramsRequest(histTable: TiTableInfo,
targetTblId: Long,
startTs: Long): TiDAGRequest = {
startTs: Long): TiDAGRequest =
TiDAGRequest.Builder
.newBuilder()
.setFullTableScan(histTable)
Expand All @@ -213,14 +213,13 @@ object StatisticsHelper {
)
.setStartTs(startTs)
.build(PushDownType.NORMAL)
}

private def checkColExists(table: TiTableInfo, column: String): Boolean =
table.getColumns.exists { _.matchName(column) }

private[statistics] def buildMetaRequest(metaTable: TiTableInfo,
targetTblId: Long,
startTs: Long): TiDAGRequest = {
startTs: Long): TiDAGRequest =
TiDAGRequest.Builder
.newBuilder()
.setFullTableScan(metaTable)
Expand All @@ -231,11 +230,10 @@ object StatisticsHelper {
.addRequiredCols(metaRequiredCols.filter(checkColExists(metaTable, _)))
.setStartTs(startTs)
.build(PushDownType.NORMAL)
}

private[statistics] def buildBucketRequest(bucketTable: TiTableInfo,
targetTblId: Long,
startTs: Long): TiDAGRequest = {
startTs: Long): TiDAGRequest =
TiDAGRequest.Builder
.newBuilder()
.setFullTableScan(bucketTable)
Expand All @@ -250,5 +248,4 @@ object StatisticsHelper {
)
.setStartTs(startTs)
.build(PushDownType.NORMAL)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,8 @@ class StatisticsManager(tiSession: TiSession) {
.toSeq
}

def getTableStatistics(id: Long): TableStatistics = {
def getTableStatistics(id: Long): TableStatistics =
statisticsMap.getIfPresent(id)
}

/**
* Estimated row count of one table
Expand All @@ -259,15 +258,14 @@ class StatisticsManager(tiSession: TiSession) {
object StatisticsManager {
private var manager: StatisticsManager = _

def initStatisticsManager(tiSession: TiSession, session: SparkSession): Unit = {
def initStatisticsManager(tiSession: TiSession, session: SparkSession): Unit =
if (manager == null) {
synchronized {
if (manager == null) {
manager = new StatisticsManager(tiSession)
}
}
}
}

def reset(): Unit = manager = null

Expand Down
6 changes: 2 additions & 4 deletions core/src/main/scala/org/apache/spark/sql/TiContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ class TiContext(val session: SparkSession) extends Serializable with Logging {
conf.getBoolean("spark.tispark.statistics.auto_load", defaultValue = true)

class DebugTool {
def getRegionDistribution(dbName: String, tableName: String): Map[String, Integer] = {
def getRegionDistribution(dbName: String, tableName: String): Map[String, Integer] =
RegionUtils.getRegionDistribution(tiSession, dbName, tableName).asScala.toMap
}

/**
* Balance region leaders of a single table.
Expand Down Expand Up @@ -146,9 +145,8 @@ class TiContext(val session: SparkSession) extends Serializable with Logging {
df
}

def tidbMapDatabase(dbName: String, dbNameAsPrefix: Boolean): Unit = {
def tidbMapDatabase(dbName: String, dbNameAsPrefix: Boolean): Unit =
tidbMapDatabase(dbName, dbNameAsPrefix, autoLoad)
}

def tidbMapDatabase(dbName: String,
dbNameAsPrefix: Boolean = false,
Expand Down
35 changes: 13 additions & 22 deletions core/src/main/scala/org/apache/spark/sql/TiStrategy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,44 +65,38 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging {
new TypeBlacklist(blacklistString)
}

private def allowAggregationPushdown(): Boolean = {
private def allowAggregationPushdown(): Boolean =
sqlConf.getConfString(TiConfigConst.ALLOW_AGG_PUSHDOWN, "true").toBoolean
}

private def allowIndexDoubleRead(): Boolean = {
private def allowIndexDoubleRead(): Boolean =
sqlConf.getConfString(TiConfigConst.ALLOW_INDEX_READ, "false").toBoolean
}

private def useStreamingProcess(): Boolean = {
private def useStreamingProcess(): Boolean =
sqlConf.getConfString(TiConfigConst.COPROCESS_STREAMING, "false").toBoolean
}

private def timeZoneOffset(): Int = {
private def timeZoneOffset(): Int =
sqlConf
.getConfString(
TiConfigConst.KV_TIMEZONE_OFFSET,
String.valueOf(ZonedDateTime.now.getOffset.getTotalSeconds)
)
.toInt
}

private def pushDownType(): PushDownType = {
private def pushDownType(): PushDownType =
if (useStreamingProcess()) {
PushDownType.STREAMING
} else {
PushDownType.NORMAL
}
}

override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
override def apply(plan: LogicalPlan): Seq[SparkPlan] =
plan
.collectFirst {
case LogicalRelation(relation: TiDBRelation, _, _) =>
doPlan(relation, plan)
}
.toSeq
.flatten
}

private def toCoprocessorRDD(
source: TiDBRelation,
Expand Down Expand Up @@ -394,9 +388,8 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging {
case e if e.deterministic => e.canonicalized -> Alias(e, e.toString())()
}.toMap

def aliasPushedPartialResult(e: AggregateExpression): Alias = {
def aliasPushedPartialResult(e: AggregateExpression): Alias =
deterministicAggAliases.getOrElse(e.canonicalized, Alias(e, e.toString())())
}

val residualAggregateExpressions = aggregateExpressions.map { aggExpr =>
// As `aggExpr` is being pushing down to TiKV, we need to replace the original Catalyst
Expand Down Expand Up @@ -464,18 +457,17 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging {
aggregateExpressions: Seq[AggregateExpression],
filters: Seq[Expression],
source: TiDBRelation
): Boolean = {
): Boolean =
allowAggregationPushdown &&
filters.forall(TiUtils.isSupportedFilter(_, source, blacklist)) &&
groupingExpressions.forall(TiUtils.isSupportedGroupingExpr(_, source, blacklist)) &&
aggregateExpressions.forall(TiUtils.isSupportedAggregate(_, source, blacklist)) &&
!aggregateExpressions.exists(_.isDistinct)
}
filters.forall(TiUtils.isSupportedFilter(_, source, blacklist)) &&
groupingExpressions.forall(TiUtils.isSupportedGroupingExpr(_, source, blacklist)) &&
aggregateExpressions.forall(TiUtils.isSupportedAggregate(_, source, blacklist)) &&
!aggregateExpressions.exists(_.isDistinct)

// We do through similar logic with original Spark as in SparkStrategies.scala
// Difference is we need to test if a sub-plan can be consumed all together by TiKV
// and then we don't return (don't planLater) and plan the remaining all at once
private def doPlan(source: TiDBRelation, plan: LogicalPlan): Seq[SparkPlan] = {
private def doPlan(source: TiDBRelation, plan: LogicalPlan): Seq[SparkPlan] =
// TODO: This test should be done once for all children
plan match {
case logical.ReturnAnswer(rootPlan) =>
Expand Down Expand Up @@ -534,7 +526,6 @@ class TiStrategy(context: SQLContext) extends Strategy with Logging {
)
case _ => Nil
}
}
}

object TiAggregation {
Expand Down
Loading