From 77a2fc5b521788b406bb32bcc3c637c1d7406e58 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 14 Jun 2017 11:48:32 -0700 Subject: [PATCH 001/118] Revert "[SPARK-20941][SQL] Fix SubqueryExec Reuse" This reverts commit f7cf2096fdecb8edab61c8973c07c6fc877ee32d. --- .../apache/spark/sql/internal/SQLConf.scala | 8 ----- .../execution/basicPhysicalOperators.scala | 3 -- .../apache/spark/sql/execution/subquery.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 35 ------------------- 4 files changed, 1 insertion(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3ea808926e10b..9f7c760fb9d21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -552,12 +552,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - val SUBQUERY_REUSE_ENABLED = buildConf("spark.sql.subquery.reuse") - .internal() - .doc("When true, the planner will try to find out duplicated subqueries and re-use them.") - .booleanConf - .createWithDefault(true) - val STATE_STORE_PROVIDER_CLASS = buildConf("spark.sql.streaming.stateStore.providerClass") .internal() @@ -938,8 +932,6 @@ class SQLConf extends Serializable with Logging { def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) - def subqueryReuseEnabled: Boolean = getConf(SUBQUERY_REUSE_ENABLED) - def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 04c130314388a..bd7a5c5d914c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -599,9 +599,6 @@ case class OutputFakerExec(output: Seq[Attribute], child: SparkPlan) extends Spa */ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { - // Ignore this wrapper for canonicalizing. - override lazy val canonicalized: SparkPlan = child.canonicalized - override lazy val metrics = Map( "dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"), "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 2abeadfe45362..d11045fb6ac8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -156,7 +156,7 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { - if (!conf.subqueryReuseEnabled) { + if (!conf.exchangeReuseEnabled) { return plan } // Build a hash map using schema of subqueries to avoid O(N*N) sameResult calls. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a7efcafa0166a..68f61cfab6d2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -23,12 +23,9 @@ import java.net.{MalformedURLException, URL} import java.sql.Timestamp import java.util.concurrent.atomic.AtomicBoolean -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils -import org.apache.spark.sql.execution.{ScalarSubquery, SubqueryExec} import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ @@ -703,38 +700,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } - test("Verify spark.sql.subquery.reuse") { - Seq(true, false).foreach { reuse => - withSQLConf(SQLConf.SUBQUERY_REUSE_ENABLED.key -> reuse.toString) { - val df = sql( - """ - |SELECT key, (SELECT avg(key) FROM testData) - |FROM testData - |WHERE key > (SELECT avg(key) FROM testData) - |ORDER BY key - |LIMIT 3 - """.stripMargin) - - checkAnswer(df, Row(51, 50.5) :: Row(52, 50.5) :: Row(53, 50.5) :: Nil) - - val subqueries = ArrayBuffer.empty[SubqueryExec] - df.queryExecution.executedPlan.transformAllExpressions { - case s @ ScalarSubquery(plan: SubqueryExec, _) => - subqueries += plan - s - } - - assert(subqueries.size == 2, "Two ScalarSubquery are expected in the plan") - - if (reuse) { - assert(subqueries.distinct.size == 1, "Only one ScalarSubquery exists in the plan") - } else { - assert(subqueries.distinct.size == 2, "Reuse is not expected") - } - } - } - } - test("cartesian product join") { withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { checkAnswer( From e254e868f1e640b59d8d3e0e01a5e0c488dd6e70 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 14 Jun 2017 14:28:21 -0700 Subject: [PATCH 002/118] [SPARK-21091][SQL] Move constraint code into QueryPlanConstraints ## What changes were proposed in this pull request? This patch moves constraint related code into a separate trait QueryPlanConstraints, so we don't litter QueryPlan with a lot of constraint private functions. ## How was this patch tested? This is a simple move refactoring and should be covered by existing tests. Author: Reynold Xin Closes #18298 from rxin/SPARK-21091. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 187 +--------------- .../catalyst/plans/QueryPlanConstraints.scala | 206 ++++++++++++++++++ 2 files changed, 210 insertions(+), 183 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 5ba043e17a128..8bc462e1e72c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -21,193 +21,14 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} -abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] { +abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] + extends TreeNode[PlanType] + with QueryPlanConstraints[PlanType] { + self: PlanType => def output: Seq[Attribute] - /** - * Extracts the relevant constraints from a given set of constraints based on the attributes that - * appear in the [[outputSet]]. - */ - protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { - constraints - .union(inferAdditionalConstraints(constraints)) - .union(constructIsNotNullConstraints(constraints)) - .filter(constraint => - constraint.references.nonEmpty && constraint.references.subsetOf(outputSet) && - constraint.deterministic) - } - - /** - * Infers a set of `isNotNull` constraints from null intolerant expressions as well as - * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this - * returns a constraint of the form `isNotNull(a)` - */ - private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { - // First, we propagate constraints from the null intolerant expressions. - var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints) - - // Second, we infer additional constraints from non-nullable attributes that are part of the - // operator's output - val nonNullableAttributes = output.filterNot(_.nullable) - isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull).toSet - - isNotNullConstraints -- constraints - } - - /** - * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions - * of constraints. - */ - private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] = - constraint match { - // When the root is IsNotNull, we can push IsNotNull through the child null intolerant - // expressions - case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_)) - // Constraints always return true for all the inputs. That means, null will never be returned. - // Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child - // null intolerant expressions. - case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_)) - } - - /** - * Recursively explores the expressions which are null intolerant and returns all attributes - * in these expressions. - */ - private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { - case a: Attribute => Seq(a) - case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) - case _ => Seq.empty[Attribute] - } - - // Collect aliases from expressions of the whole tree rooted by the current QueryPlan node, so - // we may avoid producing recursive constraints. - private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( - expressions.collect { - case a: Alias => (a.toAttribute, a.child) - } ++ children.flatMap(_.aliasMap)) - - /** - * Infers an additional set of constraints from a given set of equality constraints. - * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an - * additional constraint of the form `b = 5`. - * - * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)` - * as they are often useless and can lead to a non-converging set of constraints. - */ - private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - val constraintClasses = generateEquivalentConstraintClasses(constraints) - - var inferredConstraints = Set.empty[Expression] - constraints.foreach { - case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = constraints - eq - inferredConstraints ++= candidateConstraints.map(_ transform { - case a: Attribute if a.semanticEquals(l) && - !isRecursiveDeduction(r, constraintClasses) => r - }) - inferredConstraints ++= candidateConstraints.map(_ transform { - case a: Attribute if a.semanticEquals(r) && - !isRecursiveDeduction(l, constraintClasses) => l - }) - case _ => // No inference - } - inferredConstraints -- constraints - } - - /* - * Generate a sequence of expression sets from constraints, where each set stores an equivalence - * class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following - * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal - * to an selected attribute. - */ - private def generateEquivalentConstraintClasses( - constraints: Set[Expression]): Seq[Set[Expression]] = { - var constraintClasses = Seq.empty[Set[Expression]] - constraints.foreach { - case eq @ EqualTo(l: Attribute, r: Attribute) => - // Transform [[Alias]] to its child. - val left = aliasMap.getOrElse(l, l) - val right = aliasMap.getOrElse(r, r) - // Get the expression set for an equivalence constraint class. - val leftConstraintClass = getConstraintClass(left, constraintClasses) - val rightConstraintClass = getConstraintClass(right, constraintClasses) - if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) { - // Combine the two sets. - constraintClasses = constraintClasses - .diff(leftConstraintClass :: rightConstraintClass :: Nil) :+ - (leftConstraintClass ++ rightConstraintClass) - } else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty - // Update equivalence class of `left` expression. - constraintClasses = constraintClasses - .diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right) - } else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty - // Update equivalence class of `right` expression. - constraintClasses = constraintClasses - .diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left) - } else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty - // Create new equivalence constraint class since neither expression presents - // in any classes. - constraintClasses = constraintClasses :+ Set(left, right) - } - case _ => // Skip - } - - constraintClasses - } - - /* - * Get all expressions equivalent to the selected expression. - */ - private def getConstraintClass( - expr: Expression, - constraintClasses: Seq[Set[Expression]]): Set[Expression] = - constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression]) - - /* - * Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it - * has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function. - * Here we first get all expressions equal to `attr` and then check whether at least one of them - * is a child of the referenced expression. - */ - private def isRecursiveDeduction( - attr: Attribute, - constraintClasses: Seq[Set[Expression]]): Boolean = { - val expr = aliasMap.getOrElse(attr, attr) - getConstraintClass(expr, constraintClasses).exists { e => - expr.children.exists(_.semanticEquals(e)) - } - } - - /** - * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For - * example, if this set contains the expression `a = 2` then that expression is guaranteed to - * evaluate to `true` for all rows produced. - */ - lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints)) - - /** - * Returns [[constraints]] depending on the config of enabling constraint propagation. If the - * flag is disabled, simply returning an empty constraints. - */ - private[spark] def getConstraints(constraintPropagationEnabled: Boolean): ExpressionSet = - if (constraintPropagationEnabled) { - constraints - } else { - ExpressionSet(Set.empty) - } - - /** - * This method can be overridden by any child class of QueryPlan to specify a set of constraints - * based on the given operator's constraint propagation logic. These constraints are then - * canonicalized and filtered automatically to contain only those attributes that appear in the - * [[outputSet]]. - * - * See [[Canonicalize]] for more details. - */ - protected def validConstraints: Set[Expression] = Set.empty - /** * Returns the set of attributes that are output by this node. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala new file mode 100644 index 0000000000000..7d8a17d97759c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.sql.catalyst.expressions._ + + +trait QueryPlanConstraints[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[PlanType] => + + /** + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For + * example, if this set contains the expression `a = 2` then that expression is guaranteed to + * evaluate to `true` for all rows produced. + */ + lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints)) + + /** + * Returns [[constraints]] depending on the config of enabling constraint propagation. If the + * flag is disabled, simply returning an empty constraints. + */ + def getConstraints(constraintPropagationEnabled: Boolean): ExpressionSet = + if (constraintPropagationEnabled) { + constraints + } else { + ExpressionSet(Set.empty) + } + + /** + * This method can be overridden by any child class of QueryPlan to specify a set of constraints + * based on the given operator's constraint propagation logic. These constraints are then + * canonicalized and filtered automatically to contain only those attributes that appear in the + * [[outputSet]]. + * + * See [[Canonicalize]] for more details. + */ + protected def validConstraints: Set[Expression] = Set.empty + + /** + * Extracts the relevant constraints from a given set of constraints based on the attributes that + * appear in the [[outputSet]]. + */ + protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { + constraints + .union(inferAdditionalConstraints(constraints)) + .union(constructIsNotNullConstraints(constraints)) + .filter(constraint => + constraint.references.nonEmpty && constraint.references.subsetOf(outputSet) && + constraint.deterministic) + } + + /** + * Infers a set of `isNotNull` constraints from null intolerant expressions as well as + * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this + * returns a constraint of the form `isNotNull(a)` + */ + private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + // First, we propagate constraints from the null intolerant expressions. + var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints) + + // Second, we infer additional constraints from non-nullable attributes that are part of the + // operator's output + val nonNullableAttributes = output.filterNot(_.nullable) + isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull).toSet + + isNotNullConstraints -- constraints + } + + /** + * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions + * of constraints. + */ + private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] = + constraint match { + // When the root is IsNotNull, we can push IsNotNull through the child null intolerant + // expressions + case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_)) + // Constraints always return true for all the inputs. That means, null will never be returned. + // Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child + // null intolerant expressions. + case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_)) + } + + /** + * Recursively explores the expressions which are null intolerant and returns all attributes + * in these expressions. + */ + private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { + case a: Attribute => Seq(a) + case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) + case _ => Seq.empty[Attribute] + } + + // Collect aliases from expressions of the whole tree rooted by the current QueryPlan node, so + // we may avoid producing recursive constraints. + private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( + expressions.collect { + case a: Alias => (a.toAttribute, a.child) + } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints[PlanType]].aliasMap)) + + /** + * Infers an additional set of constraints from a given set of equality constraints. + * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an + * additional constraint of the form `b = 5`. + * + * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)` + * as they are often useless and can lead to a non-converging set of constraints. + */ + private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + val constraintClasses = generateEquivalentConstraintClasses(constraints) + + var inferredConstraints = Set.empty[Expression] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + val candidateConstraints = constraints - eq + inferredConstraints ++= candidateConstraints.map(_ transform { + case a: Attribute if a.semanticEquals(l) && + !isRecursiveDeduction(r, constraintClasses) => r + }) + inferredConstraints ++= candidateConstraints.map(_ transform { + case a: Attribute if a.semanticEquals(r) && + !isRecursiveDeduction(l, constraintClasses) => l + }) + case _ => // No inference + } + inferredConstraints -- constraints + } + + /** + * Generate a sequence of expression sets from constraints, where each set stores an equivalence + * class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following + * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal + * to an selected attribute. + */ + private def generateEquivalentConstraintClasses( + constraints: Set[Expression]): Seq[Set[Expression]] = { + var constraintClasses = Seq.empty[Set[Expression]] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + // Transform [[Alias]] to its child. + val left = aliasMap.getOrElse(l, l) + val right = aliasMap.getOrElse(r, r) + // Get the expression set for an equivalence constraint class. + val leftConstraintClass = getConstraintClass(left, constraintClasses) + val rightConstraintClass = getConstraintClass(right, constraintClasses) + if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) { + // Combine the two sets. + constraintClasses = constraintClasses + .diff(leftConstraintClass :: rightConstraintClass :: Nil) :+ + (leftConstraintClass ++ rightConstraintClass) + } else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty + // Update equivalence class of `left` expression. + constraintClasses = constraintClasses + .diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right) + } else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty + // Update equivalence class of `right` expression. + constraintClasses = constraintClasses + .diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left) + } else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty + // Create new equivalence constraint class since neither expression presents + // in any classes. + constraintClasses = constraintClasses :+ Set(left, right) + } + case _ => // Skip + } + + constraintClasses + } + + /** + * Get all expressions equivalent to the selected expression. + */ + private def getConstraintClass( + expr: Expression, + constraintClasses: Seq[Set[Expression]]): Set[Expression] = + constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression]) + + /** + * Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it + * has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function. + * Here we first get all expressions equal to `attr` and then check whether at least one of them + * is a child of the referenced expression. + */ + private def isRecursiveDeduction( + attr: Attribute, + constraintClasses: Seq[Set[Expression]]): Boolean = { + val expr = aliasMap.getOrElse(attr, attr) + getConstraintClass(expr, constraintClasses).exists { e => + expr.children.exists(_.semanticEquals(e)) + } + } +} From 29246744061ee96afd5f57e113ad69c354e4ba4a Mon Sep 17 00:00:00 2001 From: Li Yichao Date: Thu, 15 Jun 2017 08:08:26 +0800 Subject: [PATCH 003/118] [SPARK-19900][CORE] Remove driver when relaunching. This is https://github.com/apache/spark/pull/17888 . Below are some spark ui snapshots. Master, after worker disconnects: master_disconnect Master, after worker reconnects, notice the `running drivers` part: master_reconnects This patch, after worker disconnects: patch_disconnect This patch, after worker reconnects: ![image](https://cloud.githubusercontent.com/assets/2576762/26398037/d313769c-40aa-11e7-8613-5f157d193150.png) cc cloud-fan jiangxb1987 Author: Li Yichao Closes #18084 from liyichao/SPARK-19900-1. --- .../apache/spark/deploy/master/Master.scala | 16 ++- .../spark/deploy/master/MasterSuite.scala | 109 ++++++++++++++++++ 2 files changed, 122 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index b78ae1f3fc150..f10a41286c52f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -799,9 +799,19 @@ private[deploy] class Master( } private def relaunchDriver(driver: DriverInfo) { - driver.worker = None - driver.state = DriverState.RELAUNCHING - waitingDrivers += driver + // We must setup a new driver with a new driver id here, because the original driver may + // be still running. Consider this scenario: a worker is network partitioned with master, + // the master then relaunches driver driverID1 with a driver id driverID2, then the worker + // reconnects to master. From this point on, if driverID2 is equal to driverID1, then master + // can not distinguish the statusUpdate of the original driver and the newly relaunched one, + // for example, when DriverStateChanged(driverID1, KILLED) arrives at master, master will + // remove driverID1, so the newly relaunched driver disappears too. See SPARK-19900 for details. + removeDriver(driver.id, DriverState.RELAUNCHING, None) + val newDriver = createDriver(driver.desc) + persistenceEngine.addDriver(newDriver) + drivers.add(newDriver) + waitingDrivers += newDriver + schedule() } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 4f432e4cf21c7..6bb0eec040787 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -19,8 +19,10 @@ package org.apache.spark.deploy.master import java.util.Date import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.{HashMap, HashSet} import scala.concurrent.duration._ import scala.io.Source @@ -40,6 +42,49 @@ import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.serializer +object MockWorker { + val counter = new AtomicInteger(10000) +} + +class MockWorker(master: RpcEndpointRef, conf: SparkConf = new SparkConf) extends RpcEndpoint { + val seq = MockWorker.counter.incrementAndGet() + val id = seq.toString + override val rpcEnv: RpcEnv = RpcEnv.create("worker", "localhost", seq, + conf, new SecurityManager(conf)) + var apps = new mutable.HashMap[String, String]() + val driverIdToAppId = new mutable.HashMap[String, String]() + def newDriver(driverId: String): RpcEndpointRef = { + val name = s"driver_${drivers.size}" + rpcEnv.setupEndpoint(name, new RpcEndpoint { + override val rpcEnv: RpcEnv = MockWorker.this.rpcEnv + override def receive: PartialFunction[Any, Unit] = { + case RegisteredApplication(appId, _) => + apps(appId) = appId + driverIdToAppId(driverId) = appId + } + }) + } + + val appDesc = DeployTestUtils.createAppDesc() + val drivers = mutable.HashSet[String]() + override def receive: PartialFunction[Any, Unit] = { + case RegisteredWorker(masterRef, _, _) => + masterRef.send(WorkerLatestState(id, Nil, drivers.toSeq)) + case LaunchDriver(driverId, desc) => + drivers += driverId + master.send(RegisterApplication(appDesc, newDriver(driverId))) + case KillDriver(driverId) => + master.send(DriverStateChanged(driverId, DriverState.KILLED, None)) + drivers -= driverId + driverIdToAppId.get(driverId) match { + case Some(appId) => + apps.remove(appId) + master.send(UnregisterApplication(appId)) + } + driverIdToAppId.remove(driverId) + } +} + class MasterSuite extends SparkFunSuite with Matchers with Eventually with PrivateMethodTester with BeforeAndAfter { @@ -588,6 +633,70 @@ class MasterSuite extends SparkFunSuite } } + test("SPARK-19900: there should be a corresponding driver for the app after relaunching driver") { + val conf = new SparkConf().set("spark.worker.timeout", "1") + val master = makeMaster(conf) + master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + assert(masterState.status === RecoveryState.ALIVE, "Master is not alive") + } + val worker1 = new MockWorker(master.self) + worker1.rpcEnv.setupEndpoint("worker", worker1) + val worker1Reg = RegisterWorker( + worker1.id, + "localhost", + 9998, + worker1.self, + 10, + 1024, + "http://localhost:8080", + RpcAddress("localhost2", 10000)) + master.self.send(worker1Reg) + val driver = DeployTestUtils.createDriverDesc().copy(supervise = true) + master.self.askSync[SubmitDriverResponse](RequestSubmitDriver(driver)) + + eventually(timeout(10.seconds)) { + assert(worker1.apps.nonEmpty) + } + + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + assert(masterState.workers(0).state == WorkerState.DEAD) + } + + val worker2 = new MockWorker(master.self) + worker2.rpcEnv.setupEndpoint("worker", worker2) + master.self.send(RegisterWorker( + worker2.id, + "localhost", + 9999, + worker2.self, + 10, + 1024, + "http://localhost:8081", + RpcAddress("localhost", 10001))) + eventually(timeout(10.seconds)) { + assert(worker2.apps.nonEmpty) + } + + master.self.send(worker1Reg) + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + + val worker = masterState.workers.filter(w => w.id == worker1.id) + assert(worker.length == 1) + // make sure the `DriverStateChanged` arrives at Master. + assert(worker(0).drivers.isEmpty) + assert(worker1.apps.isEmpty) + assert(worker1.drivers.isEmpty) + assert(worker2.apps.size == 1) + assert(worker2.drivers.size == 1) + assert(masterState.activeDrivers.length == 1) + assert(masterState.activeApps.length == 1) + } + } + private def getDrivers(master: Master): HashSet[DriverInfo] = { master.invokePrivate(_drivers()) } From fffeb6d7c37ee673a32584f3b2fd3afe86af793a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 14 Jun 2017 22:11:41 -0700 Subject: [PATCH 004/118] [SPARK-21092][SQL] Wire SQLConf in logical plan and expressions ## What changes were proposed in this pull request? It is really painful to not have configs in logical plan and expressions. We had to add all sorts of hacks (e.g. pass SQLConf explicitly in functions). This patch exposes SQLConf in logical plan, using a thread local variable and a getter closure that's set once there is an active SparkSession. The implementation is a bit of a hack, since we didn't anticipate this need in the beginning (config was only exposed in physical plan). The implementation is described in `SQLConf.get`. In terms of future work, we should follow up to clean up CBO (remove the need for passing in config). ## How was this patch tested? Updated relevant tests for constraint propagation. Author: Reynold Xin Closes #18299 from rxin/SPARK-21092. --- .../sql/catalyst/optimizer/Optimizer.scala | 25 +++++------ .../spark/sql/catalyst/optimizer/joins.scala | 5 +-- .../spark/sql/catalyst/plans/QueryPlan.scala | 3 ++ .../catalyst/plans/QueryPlanConstraints.scala | 33 +++++---------- .../apache/spark/sql/internal/SQLConf.scala | 42 +++++++++++++++++++ .../BinaryComparisonSimplificationSuite.scala | 2 +- .../BooleanSimplificationSuite.scala | 2 +- .../InferFiltersFromConstraintsSuite.scala | 24 +++++------ .../optimizer/OuterJoinEliminationSuite.scala | 37 ++++++++-------- .../PropagateEmptyRelationSuite.scala | 4 +- .../optimizer/PruneFiltersSuite.scala | 36 +++++++--------- .../optimizer/SetOperationSuite.scala | 2 +- .../plans/ConstraintPropagationSuite.scala | 29 ++++++++----- .../org/apache/spark/sql/SparkSession.scala | 5 +++ 14 files changed, 141 insertions(+), 108 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d16689a34298a..3ab70fb90470c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -77,12 +77,12 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) // Operator push down PushProjectionThroughUnion, ReorderJoin(conf), - EliminateOuterJoin(conf), + EliminateOuterJoin, PushPredicateThroughJoin, PushDownPredicate, LimitPushDown(conf), ColumnPruning, - InferFiltersFromConstraints(conf), + InferFiltersFromConstraints, // Operator combine CollapseRepartition, CollapseProject, @@ -102,7 +102,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) SimplifyConditionals, RemoveDispensableExpressions, SimplifyBinaryComparison, - PruneFilters(conf), + PruneFilters, EliminateSorts, SimplifyCasts, SimplifyCaseConversionExpressions, @@ -619,14 +619,15 @@ object CollapseWindow extends Rule[LogicalPlan] { * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and * LeftSemi joins. */ -case class InferFiltersFromConstraints(conf: SQLConf) - extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = if (conf.constraintPropagationEnabled) { - inferFilters(plan) - } else { - plan - } +object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = { + if (SQLConf.get.constraintPropagationEnabled) { + inferFilters(plan) + } else { + plan + } + } private def inferFilters(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, child) => @@ -717,7 +718,7 @@ object EliminateSorts extends Rule[LogicalPlan] { * 2) by substituting a dummy empty relation when the filter will always evaluate to `false`. * 3) by eliminating the always-true conditions given the constraints on the child's output. */ -case class PruneFilters(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { +object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // If the filter condition always evaluate to true, remove the filter. case Filter(Literal(true, BooleanType), child) => child @@ -730,7 +731,7 @@ case class PruneFilters(conf: SQLConf) extends Rule[LogicalPlan] with PredicateH case f @ Filter(fc, p: LogicalPlan) => val (prunedPredicates, remainingPredicates) = splitConjunctivePredicates(fc).partition { cond => - cond.deterministic && p.getConstraints(conf.constraintPropagationEnabled).contains(cond) + cond.deterministic && p.constraints.contains(cond) } if (prunedPredicates.isEmpty) { f diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 2fe3039774423..bb97e2c808b9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -113,7 +113,7 @@ case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHe * * This rule should be executed before pushing down the Filter */ -case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { +object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Returns whether the expression returns null or false when all inputs are nulls. @@ -129,8 +129,7 @@ case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with Pred } private def buildNewJoinType(filter: Filter, join: Join): JoinType = { - val conditions = splitConjunctivePredicates(filter.condition) ++ - filter.getConstraints(conf.constraintPropagationEnabled) + val conditions = splitConjunctivePredicates(filter.condition) ++ filter.constraints val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet)) val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 8bc462e1e72c9..9130b14763e24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] @@ -27,6 +28,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] self: PlanType => + def conf: SQLConf = SQLConf.get + def output: Seq[Attribute] /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala index 7d8a17d97759c..b08a009f0dca1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala @@ -27,18 +27,20 @@ trait QueryPlanConstraints[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[Pl * example, if this set contains the expression `a = 2` then that expression is guaranteed to * evaluate to `true` for all rows produced. */ - lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints)) - - /** - * Returns [[constraints]] depending on the config of enabling constraint propagation. If the - * flag is disabled, simply returning an empty constraints. - */ - def getConstraints(constraintPropagationEnabled: Boolean): ExpressionSet = - if (constraintPropagationEnabled) { - constraints + lazy val constraints: ExpressionSet = { + if (conf.constraintPropagationEnabled) { + ExpressionSet( + validConstraints + .union(inferAdditionalConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic + } + ) } else { ExpressionSet(Set.empty) } + } /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints @@ -50,19 +52,6 @@ trait QueryPlanConstraints[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[Pl */ protected def validConstraints: Set[Expression] = Set.empty - /** - * Extracts the relevant constraints from a given set of constraints based on the attributes that - * appear in the [[outputSet]]. - */ - protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { - constraints - .union(inferAdditionalConstraints(constraints)) - .union(constructIsNotNullConstraints(constraints)) - .filter(constraint => - constraint.references.nonEmpty && constraint.references.subsetOf(outputSet) && - constraint.deterministic) - } - /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9f7c760fb9d21..6ab3a615e6cc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.internal import java.util.{Locale, NoSuchElementException, Properties, TimeZone} import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ import scala.collection.immutable @@ -64,6 +65,47 @@ object SQLConf { } } + /** + * Default config. Only used when there is no active SparkSession for the thread. + * See [[get]] for more information. + */ + private val fallbackConf = new ThreadLocal[SQLConf] { + override def initialValue: SQLConf = new SQLConf + } + + /** See [[get]] for more information. */ + def getFallbackConf: SQLConf = fallbackConf.get() + + /** + * Defines a getter that returns the SQLConf within scope. + * See [[get]] for more information. + */ + private val confGetter = new AtomicReference[() => SQLConf](() => fallbackConf.get()) + + /** + * Sets the active config object within the current scope. + * See [[get]] for more information. + */ + def setSQLConfGetter(getter: () => SQLConf): Unit = { + confGetter.set(getter) + } + + /** + * Returns the active config object within the current scope. If there is an active SparkSession, + * the proper SQLConf associated with the thread's session is used. + * + * The way this works is a little bit convoluted, due to the fact that config was added initially + * only for physical plans (and as a result not in sql/catalyst module). + * + * The first time a SparkSession is instantiated, we set the [[confGetter]] to return the + * active SparkSession's config. If there is no active SparkSession, it returns using the thread + * local [[fallbackConf]]. The reason [[fallbackConf]] is a thread local (rather than just a conf) + * is to support setting different config options for different threads so we can potentially + * run tests in parallel. At the time this feature was implemented, this was a no-op since we + * run unit tests (that does not involve SparkSession) in serial order. + */ + def get: SQLConf = confGetter.get()() + val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() .doc("The max number of iterations the optimizer and analyzer runs.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index b29e1cbd14943..2a04bd588dc1d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -37,7 +37,7 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper ConstantFolding, BooleanSimplification, SimplifyBinaryComparison, - PruneFilters(conf)) :: Nil + PruneFilters) :: Nil } val nullableRelation = LocalRelation('a.int.withNullability(true)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index c275f997ba6e9..1df0a89cf0bf1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -38,7 +38,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { NullPropagation(conf), ConstantFolding, BooleanSimplification, - PruneFilters(conf)) :: Nil + PruneFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 9a4bcdb011435..cdc9f25cf8777 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED +import org.apache.spark.sql.internal.SQLConf class InferFiltersFromConstraintsSuite extends PlanTest { @@ -32,20 +32,11 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Batch("InferAndPushDownFilters", FixedPoint(100), PushPredicateThroughJoin, PushDownPredicate, - InferFiltersFromConstraints(conf), + InferFiltersFromConstraints, CombineFilters, BooleanSimplification) :: Nil } - object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { - val batches = - Batch("InferAndPushDownFilters", FixedPoint(100), - PushPredicateThroughJoin, - PushDownPredicate, - InferFiltersFromConstraints(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), - CombineFilters) :: Nil - } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) test("filter: filter out constraints in condition") { @@ -215,8 +206,13 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } test("No inferred filter when constraint propagation is disabled") { - val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze - val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery) - comparePlans(optimized, originalQuery) + try { + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } finally { + SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index b7136703b7541..a37bc4bca2422 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Coalesce, IsNotNull} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED +import org.apache.spark.sql.internal.SQLConf class OuterJoinEliminationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -32,16 +32,7 @@ class OuterJoinEliminationSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Outer Join Elimination", Once, - EliminateOuterJoin(conf), - PushPredicateThroughJoin) :: Nil - } - - object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", Once, - EliminateSubqueryAliases) :: - Batch("Outer Join Elimination", Once, - EliminateOuterJoin(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), + EliminateOuterJoin, PushPredicateThroughJoin) :: Nil } @@ -243,19 +234,25 @@ class OuterJoinEliminationSuite extends PlanTest { } test("no outer join elimination if constraint propagation is disabled") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + try { + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) - // The predicate "x.b + y.d >= 3" will be inferred constraints like: - // "x.b != null" and "y.d != null", if constraint propagation is enabled. - // When we disable it, the predicate can't be evaluated on left or right plan and used to - // filter out nulls. So the Outer Join will not be eliminated. - val originalQuery = + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + // The predicate "x.b + y.d >= 3" will be inferred constraints like: + // "x.b != null" and "y.d != null", if constraint propagation is enabled. + // When we disable it, the predicate can't be evaluated on left or right plan and used to + // filter out nulls. So the Outer Join will not be eliminated. + val originalQuery = x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) .where("x.b".attr + "y.d".attr >= 3) - val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) - comparePlans(optimized, originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } finally { + SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 38dff4733f714..2285be16938d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -33,7 +33,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters(conf), + PruneFilters, PropagateEmptyRelation) :: Nil } @@ -45,7 +45,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters(conf)) :: Nil + PruneFilters) :: Nil } val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index 741dd0cf428d0..706634cdd29b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED class PruneFiltersSuite extends PlanTest { @@ -34,18 +35,7 @@ class PruneFiltersSuite extends PlanTest { EliminateSubqueryAliases) :: Batch("Filter Pushdown and Pruning", Once, CombineFilters, - PruneFilters(conf), - PushDownPredicate, - PushPredicateThroughJoin) :: Nil - } - - object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", Once, - EliminateSubqueryAliases) :: - Batch("Filter Pushdown and Pruning", Once, - CombineFilters, - PruneFilters(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), + PruneFilters, PushDownPredicate, PushPredicateThroughJoin) :: Nil } @@ -159,15 +149,19 @@ class PruneFiltersSuite extends PlanTest { ("tr1.a".attr > 10 || "tr1.c".attr < 10) && 'd.attr < 100) - val optimized = - OptimizeWithConstraintPropagationDisabled.execute(queryWithUselessFilter.analyze) - // When constraint propagation is disabled, the useless filter won't be pruned. - // It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant - // and duplicate filters. - val correctAnswer = tr1 - .where("tr1.a".attr > 10 || "tr1.c".attr < 10).where("tr1.a".attr > 10 || "tr1.c".attr < 10) - .join(tr2.where('d.attr < 100).where('d.attr < 100), + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + try { + val optimized = Optimize.execute(queryWithUselessFilter.analyze) + // When constraint propagation is disabled, the useless filter won't be pruned. + // It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant + // and duplicate filters. + val correctAnswer = tr1 + .where("tr1.a".attr > 10 || "tr1.c".attr < 10).where("tr1.a".attr > 10 || "tr1.c".attr < 10) + .join(tr2.where('d.attr < 100).where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized, correctAnswer) + } finally { + SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 756e0f35b2178..21b7f49e14bd5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -34,7 +34,7 @@ class SetOperationSuite extends PlanTest { CombineUnions, PushProjectionThroughUnion, PushDownPredicate, - PruneFilters(conf)) :: Nil + PruneFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 4061394b862a6..a3948d90b0e4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType} class ConstraintPropagationSuite extends SparkFunSuite { @@ -399,20 +400,26 @@ class ConstraintPropagationSuite extends SparkFunSuite { } test("enable/disable constraint propagation") { - val tr = LocalRelation('a.int, 'b.string, 'c.int) - val filterRelation = tr.where('a.attr > 10) + try { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + val filterRelation = tr.where('a.attr > 10) - verifyConstraints( - filterRelation.analyze.getConstraints(constraintPropagationEnabled = true), - filterRelation.analyze.constraints) + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true) + assert(filterRelation.analyze.constraints.nonEmpty) - assert(filterRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty) + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + assert(filterRelation.analyze.constraints.isEmpty) - val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) - .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) + val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) - verifyConstraints(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = true), - aliasedRelation.analyze.constraints) - assert(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty) + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true) + assert(aliasedRelation.analyze.constraints.nonEmpty) + + SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + assert(aliasedRelation.analyze.constraints.isEmpty) + } finally { + SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index d2bf350711936..2c38f7d7c88da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -87,6 +87,11 @@ class SparkSession private( sparkContext.assertNotStopped() + // If there is no active SparkSession, uses the default SQL conf. Otherwise, use the session's. + SQLConf.setSQLConfGetter(() => { + SparkSession.getActiveSession.map(_.sessionState.conf).getOrElse(SQLConf.getFallbackConf) + }) + /** * The version of Spark on which this application is running. * From 2051428173d8cd548702eb1a2e1c1ca76b8f2fd5 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Thu, 15 Jun 2017 13:18:19 +0800 Subject: [PATCH 005/118] [SPARK-20980][SQL] Rename `wholeFile` to `multiLine` for both CSV and JSON ### What changes were proposed in this pull request? The current option name `wholeFile` is misleading for CSV users. Currently, it is not representing a record per file. Actually, one file could have multiple records. Thus, we should rename it. Now, the proposal is `multiLine`. ### How was this patch tested? N/A Author: Xiao Li Closes #18202 from gatorsmile/renameCVSOption. --- R/pkg/R/SQLContext.R | 6 ++--- python/pyspark/sql/readwriter.py | 14 +++++------ python/pyspark/sql/streaming.py | 14 +++++------ python/pyspark/sql/tests.py | 8 +++---- .../spark/sql/catalyst/json/JSONOptions.scala | 2 +- .../apache/spark/sql/DataFrameReader.scala | 6 ++--- .../datasources/csv/CSVDataSource.scala | 6 ++--- .../datasources/csv/CSVOptions.scala | 2 +- .../datasources/json/JsonDataSource.scala | 6 ++--- .../sql/streaming/DataStreamReader.scala | 6 ++--- .../execution/datasources/csv/CSVSuite.scala | 24 +++++++++---------- .../datasources/json/JsonSuite.scala | 14 +++++------ 12 files changed, 54 insertions(+), 54 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index f5c3a749fe0a1..e3528bc7c3135 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -334,7 +334,7 @@ setMethod("toDF", signature(x = "RDD"), #' #' Loads a JSON file, returning the result as a SparkDataFrame #' By default, (\href{http://jsonlines.org/}{JSON Lines text format or newline-delimited JSON} -#' ) is supported. For JSON (one record per file), set a named property \code{wholeFile} to +#' ) is supported. For JSON (one record per file), set a named property \code{multiLine} to #' \code{TRUE}. #' It goes through the entire dataset once to determine the schema. #' @@ -348,7 +348,7 @@ setMethod("toDF", signature(x = "RDD"), #' sparkR.session() #' path <- "path/to/file.json" #' df <- read.json(path) -#' df <- read.json(path, wholeFile = TRUE) +#' df <- read.json(path, multiLine = TRUE) #' df <- jsonFile(path) #' } #' @name read.json @@ -598,7 +598,7 @@ tableToDF <- function(tableName) { #' df1 <- read.df("path/to/file.json", source = "json") #' schema <- structType(structField("name", "string"), #' structField("info", "map")) -#' df2 <- read.df(mapTypeJsonPath, "json", schema, wholeFile = TRUE) +#' df2 <- read.df(mapTypeJsonPath, "json", schema, multiLine = TRUE) #' df3 <- loadDF("data/test_table", "parquet", mergeSchema = "true") #' } #' @name read.df diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 5cf719bd65ae4..aef71f9ca7001 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -174,12 +174,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - wholeFile=None): + multiLine=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. `JSON Lines `_ (newline-delimited JSON) is supported by default. - For JSON (one record per file), set the ``wholeFile`` parameter to ``true``. + For JSON (one record per file), set the ``multiLine`` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -230,7 +230,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. - :param wholeFile: parse one record, which may span multiple lines, per file. If None is + :param multiLine: parse one record, which may span multiple lines, per file. If None is set, it uses the default value, ``false``. >>> df1 = spark.read.json('python/test_support/sql/people.json') @@ -248,7 +248,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, - timestampFormat=timestampFormat, wholeFile=wholeFile) + timestampFormat=timestampFormat, multiLine=multiLine) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -322,7 +322,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, wholeFile=None): + columnNameOfCorruptRecord=None, multiLine=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -396,7 +396,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. - :param wholeFile: parse records, which may span multiple lines. If None is + :param multiLine: parse records, which may span multiple lines. If None is set, it uses the default value, ``false``. >>> df = spark.read.csv('python/test_support/sql/ages.csv') @@ -411,7 +411,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, - columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) + columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine) if isinstance(path, basestring): path = [path] return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 76e8c4f47d8ad..58aa2468e006d 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -401,12 +401,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - wholeFile=None): + multiLine=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. `JSON Lines `_ (newline-delimited JSON) is supported by default. - For JSON (one record per file), set the ``wholeFile`` parameter to ``true``. + For JSON (one record per file), set the ``multiLine`` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -458,7 +458,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. - :param wholeFile: parse one record, which may span multiple lines, per file. If None is + :param multiLine: parse one record, which may span multiple lines, per file. If None is set, it uses the default value, ``false``. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) @@ -473,7 +473,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, - timestampFormat=timestampFormat, wholeFile=wholeFile) + timestampFormat=timestampFormat, multiLine=multiLine) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: @@ -532,7 +532,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, wholeFile=None): + columnNameOfCorruptRecord=None, multiLine=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -607,7 +607,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. - :param wholeFile: parse one record, which may span multiple lines. If None is + :param multiLine: parse one record, which may span multiple lines. If None is set, it uses the default value, ``false``. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) @@ -624,7 +624,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, - columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) + columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 845e1c7619cc4..31f932a363225 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -457,15 +457,15 @@ def test_udf_registration_returns_udf(self): df.select(add_three("id").alias("plus_three")).collect() ) - def test_wholefile_json(self): + def test_multiLine_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", - wholeFile=True) + multiLine=True) self.assertEqual(people1.collect(), people_array.collect()) - def test_wholefile_csv(self): + def test_multiline_csv(self): ages_newlines = self.spark.read.csv( - "python/test_support/sql/ages_newlines.csv", wholeFile=True) + "python/test_support/sql/ages_newlines.csv", multiLine=True) expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'), Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'), Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 7930515038355..1fd680ab64b5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -81,7 +81,7 @@ private[sql] class JSONOptions( FastDateFormat.getInstance( parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) - val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) + val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 0f96e82cedf4e..a1d8b7f4af1b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -295,7 +295,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * Loads JSON files and returns the results as a `DataFrame`. * * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `wholeFile` option to true. + * default. For JSON (one record per file), set the `multiLine` option to true. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -335,7 +335,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • * * @@ -537,7 +537,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines.
  • + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 76f121c0c955f..eadc6c94f4b3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -111,8 +111,8 @@ abstract class CSVDataSource extends Serializable { object CSVDataSource { def apply(options: CSVOptions): CSVDataSource = { - if (options.wholeFile) { - WholeFileCSVDataSource + if (options.multiLine) { + MultiLineCSVDataSource } else { TextInputCSVDataSource } @@ -197,7 +197,7 @@ object TextInputCSVDataSource extends CSVDataSource { } } -object WholeFileCSVDataSource extends CSVDataSource { +object MultiLineCSVDataSource extends CSVDataSource { override val isSplitable: Boolean = false override def readFile( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 78c16b75ee684..a13a5a34b4a84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -128,7 +128,7 @@ class CSVOptions( FastDateFormat.getInstance( parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) - val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) + val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) val maxColumns = getInt("maxColumns", 20480) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 4f2963da9ace9..5a92a71d19e78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -86,8 +86,8 @@ abstract class JsonDataSource extends Serializable { object JsonDataSource { def apply(options: JSONOptions): JsonDataSource = { - if (options.wholeFile) { - WholeFileJsonDataSource + if (options.multiLine) { + MultiLineJsonDataSource } else { TextInputJsonDataSource } @@ -147,7 +147,7 @@ object TextInputJsonDataSource extends JsonDataSource { } } -object WholeFileJsonDataSource extends JsonDataSource { +object MultiLineJsonDataSource extends JsonDataSource { override val isSplitable: Boolean = { false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 766776230257d..7e8e6394b4862 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -163,7 +163,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * Loads a JSON file stream and returns the results as a `DataFrame`. * * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `wholeFile` option to true. + * default. For JSON (one record per file), set the `multiLine` option to true. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -205,7 +205,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • * * @@ -276,7 +276,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines.
  • + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 352dba79a4c08..89d9b69dec7ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -261,10 +261,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for DROPMALFORMED parsing mode") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val cars = spark.read .format("csv") - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .options(Map("header" -> "true", "mode" -> "dropmalformed")) .load(testFile(carsFile)) @@ -284,11 +284,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for FAILFAST parsing mode") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val exception = intercept[SparkException] { spark.read .format("csv") - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .options(Map("header" -> "true", "mode" -> "failfast")) .load(testFile(carsFile)).collect() } @@ -990,13 +990,13 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val schema = new StructType().add("a", IntegerType).add("b", TimestampType) // We use `PERMISSIVE` mode by default if invalid string is given. val df1 = spark .read .option("mode", "abcd") - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schema) .csv(testFile(valueMalformedFile)) checkAnswer(df1, @@ -1011,7 +1011,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .read .option("mode", "Permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schemaWithCorrField1) .csv(testFile(valueMalformedFile)) checkAnswer(df2, @@ -1028,7 +1028,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .read .option("mode", "permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schemaWithCorrField2) .csv(testFile(valueMalformedFile)) checkAnswer(df3, @@ -1041,7 +1041,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .read .option("mode", "PERMISSIVE") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schema.add(columnNameOfCorruptRecord, IntegerType)) .csv(testFile(valueMalformedFile)) .collect @@ -1073,7 +1073,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val df = spark.read .option("header", true) - .option("wholeFile", true) + .option("multiLine", true) .csv(path.getAbsolutePath) // Check if headers have new lines in the names. @@ -1096,10 +1096,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Empty file produces empty dataframe with empty schema") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val df = spark.read.format("csv") .option("header", true) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .load(testFile(emptyFile)) assert(df.schema === spark.emptyDataFrame.schema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 65472cda9c1c0..704823ad516c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1814,7 +1814,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(new File(path).listFiles().exists(_.getName.endsWith(".gz"))) - val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDF = spark.read.option("multiLine", true).json(path) val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write .option("compression", "gZiP") @@ -1836,7 +1836,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDF = spark.read.option("multiLine", true).json(path) val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write.json(jsonDir) @@ -1865,7 +1865,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDF = spark.read.option("multiLine", true).json(path) // no corrupt record column should be created assert(jsonDF.schema === StructType(Seq())) // only the first object should be read @@ -1886,7 +1886,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).option("mode", "PERMISSIVE").json(path) + val jsonDF = spark.read.option("multiLine", true).option("mode", "PERMISSIVE").json(path) assert(jsonDF.count() === corruptRecordCount) assert(jsonDF.schema === new StructType() .add("_corrupt_record", StringType) @@ -1917,7 +1917,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).option("mode", "DROPMALFORMED").json(path) + val jsonDF = spark.read.option("multiLine", true).option("mode", "DROPMALFORMED").json(path) checkAnswer(jsonDF, Seq(Row("test"))) } } @@ -1940,7 +1940,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // `FAILFAST` mode should throw an exception for corrupt records. val exceptionOne = intercept[SparkException] { spark.read - .option("wholeFile", true) + .option("multiLine", true) .option("mode", "FAILFAST") .json(path) } @@ -1949,7 +1949,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val exceptionTwo = intercept[SparkException] { spark.read - .option("wholeFile", true) + .option("multiLine", true) .option("mode", "FAILFAST") .schema(schema) .json(path) From b32b2123ddca66e00acf4c9d956232e07f779f9f Mon Sep 17 00:00:00 2001 From: ALeksander Eskilson Date: Thu, 15 Jun 2017 13:45:08 +0800 Subject: [PATCH 006/118] [SPARK-18016][SQL][CATALYST] Code Generation: Constant Pool Limit - Class Splitting ## What changes were proposed in this pull request? This pull-request exclusively includes the class splitting feature described in #16648. When code for a given class would grow beyond 1600k bytes, a private, nested sub-class is generated into which subsequent functions are inlined. Additional sub-classes are generated as the code threshold is met subsequent times. This code includes 3 changes: 1. Includes helper maps, lists, and functions for keeping track of sub-classes during code generation (included in the `CodeGenerator` class). These helper functions allow nested classes and split functions to be initialized/declared/inlined to the appropriate locations in the various projection classes. 2. Changes `addNewFunction` to return a string to support instances where a split function is inlined to a nested class and not the outer class (and so must be invoked using the class-qualified name). Uses of `addNewFunction` throughout the codebase are modified so that the returned name is properly used. 3. Removes instances of the `this` keyword when used on data inside generated classes. All state declared in the outer class is by default global and accessible to the nested classes. However, if a reference to global state in a nested class is prepended with the `this` keyword, it would attempt to reference state belonging to the nested class (which would not exist), rather than the correct variable belonging to the outer class. ## How was this patch tested? Added a test case to the `GeneratedProjectionSuite` that increases the number of columns tested in various projections to a threshold that would previously have triggered a `JaninoRuntimeException` for the Constant Pool. Note: This PR does not address the second Constant Pool issue with code generation (also mentioned in #16648): excess global mutable state. A second PR may be opened to resolve that issue. Author: ALeksander Eskilson Closes #18075 from bdrillard/class_splitting_only. --- sql/catalyst/pom.xml | 7 + .../sql/catalyst/expressions/ScalaUDF.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 140 +++++++++++++++--- .../codegen/GenerateMutableProjection.scala | 17 ++- .../codegen/GenerateOrdering.scala | 3 + .../codegen/GeneratePredicate.scala | 3 + .../codegen/GenerateSafeProjection.scala | 9 +- .../codegen/GenerateUnsafeProjection.scala | 9 +- .../expressions/complexTypeCreator.scala | 6 +- .../expressions/conditionalExpressions.scala | 4 +- .../sql/catalyst/expressions/generators.scala | 6 +- .../expressions/objects/objects.scala | 2 +- .../codegen/GeneratedProjectionSuite.scala | 72 +++++++-- sql/core/pom.xml | 7 + .../sql/execution/ColumnarBatchScan.scala | 6 +- .../apache/spark/sql/execution/SortExec.scala | 4 +- .../sql/execution/WholeStageCodegenExec.scala | 3 + .../aggregate/HashAggregateExec.scala | 8 +- .../execution/basicPhysicalOperators.scala | 11 +- .../columnar/GenerateColumnAccessor.scala | 13 +- .../execution/joins/SortMergeJoinExec.scala | 2 +- .../apache/spark/sql/execution/limit.scala | 2 +- 22 files changed, 259 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 8d80f8eca5dba..36948ba52b064 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -131,6 +131,13 @@ + + org.scalatest + scalatest-maven-plugin + + -Xmx4g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + + org.antlr antlr4-maven-plugin diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index af1eba26621bd..a54f6d0e11147 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -988,7 +988,7 @@ case class ScalaUDF( val converterTerm = ctx.freshName("converter") val expressionIdx = ctx.references.size - 1 ctx.addMutableState(converterClassName, converterTerm, - s"this.$converterTerm = ($converterClassName)$typeConvertersClassName" + + s"$converterTerm = ($converterClassName)$typeConvertersClassName" + s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" + s"references[$expressionIdx]).getChildren().apply($index))).dataType());") converterTerm @@ -1005,7 +1005,7 @@ case class ScalaUDF( // Generate codes used to convert the returned value of user-defined functions to Catalyst type val catalystConverterTerm = ctx.freshName("catalystConverter") ctx.addMutableState(converterClassName, catalystConverterTerm, - s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + + s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + s".createToCatalystConverter($scalaUDF.dataType());") val resultTerm = ctx.freshName("result") @@ -1019,7 +1019,7 @@ case class ScalaUDF( val funcTerm = ctx.freshName("udf") ctx.addMutableState(funcClassName, funcTerm, - s"this.$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") + s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") // codegen for children expressions val evals = children.map(_.genCode(ctx)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index fd9780245fcfb..5158949b95629 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -28,7 +28,6 @@ import scala.util.control.NonFatal import com.google.common.cache.{CacheBuilder, CacheLoader} import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} -import org.apache.commons.lang3.exception.ExceptionUtils import org.codehaus.commons.compiler.CompileException import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, JaninoRuntimeException, SimpleCompiler} import org.codehaus.janino.util.ClassFile @@ -113,7 +112,7 @@ class CodegenContext { val idx = references.length references += obj val clsName = Option(className).getOrElse(obj.getClass.getName) - addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];") + addMutableState(clsName, term, s"$term = ($clsName) references[$idx];") term } @@ -202,16 +201,6 @@ class CodegenContext { partitionInitializationStatements.mkString("\n") } - /** - * Holding all the functions those will be added into generated class. - */ - val addedFunctions: mutable.Map[String, String] = - mutable.Map.empty[String, String] - - def addNewFunction(funcName: String, funcCode: String): Unit = { - addedFunctions += ((funcName, funcCode)) - } - /** * Holds expressions that are equivalent. Used to perform subexpression elimination * during codegen. @@ -233,10 +222,118 @@ class CodegenContext { // The collection of sub-expression result resetting methods that need to be called on each row. val subexprFunctions = mutable.ArrayBuffer.empty[String] - def declareAddedFunctions(): String = { - addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") + val outerClassName = "OuterClass" + + /** + * Holds the class and instance names to be generated, where `OuterClass` is a placeholder + * standing for whichever class is generated as the outermost class and which will contain any + * nested sub-classes. All other classes and instance names in this list will represent private, + * nested sub-classes. + */ + private val classes: mutable.ListBuffer[(String, String)] = + mutable.ListBuffer[(String, String)](outerClassName -> null) + + // A map holding the current size in bytes of each class to be generated. + private val classSize: mutable.Map[String, Int] = + mutable.Map[String, Int](outerClassName -> 0) + + // Nested maps holding function names and their code belonging to each class. + private val classFunctions: mutable.Map[String, mutable.Map[String, String]] = + mutable.Map(outerClassName -> mutable.Map.empty[String, String]) + + // Returns the size of the most recently added class. + private def currClassSize(): Int = classSize(classes.head._1) + + // Returns the class name and instance name for the most recently added class. + private def currClass(): (String, String) = classes.head + + // Adds a new class. Requires the class' name, and its instance name. + private def addClass(className: String, classInstance: String): Unit = { + classes.prepend(className -> classInstance) + classSize += className -> 0 + classFunctions += className -> mutable.Map.empty[String, String] + } + + /** + * Adds a function to the generated class. If the code for the `OuterClass` grows too large, the + * function will be inlined into a new private, nested class, and a class-qualified name for the + * function will be returned. Otherwise, the function will be inined to the `OuterClass` the + * simple `funcName` will be returned. + * + * @param funcName the class-unqualified name of the function + * @param funcCode the body of the function + * @param inlineToOuterClass whether the given code must be inlined to the `OuterClass`. This + * can be necessary when a function is declared outside of the context + * it is eventually referenced and a returned qualified function name + * cannot otherwise be accessed. + * @return the name of the function, qualified by class if it will be inlined to a private, + * nested sub-class + */ + def addNewFunction( + funcName: String, + funcCode: String, + inlineToOuterClass: Boolean = false): String = { + // The number of named constants that can exist in the class is limited by the Constant Pool + // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a + // threshold of 1600k bytes to determine when a function should be inlined to a private, nested + // sub-class. + val (className, classInstance) = if (inlineToOuterClass) { + outerClassName -> "" + } else if (currClassSize > 1600000) { + val className = freshName("NestedClass") + val classInstance = freshName("nestedClassInstance") + + addClass(className, classInstance) + + className -> classInstance + } else { + currClass() + } + + classSize(className) += funcCode.length + classFunctions(className) += funcName -> funcCode + + if (className == outerClassName) { + funcName + } else { + + s"$classInstance.$funcName" + } + } + + /** + * Instantiates all nested, private sub-classes as objects to the `OuterClass` + */ + private[sql] def initNestedClasses(): String = { + // Nested, private sub-classes have no mutable state (though they do reference the outer class' + // mutable state), so we declare and initialize them inline to the OuterClass. + classes.filter(_._1 != outerClassName).map { + case (className, classInstance) => + s"private $className $classInstance = new $className();" + }.mkString("\n") + } + + /** + * Declares all function code that should be inlined to the `OuterClass`. + */ + private[sql] def declareAddedFunctions(): String = { + classFunctions(outerClassName).values.mkString("\n") } + /** + * Declares all nested, private sub-classes and the function code that should be inlined to them. + */ + private[sql] def declareNestedClasses(): String = { + classFunctions.filterKeys(_ != outerClassName).map { + case (className, functions) => + s""" + |private class $className { + | ${functions.values.mkString("\n")} + |} + """.stripMargin + } + }.mkString("\n") + final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -556,8 +653,7 @@ class CodegenContext { return 0; } """ - addNewFunction(compareFunc, funcCode) - s"this.$compareFunc($c1, $c2)" + s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" case schema: StructType => val comparisons = GenerateOrdering.genComparisons(this, schema) val compareFunc = freshName("compareStruct") @@ -573,8 +669,7 @@ class CodegenContext { return 0; } """ - addNewFunction(compareFunc, funcCode) - s"this.$compareFunc($c1, $c2)" + s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => @@ -629,7 +724,9 @@ class CodegenContext { /** * Splits the generated code of expressions into multiple functions, because function has - * 64kb code size limit in JVM + * 64kb code size limit in JVM. If the class to which the function would be inlined would grow + * beyond 1600kb, we declare a private, nested sub-class, and the function is inlined to it + * instead, because classes have a constant pool limit of 65,536 named values. * * @param row the variable name of row that is used by expressions * @param expressions the codes to evaluate expressions. @@ -689,7 +786,6 @@ class CodegenContext { |} """.stripMargin addNewFunction(name, code) - name } foldFunctions(functions.map(name => s"$name(${arguments.map(_._2).mkString(", ")})")) @@ -773,8 +869,6 @@ class CodegenContext { |} """.stripMargin - addNewFunction(fnName, fn) - // Add a state and a mapping of the common subexpressions that are associate with this // state. Adding this expression to subExprEliminationExprMap means it will call `fn` // when it is code generated. This decision should be a cost based one. @@ -792,7 +886,7 @@ class CodegenContext { addMutableState(javaType(expr.dataType), value, s"$value = ${defaultValue(expr.dataType)};") - subexprFunctions += s"$fnName($INPUT_ROW);" + subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) e.foreach(subExprEliminationExprs.put(_, state)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 4d732445544a8..635766835029b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -63,21 +63,21 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP if (e.nullable) { val isNull = s"isNull_$i" val value = s"value_$i" - ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") + ctx.addMutableState("boolean", isNull, s"$isNull = true;") ctx.addMutableState(ctx.javaType(e.dataType), value, - s"this.$value = ${ctx.defaultValue(e.dataType)};") + s"$value = ${ctx.defaultValue(e.dataType)};") s""" ${ev.code} - this.$isNull = ${ev.isNull}; - this.$value = ${ev.value}; + $isNull = ${ev.isNull}; + $value = ${ev.value}; """ } else { val value = s"value_$i" ctx.addMutableState(ctx.javaType(e.dataType), value, - s"this.$value = ${ctx.defaultValue(e.dataType)};") + s"$value = ${ctx.defaultValue(e.dataType)};") s""" ${ev.code} - this.$value = ${ev.value}; + $value = ${ev.value}; """ } } @@ -87,7 +87,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(index).map { case (e, i) => - val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i") + val ev = ExprCode("", s"isNull_$i", s"value_$i") ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } @@ -135,6 +135,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP $allUpdates return mutableRow; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index f7fc2d54a047b..a31943255b995 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -179,6 +179,9 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR $comparisons return 0; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index dcd1ed96a298e..b400783bb5e55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -72,6 +72,9 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] { ${eval.code} return !${eval.isNull} && ${eval.value}; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index b1cb6edefb852..f708aeff2b146 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -49,7 +49,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val output = ctx.freshName("safeRow") val values = ctx.freshName("values") // These expressions could be split into multiple functions - ctx.addMutableState("Object[]", values, s"this.$values = null;") + ctx.addMutableState("Object[]", values, s"$values = null;") val rowClass = classOf[GenericInternalRow].getName @@ -65,10 +65,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val allFields = ctx.splitExpressions(tmp, fieldWriters) val code = s""" final InternalRow $tmp = $input; - this.$values = new Object[${schema.length}]; + $values = new Object[${schema.length}]; $allFields final InternalRow $output = new $rowClass($values); - this.$values = null; + $values = null; """ ExprCode(code, "false", output) @@ -184,6 +184,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] $allExpressions return mutableRow; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index efbbc038bd33b..6be69d119bf8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -82,7 +82,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val rowWriterClass = classOf[UnsafeRowWriter].getName val rowWriter = ctx.freshName("rowWriter") ctx.addMutableState(rowWriterClass, rowWriter, - s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") + s"$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, @@ -182,7 +182,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.freshName("arrayWriter") ctx.addMutableState(arrayWriterClass, arrayWriter, - s"this.$arrayWriter = new $arrayWriterClass();") + s"$arrayWriter = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") val element = ctx.freshName("element") @@ -321,7 +321,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val holder = ctx.freshName("holder") val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, holder, - s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});") + s"$holder = new $holderClass($result, ${numVarLenFields * 32});") val resetBufferHolder = if (numVarLenFields == 0) { "" @@ -402,6 +402,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${eval.code.trim} return ${eval.value}; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index b6675a84ece48..98c4cbee38dee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -93,7 +93,7 @@ private [sql] object GenArrayData { if (!ctx.isPrimitiveType(elementType)) { val genericArrayClass = classOf[GenericArrayData].getName ctx.addMutableState("Object[]", arrayName, - s"this.$arrayName = new Object[${numElements}];") + s"$arrayName = new Object[${numElements}];") val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { @@ -340,7 +340,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"this.$values = null;") + ctx.addMutableState("Object[]", values, s"$values = null;") ev.copy(code = s""" $values = new Object[${valExprs.size}];""" + @@ -357,7 +357,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc }) + s""" final InternalRow ${ev.value} = new $rowClass($values); - this.$values = null; + $values = null; """, isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index ee365fe636614..ae8efb673f91c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -131,8 +131,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi | $globalValue = ${ev.value}; |} """.stripMargin - ctx.addNewFunction(funcName, funcBody) - (funcName, globalIsNull, globalValue) + val fullFuncName = ctx.addNewFunction(funcName, funcBody) + (fullFuncName, globalIsNull, globalValue) } override def toString: String = s"if ($predicate) $trueValue else $falseValue" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index e023f0567ea87..c217aa875d9eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -200,7 +200,7 @@ case class Stack(children: Seq[Expression]) extends Generator { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Rows - we write these into an array. val rowData = ctx.freshName("rows") - ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];") + ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];") val values = children.tail val dataTypes = values.take(numFields).map(_.dataType) val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row => @@ -209,7 +209,7 @@ case class Stack(children: Seq[Expression]) extends Generator { if (index < values.length) values(index) else Literal(null, dataTypes(col)) } val eval = CreateStruct(fields).genCode(ctx) - s"${eval.code}\nthis.$rowData[$row] = ${eval.value};" + s"${eval.code}\n$rowData[$row] = ${eval.value};" }) // Create the collection. @@ -217,7 +217,7 @@ case class Stack(children: Seq[Expression]) extends Generator { ctx.addMutableState( s"$wrapperClass", ev.value, - s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);") + s"${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);") ev.copy(code = code, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 5bb0febc943f2..073993cccdf8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1163,7 +1163,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val code = s""" ${instanceGen.code} - this.${javaBeanInstance} = ${instanceGen.value}; + ${javaBeanInstance} = ${instanceGen.value}; if (!${instanceGen.isNull}) { $initializeCode } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index b69b74b4240bd..58ea5b9cb52d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -33,10 +33,10 @@ class GeneratedProjectionSuite extends SparkFunSuite { test("generated projections on wider table") { val N = 1000 - val wideRow1 = new GenericInternalRow((1 to N).toArray[Any]) + val wideRow1 = new GenericInternalRow((0 until N).toArray[Any]) val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) val wideRow2 = new GenericInternalRow( - (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) + (0 until N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) val schema2 = StructType((1 to N).map(i => StructField("", StringType))) val joined = new JoinedRow(wideRow1, wideRow2) val joinedSchema = StructType(schema1 ++ schema2) @@ -48,12 +48,12 @@ class GeneratedProjectionSuite extends SparkFunSuite { val unsafeProj = UnsafeProjection.create(nestedSchema) val unsafe: UnsafeRow = unsafeProj(nested) (0 until N).foreach { i => - val s = UTF8String.fromString((i + 1).toString) - assert(i + 1 === unsafe.getInt(i + 2)) + val s = UTF8String.fromString(i.toString) + assert(i === unsafe.getInt(i + 2)) assert(s === unsafe.getUTF8String(i + 2 + N)) - assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i)) + assert(i === unsafe.getStruct(0, N * 2).getInt(i)) assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) - assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i)) + assert(i === unsafe.getStruct(1, N * 2).getInt(i)) assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) } @@ -62,13 +62,63 @@ class GeneratedProjectionSuite extends SparkFunSuite { val result = safeProj(unsafe) // Can't compare GenericInternalRow with JoinedRow directly (0 until N).foreach { i => - val r = i + 1 - val s = UTF8String.fromString((i + 1).toString) - assert(r === result.getInt(i + 2)) + val s = UTF8String.fromString(i.toString) + assert(i === result.getInt(i + 2)) assert(s === result.getUTF8String(i + 2 + N)) - assert(r === result.getStruct(0, N * 2).getInt(i)) + assert(i === result.getStruct(0, N * 2).getInt(i)) assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) - assert(r === result.getStruct(1, N * 2).getInt(i)) + assert(i === result.getStruct(1, N * 2).getInt(i)) + assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated MutableProjection + val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val mutableProj = GenerateMutableProjection.generate(exprs) + val row1 = mutableProj(result) + assert(result === row1) + val row2 = mutableProj(result) + assert(result === row2) + } + + test("SPARK-18016: generated projections on wider table requiring class-splitting") { + val N = 4000 + val wideRow1 = new GenericInternalRow((0 until N).toArray[Any]) + val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) + val wideRow2 = new GenericInternalRow( + (0 until N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) + val schema2 = StructType((1 to N).map(i => StructField("", StringType))) + val joined = new JoinedRow(wideRow1, wideRow2) + val joinedSchema = StructType(schema1 ++ schema2) + val nested = new JoinedRow(InternalRow(joined, joined), joined) + val nestedSchema = StructType( + Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) + + // test generated UnsafeProjection + val unsafeProj = UnsafeProjection.create(nestedSchema) + val unsafe: UnsafeRow = unsafeProj(nested) + (0 until N).foreach { i => + val s = UTF8String.fromString(i.toString) + assert(i === unsafe.getInt(i + 2)) + assert(s === unsafe.getUTF8String(i + 2 + N)) + assert(i === unsafe.getStruct(0, N * 2).getInt(i)) + assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) + assert(i === unsafe.getStruct(1, N * 2).getInt(i)) + assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated SafeProjection + val safeProj = FromUnsafeProjection(nestedSchema) + val result = safeProj(unsafe) + // Can't compare GenericInternalRow with JoinedRow directly + (0 until N).foreach { i => + val s = UTF8String.fromString(i.toString) + assert(i === result.getInt(i + 2)) + assert(s === result.getUTF8String(i + 2 + N)) + assert(i === result.getStruct(0, N * 2).getInt(i)) + assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) + assert(i === result.getStruct(1, N * 2).getInt(i)) assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index fe4be963e8184..7327c9b0c9c50 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -183,6 +183,13 @@ + + org.scalatest + scalatest-maven-plugin + + -Xmx4g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + + org.codehaus.mojo build-helper-maven-plugin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index e86116680a57a..74a47da2deef2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -93,7 +93,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } val nextBatch = ctx.freshName("nextBatch") - ctx.addNewFunction(nextBatch, + val nextBatchFuncName = ctx.addNewFunction(nextBatch, s""" |private void $nextBatch() throws java.io.IOException { | long getBatchStart = System.nanoTime(); @@ -121,7 +121,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } s""" |if ($batch == null) { - | $nextBatch(); + | $nextBatchFuncName(); |} |while ($batch != null) { | int $numRows = $batch.numRows(); @@ -133,7 +133,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { | } | $idx = $numRows; | $batch = null; - | $nextBatch(); + | $nextBatchFuncName(); |} |$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000)); |$scanTimeTotalNs = 0; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index f98ae82574d20..ff71fd4dc7bb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -141,7 +141,7 @@ case class SortExec( ctx.addMutableState("scala.collection.Iterator", sortedIterator, "") val addToSorter = ctx.freshName("addToSorter") - ctx.addNewFunction(addToSorter, + val addToSorterFuncName = ctx.addNewFunction(addToSorter, s""" | private void $addToSorter() throws java.io.IOException { | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} @@ -160,7 +160,7 @@ case class SortExec( s""" | if ($needToSort) { | long $spillSizeBefore = $metrics.memoryBytesSpilled(); - | $addToSorter(); + | $addToSorterFuncName(); | $sortedIterator = $sorterVariable.sort(); | $sortTime.add($sorterVariable.getSortTimeNanos() / 1000000); | $peakMemory.add($sorterVariable.getPeakMemoryUsage()); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index ac30b11557adb..0bd28e36135c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -357,6 +357,9 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co protected void processNext() throws java.io.IOException { ${code.trim} } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} } """.trim diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9df5e58f70add..5027a615ced7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -212,7 +212,7 @@ case class HashAggregateExec( } val doAgg = ctx.freshName("doAggregateWithoutKey") - ctx.addNewFunction(doAgg, + val doAggFuncName = ctx.addNewFunction(doAgg, s""" | private void $doAgg() throws java.io.IOException { | // initialize aggregation buffer @@ -229,7 +229,7 @@ case class HashAggregateExec( | while (!$initAgg) { | $initAgg = true; | long $beforeAgg = System.nanoTime(); - | $doAgg(); + | $doAggFuncName(); | $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); | | // output the result @@ -600,7 +600,7 @@ case class HashAggregateExec( } else "" } - ctx.addNewFunction(doAgg, + val doAggFuncName = ctx.addNewFunction(doAgg, s""" ${generateGenerateCode} private void $doAgg() throws java.io.IOException { @@ -681,7 +681,7 @@ case class HashAggregateExec( if (!$initAgg) { $initAgg = true; long $beforeAgg = System.nanoTime(); - $doAgg(); + $doAggFuncName(); $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index bd7a5c5d914c1..f3ca8397047fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -281,10 +281,8 @@ case class SampleExec( val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") ctx.copyResult = true - ctx.addMutableState(s"$samplerClass", sampler, - s"$initSampler();") - ctx.addNewFunction(initSampler, + val initSamplerFuncName = ctx.addNewFunction(initSampler, s""" | private void $initSampler() { | $sampler = new $samplerClass($upperBound - $lowerBound, false); @@ -299,6 +297,9 @@ case class SampleExec( | } """.stripMargin.trim) + ctx.addMutableState(s"$samplerClass", sampler, + s"$initSamplerFuncName();") + val samplingCount = ctx.freshName("samplingCount") s""" | int $samplingCount = $sampler.sample(); @@ -394,7 +395,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // The default size of a batch, which must be positive integer val batchSize = 1000 - ctx.addNewFunction("initRange", + val initRangeFuncName = ctx.addNewFunction("initRange", s""" | private void initRange(int idx) { | $BigInt index = $BigInt.valueOf(idx); @@ -451,7 +452,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | // initialize Range | if (!$initTerm) { | $initTerm = true; - | initRange(partitionIndex); + | $initRangeFuncName(partitionIndex); | } | | while (true) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 14024d6c10558..d3fa0dcd2d7c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -128,9 +128,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera } else { val groupedAccessorsItr = initializeAccessors.grouped(numberOfStatementsThreshold) val groupedExtractorsItr = extractors.grouped(numberOfStatementsThreshold) - var groupedAccessorsLength = 0 - groupedAccessorsItr.zipWithIndex.foreach { case (body, i) => - groupedAccessorsLength += 1 + val accessorNames = groupedAccessorsItr.zipWithIndex.map { case (body, i) => val funcName = s"accessors$i" val funcCode = s""" |private void $funcName() { @@ -139,7 +137,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera """.stripMargin ctx.addNewFunction(funcName, funcCode) } - groupedExtractorsItr.zipWithIndex.foreach { case (body, i) => + val extractorNames = groupedExtractorsItr.zipWithIndex.map { case (body, i) => val funcName = s"extractors$i" val funcCode = s""" |private void $funcName() { @@ -148,8 +146,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera """.stripMargin ctx.addNewFunction(funcName, funcCode) } - ((0 to groupedAccessorsLength - 1).map { i => s"accessors$i();" }.mkString("\n"), - (0 to groupedAccessorsLength - 1).map { i => s"extractors$i();" }.mkString("\n")) + (accessorNames.map { accessorName => s"$accessorName();" }.mkString("\n"), + extractorNames.map { extractorName => s"$extractorName();"}.mkString("\n")) } val codeBody = s""" @@ -224,6 +222,9 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera unsafeRow.setTotalSize(bufferHolder.totalSize()); return unsafeRow; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 26fb6103953fc..8445c26eeee58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -478,7 +478,7 @@ case class SortMergeJoinExec( | } | return false; // unreachable |} - """.stripMargin) + """.stripMargin, inlineToOuterClass = true) (leftRow, matches) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 757fe2185d302..73a0f8735ed45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -75,7 +75,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { protected boolean stopEarly() { return $stopEarly; } - """) + """, inlineToOuterClass = true) val countTerm = ctx.freshName("count") ctx.addMutableState("int", countTerm, s"$countTerm = 0;") s""" From 1bf55e396c7b995a276df61d9a4eb8e60bcee334 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 14 Jun 2017 23:08:05 -0700 Subject: [PATCH 007/118] [SPARK-20980][DOCS] update doc to reflect multiLine change ## What changes were proposed in this pull request? doc only change ## How was this patch tested? manually Author: Felix Cheung Closes #18312 from felixcheung/sqljsonwholefiledoc. --- docs/sql-programming-guide.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 314ff6ef80d29..8e722ae6adca6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -998,7 +998,7 @@ Note that the file that is offered as _a json file_ is not a typical JSON file. line must contain a separate, self-contained valid JSON object. For more information, please see [JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). -For a regular multi-line JSON file, set the `wholeFile` option to `true`. +For a regular multi-line JSON file, set the `multiLine` option to `true`. {% include_example json_dataset scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} @@ -1012,7 +1012,7 @@ Note that the file that is offered as _a json file_ is not a typical JSON file. line must contain a separate, self-contained valid JSON object. For more information, please see [JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). -For a regular multi-line JSON file, set the `wholeFile` option to `true`. +For a regular multi-line JSON file, set the `multiLine` option to `true`. {% include_example json_dataset java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} @@ -1025,7 +1025,7 @@ Note that the file that is offered as _a json file_ is not a typical JSON file. line must contain a separate, self-contained valid JSON object. For more information, please see [JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). -For a regular multi-line JSON file, set the `wholeFile` parameter to `True`. +For a regular multi-line JSON file, set the `multiLine` parameter to `True`. {% include_example json_dataset python/sql/datasource.py %} @@ -1039,7 +1039,7 @@ Note that the file that is offered as _a json file_ is not a typical JSON file. line must contain a separate, self-contained valid JSON object. For more information, please see [JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). -For a regular multi-line JSON file, set a named parameter `wholeFile` to `TRUE`. +For a regular multi-line JSON file, set a named parameter `multiLine` to `TRUE`. {% include_example json_dataset r/RSparkSQLExample.R %} From 7dc3e697c74864a4e3cca7342762f1427058b3c3 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 16 Jun 2017 00:06:54 +0800 Subject: [PATCH 008/118] [SPARK-16251][SPARK-20200][CORE][TEST] Flaky test: org.apache.spark.rdd.LocalCheckpointSuite.missing checkpoint block fails with informative message ## What changes were proposed in this pull request? Currently we don't wait to confirm the removal of the block from the slave's BlockManager, if the removal takes too much time, we will fail the assertion in this test case. The failure can be easily reproduced if we sleep for a while before we remove the block in BlockManagerSlaveEndpoint.receiveAndReply(). ## How was this patch tested? N/A Author: Xingbo Jiang Closes #18314 from jiangxb1987/LocalCheckpointSuite. --- .../scala/org/apache/spark/rdd/LocalCheckpointSuite.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala index 2802cd975292c..9e204f5cc33fe 100644 --- a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.rdd +import scala.concurrent.duration._ + +import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} + import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -168,6 +172,10 @@ class LocalCheckpointSuite extends SparkFunSuite with LocalSparkContext { // Collecting the RDD should now fail with an informative exception val blockId = RDDBlockId(rdd.id, numPartitions - 1) bmm.removeBlock(blockId) + // Wait until the block has been removed successfully. + eventually(timeout(1 seconds), interval(100 milliseconds)) { + assert(bmm.getBlockStatus(blockId).isEmpty) + } try { rdd.collect() fail("Collect should have failed if local checkpoint block is removed...") From a18d637112b97d2caaca0a8324bdd99086664b24 Mon Sep 17 00:00:00 2001 From: Michael Gummelt Date: Thu, 15 Jun 2017 11:46:00 -0700 Subject: [PATCH 009/118] [SPARK-20434][YARN][CORE] Move Hadoop delegation token code from yarn to core ## What changes were proposed in this pull request? Move Hadoop delegation token code from `spark-yarn` to `spark-core`, so that other schedulers (such as Mesos), may use it. In order to avoid exposing Hadoop interfaces in spark-core, the new Hadoop delegation token classes are kept private. In order to provider backward compatiblity, and to allow YARN users to continue to load their own delegation token providers via Java service loading, the old YARN interfaces, as well as the client code that uses them, have been retained. Summary: - Move registered `yarn.security.ServiceCredentialProvider` classes from `spark-yarn` to `spark-core`. Moved them into a new, private hierarchy under `HadoopDelegationTokenProvider`. Client code in `HadoopDelegationTokenManager` now loads credentials from a whitelist of three providers (`HadoopFSDelegationTokenProvider`, `HiveDelegationTokenProvider`, `HBaseDelegationTokenProvider`), instead of service loading, which means that users are not able to implement their own delegation token providers, as they are in the `spark-yarn` module. - The `yarn.security.ServiceCredentialProvider` interface has been kept for backwards compatibility, and to continue to allow YARN users to implement their own delegation token provider implementations. Client code in YARN now fetches tokens via the new `YARNHadoopDelegationTokenManager` class, which fetches tokens from the core providers through `HadoopDelegationTokenManager`, as well as service loads them from `yarn.security.ServiceCredentialProvider`. Old Hierarchy: ``` yarn.security.ServiceCredentialProvider (service loaded) HadoopFSCredentialProvider HiveCredentialProvider HBaseCredentialProvider yarn.security.ConfigurableCredentialManager ``` New Hierarchy: ``` HadoopDelegationTokenManager HadoopDelegationTokenProvider (not service loaded) HadoopFSDelegationTokenProvider HiveDelegationTokenProvider HBaseDelegationTokenProvider yarn.security.ServiceCredentialProvider (service loaded) yarn.security.YARNHadoopDelegationTokenManager ``` ## How was this patch tested? unit tests Author: Michael Gummelt Author: Dr. Stefan Schimanski Closes #17723 from mgummelt/SPARK-20434-refactor-kerberos. --- core/pom.xml | 28 ++++ .../HBaseDelegationTokenProvider.scala | 11 +- .../HadoopDelegationTokenManager.scala | 119 ++++++++++++++ .../HadoopDelegationTokenProvider.scala | 50 ++++++ .../HadoopFSDelegationTokenProvider.scala | 126 +++++++++++++++ .../HiveDelegationTokenProvider.scala | 78 ++++----- .../HadoopDelegationTokenManagerSuite.scala | 116 ++++++++++++++ dev/.rat-excludes | 5 +- docs/running-on-yarn.md | 12 +- resource-managers/yarn/pom.xml | 14 +- ...oy.yarn.security.ServiceCredentialProvider | 3 - .../spark/deploy/yarn/ApplicationMaster.scala | 10 +- .../org/apache/spark/deploy/yarn/Client.scala | 9 +- .../deploy/yarn/YarnSparkHadoopUtil.scala | 31 +++- .../yarn/security/AMCredentialRenewer.scala | 6 +- .../ConfigurableCredentialManager.scala | 107 ------------- .../yarn/security/CredentialUpdater.scala | 2 +- .../security/HadoopFSCredentialProvider.scala | 120 -------------- .../security/ServiceCredentialProvider.scala | 3 +- .../YARNHadoopDelegationTokenManager.scala | 83 ++++++++++ ...oy.yarn.security.ServiceCredentialProvider | 2 +- .../ConfigurableCredentialManagerSuite.scala | 150 ------------------ .../HadoopFSCredentialProviderSuite.scala | 70 -------- ...ARNHadoopDelegationTokenManagerSuite.scala | 66 ++++++++ 24 files changed, 689 insertions(+), 532 deletions(-) rename resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala => core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala (88%) create mode 100644 core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala rename resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala => core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala (54%) create mode 100644 core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala delete mode 100644 resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider delete mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala delete mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProvider.scala create mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala delete mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala delete mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProviderSuite.scala create mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala diff --git a/core/pom.xml b/core/pom.xml index 7f245b5b6384a..326dde4f274bb 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -357,6 +357,34 @@ org.apache.commons commons-crypto + + + + ${hive.group} + hive-exec + provided + + + ${hive.group} + hive-metastore + provided + + + org.apache.thrift + libthrift + provided + + + org.apache.thrift + libfb303 + provided + + target/scala-${scala.binary.version}/classes diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala similarity index 88% rename from resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala rename to core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala index 5adeb8e605ff4..35621daf9c0d7 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.deploy.yarn.security +package org.apache.spark.deploy.security import scala.reflect.runtime.universe import scala.util.control.NonFatal @@ -24,17 +24,16 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.token.{Token, TokenIdentifier} -import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -private[security] class HBaseCredentialProvider extends ServiceCredentialProvider with Logging { +private[security] class HBaseDelegationTokenProvider + extends HadoopDelegationTokenProvider with Logging { override def serviceName: String = "hbase" - override def obtainCredentials( + override def obtainDelegationTokens( hadoopConf: Configuration, - sparkConf: SparkConf, creds: Credentials): Option[Long] = { try { val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) @@ -55,7 +54,7 @@ private[security] class HBaseCredentialProvider extends ServiceCredentialProvide None } - override def credentialsRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired(hadoopConf: Configuration): Boolean = { hbaseConf(hadoopConf).get("hbase.security.authentication") == "kerberos" } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala new file mode 100644 index 0000000000000..89b6f52ba4bca --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.security.Credentials + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging + +/** + * Manages all the registered HadoopDelegationTokenProviders and offer APIs for other modules to + * obtain delegation tokens and their renewal time. By default [[HadoopFSDelegationTokenProvider]], + * [[HiveDelegationTokenProvider]] and [[HBaseDelegationTokenProvider]] will be loaded in if not + * explicitly disabled. + * + * Also, each HadoopDelegationTokenProvider is controlled by + * spark.security.credentials.{service}.enabled, and will not be loaded if this config is set to + * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be + * enabled/disabled by the configuration spark.security.credentials.hive.enabled. + * + * @param sparkConf Spark configuration + * @param hadoopConf Hadoop configuration + * @param fileSystems Delegation tokens will be fetched for these Hadoop filesystems. + */ +private[spark] class HadoopDelegationTokenManager( + sparkConf: SparkConf, + hadoopConf: Configuration, + fileSystems: Set[FileSystem]) + extends Logging { + + private val deprecatedProviderEnabledConfigs = List( + "spark.yarn.security.tokens.%s.enabled", + "spark.yarn.security.credentials.%s.enabled") + private val providerEnabledConfig = "spark.security.credentials.%s.enabled" + + // Maintain all the registered delegation token providers + private val delegationTokenProviders = getDelegationTokenProviders + logDebug(s"Using the following delegation token providers: " + + s"${delegationTokenProviders.keys.mkString(", ")}.") + + private def getDelegationTokenProviders: Map[String, HadoopDelegationTokenProvider] = { + val providers = List(new HadoopFSDelegationTokenProvider(fileSystems), + new HiveDelegationTokenProvider, + new HBaseDelegationTokenProvider) + + // Filter out providers for which spark.security.credentials.{service}.enabled is false. + providers + .filter { p => isServiceEnabled(p.serviceName) } + .map { p => (p.serviceName, p) } + .toMap + } + + def isServiceEnabled(serviceName: String): Boolean = { + val key = providerEnabledConfig.format(serviceName) + + deprecatedProviderEnabledConfigs.foreach { pattern => + val deprecatedKey = pattern.format(serviceName) + if (sparkConf.contains(deprecatedKey)) { + logWarning(s"${deprecatedKey} is deprecated. Please use ${key} instead.") + } + } + + val isEnabledDeprecated = deprecatedProviderEnabledConfigs.forall { pattern => + sparkConf + .getOption(pattern.format(serviceName)) + .map(_.toBoolean) + .getOrElse(true) + } + + sparkConf + .getOption(key) + .map(_.toBoolean) + .getOrElse(isEnabledDeprecated) + } + + /** + * Get delegation token provider for the specified service. + */ + def getServiceDelegationTokenProvider(service: String): Option[HadoopDelegationTokenProvider] = { + delegationTokenProviders.get(service) + } + + /** + * Writes delegation tokens to creds. Delegation tokens are fetched from all registered + * providers. + * + * @return Time after which the fetched delegation tokens should be renewed. + */ + def obtainDelegationTokens( + hadoopConf: Configuration, + creds: Credentials): Long = { + delegationTokenProviders.values.flatMap { provider => + if (provider.delegationTokensRequired(hadoopConf)) { + provider.obtainDelegationTokens(hadoopConf, creds) + } else { + logDebug(s"Service ${provider.serviceName} does not require a token." + + s" Check your configuration to see if security is disabled or not.") + None + } + }.foldLeft(Long.MaxValue)(math.min) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala new file mode 100644 index 0000000000000..f162e7e58c53a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.Credentials + +/** + * Hadoop delegation token provider. + */ +private[spark] trait HadoopDelegationTokenProvider { + + /** + * Name of the service to provide delegation tokens. This name should be unique. Spark will + * internally use this name to differentiate delegation token providers. + */ + def serviceName: String + + /** + * Returns true if delegation tokens are required for this service. By default, it is based on + * whether Hadoop security is enabled. + */ + def delegationTokensRequired(hadoopConf: Configuration): Boolean + + /** + * Obtain delegation tokens for this service and get the time of the next renewal. + * @param hadoopConf Configuration of current Hadoop Compatible system. + * @param creds Credentials to add tokens and security keys to. + * @return If the returned tokens are renewable and can be renewed, return the time of the next + * renewal, otherwise None should be returned. + */ + def obtainDelegationTokens( + hadoopConf: Configuration, + creds: Credentials): Option[Long] +} diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala new file mode 100644 index 0000000000000..13157f33e2bf9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import scala.collection.JavaConverters._ +import scala.util.Try + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.mapred.Master +import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging + +private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: Set[FileSystem]) + extends HadoopDelegationTokenProvider with Logging { + + // This tokenRenewalInterval will be set in the first call to obtainDelegationTokens. + // If None, no token renewer is specified or no token can be renewed, + // so we cannot get the token renewal interval. + private var tokenRenewalInterval: Option[Long] = null + + override val serviceName: String = "hadoopfs" + + override def obtainDelegationTokens( + hadoopConf: Configuration, + creds: Credentials): Option[Long] = { + + val newCreds = fetchDelegationTokens( + getTokenRenewer(hadoopConf), + fileSystems) + + // Get the token renewal interval if it is not set. It will only be called once. + if (tokenRenewalInterval == null) { + tokenRenewalInterval = getTokenRenewalInterval(hadoopConf, fileSystems) + } + + // Get the time of next renewal. + val nextRenewalDate = tokenRenewalInterval.flatMap { interval => + val nextRenewalDates = newCreds.getAllTokens.asScala + .filter(_.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier]) + .map { token => + val identifier = token + .decodeIdentifier() + .asInstanceOf[AbstractDelegationTokenIdentifier] + identifier.getIssueDate + interval + } + if (nextRenewalDates.isEmpty) None else Some(nextRenewalDates.min) + } + + creds.addAll(newCreds) + nextRenewalDate + } + + def delegationTokensRequired(hadoopConf: Configuration): Boolean = { + UserGroupInformation.isSecurityEnabled + } + + private def getTokenRenewer(hadoopConf: Configuration): String = { + val tokenRenewer = Master.getMasterPrincipal(hadoopConf) + logDebug("Delegation token renewer is: " + tokenRenewer) + + if (tokenRenewer == null || tokenRenewer.length() == 0) { + val errorMessage = "Can't get Master Kerberos principal for use as renewer." + logError(errorMessage) + throw new SparkException(errorMessage) + } + + tokenRenewer + } + + private def fetchDelegationTokens( + renewer: String, + filesystems: Set[FileSystem]): Credentials = { + + val creds = new Credentials() + + filesystems.foreach { fs => + logInfo("getting token for: " + fs) + fs.addDelegationTokens(renewer, creds) + } + + creds + } + + private def getTokenRenewalInterval( + hadoopConf: Configuration, + filesystems: Set[FileSystem]): Option[Long] = { + // We cannot use the tokens generated with renewer yarn. Trying to renew + // those will fail with an access control issue. So create new tokens with the logged in + // user as renewer. + val creds = fetchDelegationTokens( + UserGroupInformation.getCurrentUser.getUserName, + filesystems) + + val renewIntervals = creds.getAllTokens.asScala.filter { + _.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier] + }.flatMap { token => + Try { + val newExpiration = token.renew(hadoopConf) + val identifier = token.decodeIdentifier().asInstanceOf[AbstractDelegationTokenIdentifier] + val interval = newExpiration - identifier.getIssueDate + logInfo(s"Renewal interval is $interval for token ${token.getKind.toString}") + interval + }.toOption + } + if (renewIntervals.isEmpty) None else Some(renewIntervals.min) + } +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala similarity index 54% rename from resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala rename to core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala index 16d8fc32bb42d..53b9f898c6e7d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala @@ -15,97 +15,89 @@ * limitations under the License. */ -package org.apache.spark.deploy.yarn.security +package org.apache.spark.deploy.security import java.lang.reflect.UndeclaredThrowableException import java.security.PrivilegedExceptionAction -import scala.reflect.runtime.universe import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.io.Text import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.Token -import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -private[security] class HiveCredentialProvider extends ServiceCredentialProvider with Logging { +private[security] class HiveDelegationTokenProvider + extends HadoopDelegationTokenProvider with Logging { override def serviceName: String = "hive" + private val classNotFoundErrorStr = s"You are attempting to use the " + + s"${getClass.getCanonicalName}, but your Spark distribution is not built with Hive libraries." + private def hiveConf(hadoopConf: Configuration): Configuration = { try { - val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) - // the hive configuration class is a subclass of Hadoop Configuration, so can be cast down - // to a Configuration and used without reflection - val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") - // using the (Configuration, Class) constructor allows the current configuration to be - // included in the hive config. - val ctor = hiveConfClass.getDeclaredConstructor(classOf[Configuration], - classOf[Object].getClass) - ctor.newInstance(hadoopConf, hiveConfClass).asInstanceOf[Configuration] + new HiveConf(hadoopConf, classOf[HiveConf]) } catch { case NonFatal(e) => logDebug("Fail to create Hive Configuration", e) hadoopConf + case e: NoClassDefFoundError => + logWarning(classNotFoundErrorStr) + hadoopConf } } - override def credentialsRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired(hadoopConf: Configuration): Boolean = { UserGroupInformation.isSecurityEnabled && hiveConf(hadoopConf).getTrimmed("hive.metastore.uris", "").nonEmpty } - override def obtainCredentials( + override def obtainDelegationTokens( hadoopConf: Configuration, - sparkConf: SparkConf, creds: Credentials): Option[Long] = { - val conf = hiveConf(hadoopConf) - - val principalKey = "hive.metastore.kerberos.principal" - val principal = conf.getTrimmed(principalKey, "") - require(principal.nonEmpty, s"Hive principal $principalKey undefined") - val metastoreUri = conf.getTrimmed("hive.metastore.uris", "") - require(metastoreUri.nonEmpty, "Hive metastore uri undefined") - - val currentUser = UserGroupInformation.getCurrentUser() - logDebug(s"Getting Hive delegation token for ${currentUser.getUserName()} against " + - s"$principal at $metastoreUri") + try { + val conf = hiveConf(hadoopConf) - val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) - val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") - val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") - val closeCurrent = hiveClass.getMethod("closeCurrent") + val principalKey = "hive.metastore.kerberos.principal" + val principal = conf.getTrimmed(principalKey, "") + require(principal.nonEmpty, s"Hive principal $principalKey undefined") + val metastoreUri = conf.getTrimmed("hive.metastore.uris", "") + require(metastoreUri.nonEmpty, "Hive metastore uri undefined") - try { - // get all the instance methods before invoking any - val getDelegationToken = hiveClass.getMethod("getDelegationToken", - classOf[String], classOf[String]) - val getHive = hiveClass.getMethod("get", hiveConfClass) + val currentUser = UserGroupInformation.getCurrentUser() + logDebug(s"Getting Hive delegation token for ${currentUser.getUserName()} against " + + s"$principal at $metastoreUri") doAsRealUser { - val hive = getHive.invoke(null, conf) - val tokenStr = getDelegationToken.invoke(hive, currentUser.getUserName(), principal) - .asInstanceOf[String] + val hive = Hive.get(conf, classOf[HiveConf]) + val tokenStr = hive.getDelegationToken(currentUser.getUserName(), principal) + val hive2Token = new Token[DelegationTokenIdentifier]() hive2Token.decodeFromUrlString(tokenStr) logInfo(s"Get Token from hive metastore: ${hive2Token.toString}") creds.addToken(new Text("hive.server2.delegation.token"), hive2Token) } + + None } catch { case NonFatal(e) => - logDebug(s"Fail to get token from service $serviceName", e) + logDebug(s"Failed to get token from service $serviceName", e) + None + case e: NoClassDefFoundError => + logWarning(classNotFoundErrorStr) + None } finally { Utils.tryLogNonFatalError { - closeCurrent.invoke(null) + Hive.closeCurrent() } } - - None } /** diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala new file mode 100644 index 0000000000000..335f3449cb782 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.security.Credentials +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class HadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { + private var delegationTokenManager: HadoopDelegationTokenManager = null + private var sparkConf: SparkConf = null + private var hadoopConf: Configuration = null + + override def beforeAll(): Unit = { + super.beforeAll() + + sparkConf = new SparkConf() + hadoopConf = new Configuration() + } + + test("Correctly load default credential providers") { + delegationTokenManager = new HadoopDelegationTokenManager( + sparkConf, + hadoopConf, + hadoopFSsToAccess(hadoopConf)) + + delegationTokenManager.getServiceDelegationTokenProvider("hadoopfs") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hbase") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hive") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("bogus") should be (None) + } + + test("disable hive credential provider") { + sparkConf.set("spark.security.credentials.hive.enabled", "false") + delegationTokenManager = new HadoopDelegationTokenManager( + sparkConf, + hadoopConf, + hadoopFSsToAccess(hadoopConf)) + + delegationTokenManager.getServiceDelegationTokenProvider("hadoopfs") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hbase") should not be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hive") should be (None) + } + + test("using deprecated configurations") { + sparkConf.set("spark.yarn.security.tokens.hadoopfs.enabled", "false") + sparkConf.set("spark.yarn.security.credentials.hive.enabled", "false") + delegationTokenManager = new HadoopDelegationTokenManager( + sparkConf, + hadoopConf, + hadoopFSsToAccess(hadoopConf)) + + delegationTokenManager.getServiceDelegationTokenProvider("hadoopfs") should be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hive") should be (None) + delegationTokenManager.getServiceDelegationTokenProvider("hbase") should not be (None) + } + + test("verify no credentials are obtained") { + delegationTokenManager = new HadoopDelegationTokenManager( + sparkConf, + hadoopConf, + hadoopFSsToAccess(hadoopConf)) + val creds = new Credentials() + + // Tokens cannot be obtained from HDFS, Hive, HBase in unit tests. + delegationTokenManager.obtainDelegationTokens(hadoopConf, creds) + val tokens = creds.getAllTokens + tokens.size() should be (0) + } + + test("obtain tokens For HiveMetastore") { + val hadoopConf = new Configuration() + hadoopConf.set("hive.metastore.kerberos.principal", "bob") + // thrift picks up on port 0 and bails out, without trying to talk to endpoint + hadoopConf.set("hive.metastore.uris", "http://localhost:0") + + val hiveCredentialProvider = new HiveDelegationTokenProvider() + val credentials = new Credentials() + hiveCredentialProvider.obtainDelegationTokens(hadoopConf, credentials) + + credentials.getAllTokens.size() should be (0) + } + + test("Obtain tokens For HBase") { + val hadoopConf = new Configuration() + hadoopConf.set("hbase.security.authentication", "kerberos") + + val hbaseTokenProvider = new HBaseDelegationTokenProvider() + val creds = new Credentials() + hbaseTokenProvider.obtainDelegationTokens(hadoopConf, creds) + + creds.getAllTokens.size should be (0) + } + + private[spark] def hadoopFSsToAccess(hadoopConf: Configuration): Set[FileSystem] = { + Set(FileSystem.get(hadoopConf)) + } +} diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 2355d40d1e6fe..607234b4068d0 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -93,16 +93,13 @@ INDEX .lintr gen-java.* .*avpr -org.apache.spark.sql.sources.DataSourceRegister -org.apache.spark.scheduler.SparkHistoryListenerFactory .*parquet spark-deps-.* .*csv .*tsv -org.apache.spark.scheduler.ExternalClusterManager .*\.sql .Rbuildignore -org.apache.spark.deploy.yarn.security.ServiceCredentialProvider +META-INF/* spark-warehouse structured-streaming/* kafka-source-initial-offset-version-2.1.0.bin diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 2d56123028f2b..e4a74556d4f26 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -419,7 +419,7 @@ To use a custom metrics.properties for the application master and executors, upd - spark.yarn.security.credentials.${service}.enabled + spark.security.credentials.${service}.enabled true Controls whether to obtain credentials for services when security is enabled. @@ -482,11 +482,11 @@ token for the cluster's default Hadoop filesystem, and potentially for HBase and An HBase token will be obtained if HBase is in on classpath, the HBase configuration declares the application is secure (i.e. `hbase-site.xml` sets `hbase.security.authentication` to `kerberos`), -and `spark.yarn.security.credentials.hbase.enabled` is not set to `false`. +and `spark.security.credentials.hbase.enabled` is not set to `false`. Similarly, a Hive token will be obtained if Hive is on the classpath, its configuration includes a URI of the metadata store in `"hive.metastore.uris`, and -`spark.yarn.security.credentials.hive.enabled` is not set to `false`. +`spark.security.credentials.hive.enabled` is not set to `false`. If an application needs to interact with other secure Hadoop filesystems, then the tokens needed to access these clusters must be explicitly requested at @@ -500,7 +500,7 @@ Spark supports integrating with other security-aware services through Java Servi `java.util.ServiceLoader`). To do that, implementations of `org.apache.spark.deploy.yarn.security.ServiceCredentialProvider` should be available to Spark by listing their names in the corresponding file in the jar's `META-INF/services` directory. These plug-ins can be disabled by setting -`spark.yarn.security.credentials.{service}.enabled` to `false`, where `{service}` is the name of +`spark.security.credentials.{service}.enabled` to `false`, where `{service}` is the name of credential provider. ## Configuring the External Shuffle Service @@ -564,8 +564,8 @@ the Spark configuration must be set to disable token collection for the services The Spark configuration must include the lines: ``` -spark.yarn.security.credentials.hive.enabled false -spark.yarn.security.credentials.hbase.enabled false +spark.security.credentials.hive.enabled false +spark.security.credentials.hbase.enabled false ``` The configuration option `spark.yarn.access.hadoopFileSystems` must be unset. diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 71d4ad681e169..43a7ce95bd3de 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -167,29 +167,27 @@ ${jersey-1.version} - + ${hive.group} hive-exec - test + provided ${hive.group} hive-metastore - test + provided org.apache.thrift libthrift - test + provided org.apache.thrift libfb303 - test + provided diff --git a/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider b/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider deleted file mode 100644 index f5a807ecac9d7..0000000000000 --- a/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider +++ /dev/null @@ -1,3 +0,0 @@ -org.apache.spark.deploy.yarn.security.HadoopFSCredentialProvider -org.apache.spark.deploy.yarn.security.HBaseCredentialProvider -org.apache.spark.deploy.yarn.security.HiveCredentialProvider diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 6da2c0b5f330a..4f71a1606312d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -38,7 +38,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.{AMCredentialRenewer, ConfigurableCredentialManager} +import org.apache.spark.deploy.yarn.security.{AMCredentialRenewer, YARNHadoopDelegationTokenManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc._ @@ -247,8 +247,12 @@ private[spark] class ApplicationMaster( if (sparkConf.contains(CREDENTIALS_FILE_PATH.key)) { // If a principal and keytab have been set, use that to create new credentials for executors // periodically - credentialRenewer = - new ConfigurableCredentialManager(sparkConf, yarnConf).credentialRenewer() + val credentialManager = new YARNHadoopDelegationTokenManager( + sparkConf, + yarnConf, + YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, yarnConf)) + + val credentialRenewer = new AMCredentialRenewer(sparkConf, yarnConf, credentialManager) credentialRenewer.scheduleLoginFromKeytab() } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 1fb7edf2a6e30..e5131e636dc04 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -49,7 +49,7 @@ import org.apache.hadoop.yarn.util.Records import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.ConfigurableCredentialManager +import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} @@ -121,7 +121,10 @@ private[spark] class Client( private val appStagingBaseDir = sparkConf.get(STAGING_DIR).map { new Path(_) } .getOrElse(FileSystem.get(hadoopConf).getHomeDirectory()) - private val credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) + private val credentialManager = new YARNHadoopDelegationTokenManager( + sparkConf, + hadoopConf, + YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, hadoopConf)) def reportLauncherState(state: SparkAppHandle.State): Unit = { launcherBackend.setState(state) @@ -368,7 +371,7 @@ private[spark] class Client( val fs = destDir.getFileSystem(hadoopConf) // Merge credentials obtained from registered providers - val nearestTimeOfNextRenewal = credentialManager.obtainCredentials(hadoopConf, credentials) + val nearestTimeOfNextRenewal = credentialManager.obtainDelegationTokens(hadoopConf, credentials) if (credentials != null) { // Add credentials to current user's UGI, so that following operations don't need to use the diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 0fc994d629ccb..4522071bd92e2 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -24,8 +24,9 @@ import java.util.regex.Pattern import scala.collection.mutable.{HashMap, ListBuffer} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.{JobConf, Master} import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api.ApplicationConstants @@ -35,11 +36,14 @@ import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.yarn.security.{ConfigurableCredentialManager, CredentialUpdater} +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.deploy.yarn.security.CredentialUpdater +import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.config._ import org.apache.spark.launcher.YarnCommandBuilderUtils import org.apache.spark.util.Utils + /** * Contains util methods to interact with Hadoop from spark. */ @@ -87,8 +91,12 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } private[spark] override def startCredentialUpdater(sparkConf: SparkConf): Unit = { - credentialUpdater = - new ConfigurableCredentialManager(sparkConf, newConfiguration(sparkConf)).credentialUpdater() + val hadoopConf = newConfiguration(sparkConf) + val credentialManager = new YARNHadoopDelegationTokenManager( + sparkConf, + hadoopConf, + YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, hadoopConf)) + credentialUpdater = new CredentialUpdater(sparkConf, hadoopConf, credentialManager) credentialUpdater.start() } @@ -103,6 +111,21 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) ConverterUtils.toContainerId(containerIdString) } + + /** The filesystems for which YARN should fetch delegation tokens. */ + private[spark] def hadoopFSsToAccess( + sparkConf: SparkConf, + hadoopConf: Configuration): Set[FileSystem] = { + val filesystemsToAccess = sparkConf.get(FILESYSTEMS_TO_ACCESS) + .map(new Path(_).getFileSystem(hadoopConf)) + .toSet + + val stagingFS = sparkConf.get(STAGING_DIR) + .map(new Path(_).getFileSystem(hadoopConf)) + .getOrElse(FileSystem.get(hadoopConf)) + + filesystemsToAccess + stagingFS + } } object YarnSparkHadoopUtil { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala index 7e76f402db249..68a2e9e70a78b 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala @@ -54,7 +54,7 @@ import org.apache.spark.util.ThreadUtils private[yarn] class AMCredentialRenewer( sparkConf: SparkConf, hadoopConf: Configuration, - credentialManager: ConfigurableCredentialManager) extends Logging { + credentialManager: YARNHadoopDelegationTokenManager) extends Logging { private var lastCredentialsFileSuffix = 0 @@ -174,7 +174,9 @@ private[yarn] class AMCredentialRenewer( keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { // Get a copy of the credentials override def run(): Void = { - nearestNextRenewalTime = credentialManager.obtainCredentials(freshHadoopConf, tempCreds) + nearestNextRenewalTime = credentialManager.obtainDelegationTokens( + freshHadoopConf, + tempCreds) null } }) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala deleted file mode 100644 index 4f4be52a0d691..0000000000000 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import java.util.ServiceLoader - -import scala.collection.JavaConverters._ - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.security.Credentials - -import org.apache.spark.SparkConf -import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils - -/** - * A ConfigurableCredentialManager to manage all the registered credential providers and offer - * APIs for other modules to obtain credentials as well as renewal time. By default - * [[HadoopFSCredentialProvider]], [[HiveCredentialProvider]] and [[HBaseCredentialProvider]] will - * be loaded in if not explicitly disabled, any plugged-in credential provider wants to be - * managed by ConfigurableCredentialManager needs to implement [[ServiceCredentialProvider]] - * interface and put into resources/META-INF/services to be loaded by ServiceLoader. - * - * Also each credential provider is controlled by - * spark.yarn.security.credentials.{service}.enabled, it will not be loaded in if set to false. - * For example, Hive's credential provider [[HiveCredentialProvider]] can be enabled/disabled by - * the configuration spark.yarn.security.credentials.hive.enabled. - */ -private[yarn] final class ConfigurableCredentialManager( - sparkConf: SparkConf, hadoopConf: Configuration) extends Logging { - private val deprecatedProviderEnabledConfig = "spark.yarn.security.tokens.%s.enabled" - private val providerEnabledConfig = "spark.yarn.security.credentials.%s.enabled" - - // Maintain all the registered credential providers - private val credentialProviders = { - val providers = ServiceLoader.load(classOf[ServiceCredentialProvider], - Utils.getContextOrSparkClassLoader).asScala - - // Filter out credentials in which spark.yarn.security.credentials.{service}.enabled is false. - providers.filter { p => - sparkConf.getOption(providerEnabledConfig.format(p.serviceName)) - .orElse { - sparkConf.getOption(deprecatedProviderEnabledConfig.format(p.serviceName)).map { c => - logWarning(s"${deprecatedProviderEnabledConfig.format(p.serviceName)} is deprecated, " + - s"using ${providerEnabledConfig.format(p.serviceName)} instead") - c - } - }.map(_.toBoolean).getOrElse(true) - }.map { p => (p.serviceName, p) }.toMap - } - - /** - * Get credential provider for the specified service. - */ - def getServiceCredentialProvider(service: String): Option[ServiceCredentialProvider] = { - credentialProviders.get(service) - } - - /** - * Obtain credentials from all the registered providers. - * @return nearest time of next renewal, Long.MaxValue if all the credentials aren't renewable, - * otherwise the nearest renewal time of any credentials will be returned. - */ - def obtainCredentials(hadoopConf: Configuration, creds: Credentials): Long = { - credentialProviders.values.flatMap { provider => - if (provider.credentialsRequired(hadoopConf)) { - provider.obtainCredentials(hadoopConf, sparkConf, creds) - } else { - logDebug(s"Service ${provider.serviceName} does not require a token." + - s" Check your configuration to see if security is disabled or not.") - None - } - }.foldLeft(Long.MaxValue)(math.min) - } - - /** - * Create an [[AMCredentialRenewer]] instance, caller should be responsible to stop this - * instance when it is not used. AM will use it to renew credentials periodically. - */ - def credentialRenewer(): AMCredentialRenewer = { - new AMCredentialRenewer(sparkConf, hadoopConf, this) - } - - /** - * Create an [[CredentialUpdater]] instance, caller should be resposible to stop this intance - * when it is not used. Executors and driver (client mode) will use it to update credentials. - * periodically. - */ - def credentialUpdater(): CredentialUpdater = { - new CredentialUpdater(sparkConf, hadoopConf, this) - } -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala index 41b7b5d60b038..fe173dffc22a8 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class CredentialUpdater( sparkConf: SparkConf, hadoopConf: Configuration, - credentialManager: ConfigurableCredentialManager) extends Logging { + credentialManager: YARNHadoopDelegationTokenManager) extends Logging { @volatile private var lastCredentialsFileSuffix = 0 diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProvider.scala deleted file mode 100644 index f65c886db944e..0000000000000 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProvider.scala +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import scala.collection.JavaConverters._ -import scala.util.Try - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.mapred.Master -import org.apache.hadoop.security.Credentials -import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config._ - -private[security] class HadoopFSCredentialProvider - extends ServiceCredentialProvider with Logging { - // Token renewal interval, this value will be set in the first call, - // if None means no token renewer specified or no token can be renewed, - // so cannot get token renewal interval. - private var tokenRenewalInterval: Option[Long] = null - - override val serviceName: String = "hadoopfs" - - override def obtainCredentials( - hadoopConf: Configuration, - sparkConf: SparkConf, - creds: Credentials): Option[Long] = { - // NameNode to access, used to get tokens from different FileSystems - val tmpCreds = new Credentials() - val tokenRenewer = getTokenRenewer(hadoopConf) - hadoopFSsToAccess(hadoopConf, sparkConf).foreach { dst => - val dstFs = dst.getFileSystem(hadoopConf) - logInfo("getting token for: " + dst) - dstFs.addDelegationTokens(tokenRenewer, tmpCreds) - } - - // Get the token renewal interval if it is not set. It will only be called once. - if (tokenRenewalInterval == null) { - tokenRenewalInterval = getTokenRenewalInterval(hadoopConf, sparkConf) - } - - // Get the time of next renewal. - val nextRenewalDate = tokenRenewalInterval.flatMap { interval => - val nextRenewalDates = tmpCreds.getAllTokens.asScala - .filter(_.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier]) - .map { t => - val identifier = t.decodeIdentifier().asInstanceOf[AbstractDelegationTokenIdentifier] - identifier.getIssueDate + interval - } - if (nextRenewalDates.isEmpty) None else Some(nextRenewalDates.min) - } - - creds.addAll(tmpCreds) - nextRenewalDate - } - - private def getTokenRenewalInterval( - hadoopConf: Configuration, sparkConf: SparkConf): Option[Long] = { - // We cannot use the tokens generated with renewer yarn. Trying to renew - // those will fail with an access control issue. So create new tokens with the logged in - // user as renewer. - sparkConf.get(PRINCIPAL).flatMap { renewer => - val creds = new Credentials() - hadoopFSsToAccess(hadoopConf, sparkConf).foreach { dst => - val dstFs = dst.getFileSystem(hadoopConf) - dstFs.addDelegationTokens(renewer, creds) - } - - val renewIntervals = creds.getAllTokens.asScala.filter { - _.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier] - }.flatMap { token => - Try { - val newExpiration = token.renew(hadoopConf) - val identifier = token.decodeIdentifier().asInstanceOf[AbstractDelegationTokenIdentifier] - val interval = newExpiration - identifier.getIssueDate - logInfo(s"Renewal interval is $interval for token ${token.getKind.toString}") - interval - }.toOption - } - if (renewIntervals.isEmpty) None else Some(renewIntervals.min) - } - } - - private def getTokenRenewer(conf: Configuration): String = { - val delegTokenRenewer = Master.getMasterPrincipal(conf) - logDebug("delegation token renewer is: " + delegTokenRenewer) - if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { - val errorMessage = "Can't get Master Kerberos principal for use as renewer" - logError(errorMessage) - throw new SparkException(errorMessage) - } - - delegTokenRenewer - } - - private def hadoopFSsToAccess(hadoopConf: Configuration, sparkConf: SparkConf): Set[Path] = { - sparkConf.get(FILESYSTEMS_TO_ACCESS).map(new Path(_)).toSet + - sparkConf.get(STAGING_DIR).map(new Path(_)) - .getOrElse(FileSystem.get(hadoopConf).getHomeDirectory) - } -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala index 4e3fcce8dbb1d..cc24ac4d9bcf6 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala @@ -35,7 +35,7 @@ trait ServiceCredentialProvider { def serviceName: String /** - * To decide whether credential is required for this service. By default it based on whether + * Returns true if credentials are required by this service. By default, it is based on whether * Hadoop security is enabled. */ def credentialsRequired(hadoopConf: Configuration): Boolean = { @@ -44,6 +44,7 @@ trait ServiceCredentialProvider { /** * Obtain credentials for this service and get the time of the next renewal. + * * @param hadoopConf Configuration of current Hadoop Compatible system. * @param sparkConf Spark configuration. * @param creds Credentials to add tokens and security keys to. diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala new file mode 100644 index 0000000000000..bbd17c8fc1272 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn.security + +import java.util.ServiceLoader + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.security.Credentials + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.security.HadoopDelegationTokenManager +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * This class loads delegation token providers registered under the YARN-specific + * [[ServiceCredentialProvider]] interface, as well as the builtin providers defined + * in [[HadoopDelegationTokenManager]]. + */ +private[yarn] class YARNHadoopDelegationTokenManager( + sparkConf: SparkConf, + hadoopConf: Configuration, + fileSystems: Set[FileSystem]) extends Logging { + + private val delegationTokenManager = + new HadoopDelegationTokenManager(sparkConf, hadoopConf, fileSystems) + + // public for testing + val credentialProviders = getCredentialProviders + + /** + * Writes delegation tokens to creds. Delegation tokens are fetched from all registered + * providers. + * + * @return Time after which the fetched delegation tokens should be renewed. + */ + def obtainDelegationTokens(hadoopConf: Configuration, creds: Credentials): Long = { + val superInterval = delegationTokenManager.obtainDelegationTokens(hadoopConf, creds) + + credentialProviders.values.flatMap { provider => + if (provider.credentialsRequired(hadoopConf)) { + provider.obtainCredentials(hadoopConf, sparkConf, creds) + } else { + logDebug(s"Service ${provider.serviceName} does not require a token." + + s" Check your configuration to see if security is disabled or not.") + None + } + }.foldLeft(superInterval)(math.min) + } + + private def getCredentialProviders: Map[String, ServiceCredentialProvider] = { + val providers = loadCredentialProviders + + providers. + filter { p => delegationTokenManager.isServiceEnabled(p.serviceName) } + .map { p => (p.serviceName, p) } + .toMap + } + + private def loadCredentialProviders: List[ServiceCredentialProvider] = { + ServiceLoader.load(classOf[ServiceCredentialProvider], Utils.getContextOrSparkClassLoader) + .asScala + .toList + } +} diff --git a/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider b/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider index d0ef5efa36e86..f31c232693133 100644 --- a/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider +++ b/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider @@ -1 +1 @@ -org.apache.spark.deploy.yarn.security.TestCredentialProvider +org.apache.spark.deploy.yarn.security.YARNTestCredentialProvider diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala deleted file mode 100644 index b0067aa4517c7..0000000000000 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.io.Text -import org.apache.hadoop.security.Credentials -import org.apache.hadoop.security.token.Token -import org.scalatest.{BeforeAndAfter, Matchers} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.yarn.config._ - -class ConfigurableCredentialManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { - private var credentialManager: ConfigurableCredentialManager = null - private var sparkConf: SparkConf = null - private var hadoopConf: Configuration = null - - override def beforeAll(): Unit = { - super.beforeAll() - - sparkConf = new SparkConf() - hadoopConf = new Configuration() - System.setProperty("SPARK_YARN_MODE", "true") - } - - override def afterAll(): Unit = { - System.clearProperty("SPARK_YARN_MODE") - - super.afterAll() - } - - test("Correctly load default credential providers") { - credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) - - credentialManager.getServiceCredentialProvider("hadoopfs") should not be (None) - credentialManager.getServiceCredentialProvider("hbase") should not be (None) - credentialManager.getServiceCredentialProvider("hive") should not be (None) - } - - test("disable hive credential provider") { - sparkConf.set("spark.yarn.security.credentials.hive.enabled", "false") - credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) - - credentialManager.getServiceCredentialProvider("hadoopfs") should not be (None) - credentialManager.getServiceCredentialProvider("hbase") should not be (None) - credentialManager.getServiceCredentialProvider("hive") should be (None) - } - - test("using deprecated configurations") { - sparkConf.set("spark.yarn.security.tokens.hadoopfs.enabled", "false") - sparkConf.set("spark.yarn.security.tokens.hive.enabled", "false") - credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) - - credentialManager.getServiceCredentialProvider("hadoopfs") should be (None) - credentialManager.getServiceCredentialProvider("hive") should be (None) - credentialManager.getServiceCredentialProvider("test") should not be (None) - credentialManager.getServiceCredentialProvider("hbase") should not be (None) - } - - test("verify obtaining credentials from provider") { - credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) - val creds = new Credentials() - - // Tokens can only be obtained from TestTokenProvider, for hdfs, hbase and hive tokens cannot - // be obtained. - credentialManager.obtainCredentials(hadoopConf, creds) - val tokens = creds.getAllTokens - tokens.size() should be (1) - tokens.iterator().next().getService should be (new Text("test")) - } - - test("verify getting credential renewal info") { - credentialManager = new ConfigurableCredentialManager(sparkConf, hadoopConf) - val creds = new Credentials() - - val testCredentialProvider = credentialManager.getServiceCredentialProvider("test").get - .asInstanceOf[TestCredentialProvider] - // Only TestTokenProvider can get the time of next token renewal - val nextRenewal = credentialManager.obtainCredentials(hadoopConf, creds) - nextRenewal should be (testCredentialProvider.timeOfNextTokenRenewal) - } - - test("obtain tokens For HiveMetastore") { - val hadoopConf = new Configuration() - hadoopConf.set("hive.metastore.kerberos.principal", "bob") - // thrift picks up on port 0 and bails out, without trying to talk to endpoint - hadoopConf.set("hive.metastore.uris", "http://localhost:0") - - val hiveCredentialProvider = new HiveCredentialProvider() - val credentials = new Credentials() - hiveCredentialProvider.obtainCredentials(hadoopConf, sparkConf, credentials) - - credentials.getAllTokens.size() should be (0) - } - - test("Obtain tokens For HBase") { - val hadoopConf = new Configuration() - hadoopConf.set("hbase.security.authentication", "kerberos") - - val hbaseTokenProvider = new HBaseCredentialProvider() - val creds = new Credentials() - hbaseTokenProvider.obtainCredentials(hadoopConf, sparkConf, creds) - - creds.getAllTokens.size should be (0) - } -} - -class TestCredentialProvider extends ServiceCredentialProvider { - val tokenRenewalInterval = 86400 * 1000L - var timeOfNextTokenRenewal = 0L - - override def serviceName: String = "test" - - override def credentialsRequired(conf: Configuration): Boolean = true - - override def obtainCredentials( - hadoopConf: Configuration, - sparkConf: SparkConf, - creds: Credentials): Option[Long] = { - if (creds == null) { - // Guard out other unit test failures. - return None - } - - val emptyToken = new Token() - emptyToken.setService(new Text("test")) - creds.addToken(emptyToken.getService, emptyToken) - - val currTime = System.currentTimeMillis() - timeOfNextTokenRenewal = (currTime - currTime % tokenRenewalInterval) + tokenRenewalInterval - - Some(timeOfNextTokenRenewal) - } -} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProviderSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProviderSuite.scala deleted file mode 100644 index f50ee193c258f..0000000000000 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HadoopFSCredentialProviderSuite.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import org.apache.hadoop.conf.Configuration -import org.scalatest.{Matchers, PrivateMethodTester} - -import org.apache.spark.{SparkException, SparkFunSuite} - -class HadoopFSCredentialProviderSuite - extends SparkFunSuite - with PrivateMethodTester - with Matchers { - private val _getTokenRenewer = PrivateMethod[String]('getTokenRenewer) - - private def getTokenRenewer( - fsCredentialProvider: HadoopFSCredentialProvider, conf: Configuration): String = { - fsCredentialProvider invokePrivate _getTokenRenewer(conf) - } - - private var hadoopFsCredentialProvider: HadoopFSCredentialProvider = null - - override def beforeAll() { - super.beforeAll() - - if (hadoopFsCredentialProvider == null) { - hadoopFsCredentialProvider = new HadoopFSCredentialProvider() - } - } - - override def afterAll() { - if (hadoopFsCredentialProvider != null) { - hadoopFsCredentialProvider = null - } - - super.afterAll() - } - - test("check token renewer") { - val hadoopConf = new Configuration() - hadoopConf.set("yarn.resourcemanager.address", "myrm:8033") - hadoopConf.set("yarn.resourcemanager.principal", "yarn/myrm:8032@SPARKTEST.COM") - val renewer = getTokenRenewer(hadoopFsCredentialProvider, hadoopConf) - renewer should be ("yarn/myrm:8032@SPARKTEST.COM") - } - - test("check token renewer default") { - val hadoopConf = new Configuration() - val caught = - intercept[SparkException] { - getTokenRenewer(hadoopFsCredentialProvider, hadoopConf) - } - assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer") - } -} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala new file mode 100644 index 0000000000000..2b226eff5ce19 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn.security + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.Credentials +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil + +class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { + private var credentialManager: YARNHadoopDelegationTokenManager = null + private var sparkConf: SparkConf = null + private var hadoopConf: Configuration = null + + override def beforeAll(): Unit = { + super.beforeAll() + + System.setProperty("SPARK_YARN_MODE", "true") + + sparkConf = new SparkConf() + hadoopConf = new Configuration() + } + + override def afterAll(): Unit = { + super.afterAll() + + System.clearProperty("SPARK_YARN_MODE") + } + + test("Correctly loads credential providers") { + credentialManager = new YARNHadoopDelegationTokenManager( + sparkConf, + hadoopConf, + YarnSparkHadoopUtil.get.hadoopFSsToAccess(sparkConf, hadoopConf)) + + credentialManager.credentialProviders.get("yarn-test") should not be (None) + } +} + +class YARNTestCredentialProvider extends ServiceCredentialProvider { + override def serviceName: String = "yarn-test" + + override def credentialsRequired(conf: Configuration): Boolean = true + + override def obtainCredentials( + hadoopConf: Configuration, + sparkConf: SparkConf, + creds: Credentials): Option[Long] = None +} From 5d35d5c15c63debaa79202708c6e6481980a6a7f Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 16 Jun 2017 10:11:23 +0800 Subject: [PATCH 010/118] [SPARK-21112][SQL] ALTER TABLE SET TBLPROPERTIES should not overwrite COMMENT ### What changes were proposed in this pull request? `ALTER TABLE SET TBLPROPERTIES` should not overwrite `COMMENT` even if the input property does not have the property of `COMMENT`. This PR is to fix the issue. ### How was this patch tested? Covered by the existing tests. Author: Xiao Li Closes #18318 from gatorsmile/fixTableComment. --- .../main/scala/org/apache/spark/sql/execution/command/ddl.scala | 2 +- sql/core/src/test/resources/sql-tests/results/describe.sql.out | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 5a7f8cf1eb59e..f924b3d914635 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -235,7 +235,7 @@ case class AlterTableSetPropertiesCommand( // direct property. val newTable = table.copy( properties = table.properties ++ properties, - comment = properties.get("comment")) + comment = properties.get("comment").orElse(table.comment)) catalog.alterTable(newTable) Seq.empty[Row] } diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 329532cd7c842..ab9f2783f06bb 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -127,6 +127,7 @@ Provider parquet Num Buckets 2 Bucket Columns [`a`] Sort Columns [`b`] +Comment table_comment Table Properties [e=3] Location [not included in comparison]sql/core/spark-warehouse/t Storage Properties [a=1, b=2] @@ -157,6 +158,7 @@ Provider parquet Num Buckets 2 Bucket Columns [`a`] Sort Columns [`b`] +Comment table_comment Table Properties [e=3] Location [not included in comparison]sql/core/spark-warehouse/t Storage Properties [a=1, b=2] From 87ab0cec65b50584a627037b9d1b6fdecaee725c Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Fri, 16 Jun 2017 12:10:09 +0800 Subject: [PATCH 011/118] [SPARK-21072][SQL] TreeNode.mapChildren should only apply to the children node. ## What changes were proposed in this pull request? Just as the function name and comments of `TreeNode.mapChildren` mentioned, the function should be apply to all currently node children. So, the follow code should judge whether it is the children node. https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala#L342 ## How was this patch tested? Existing tests. Author: Xianyang Liu Closes #18284 from ConeyLiu/treenode. --- .../spark/sql/catalyst/trees/TreeNode.scala | 14 +++++++++++-- .../sql/catalyst/trees/TreeNodeSuite.scala | 21 ++++++++++++++++++- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index df66f9a082aee..7375a0bcbae75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -340,8 +340,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { arg } case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => - val newChild1 = f(arg1.asInstanceOf[BaseType]) - val newChild2 = f(arg2.asInstanceOf[BaseType]) + val newChild1 = if (containsChild(arg1)) { + f(arg1.asInstanceOf[BaseType]) + } else { + arg1.asInstanceOf[BaseType] + } + + val newChild2 = if (containsChild(arg2)) { + f(arg2.asInstanceOf[BaseType]) + } else { + arg2.asInstanceOf[BaseType] + } + if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { changed = true (newChild1, newChild2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 712841835acd5..819078218c546 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -54,13 +54,21 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]]) override def output: Seq[Attribute] = Nil } -case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable { +case class ExpressionInMap(map: Map[String, Expression]) extends Unevaluable { override def children: Seq[Expression] = map.values.toSeq override def nullable: Boolean = true override def dataType: NullType = NullType override lazy val resolved = true } +case class SeqTupleExpression(sons: Seq[(Expression, Expression)], + nonSons: Seq[(Expression, Expression)]) extends Unevaluable { + override def children: Seq[Expression] = sons.flatMap(t => Iterator(t._1, t._2)) + override def nullable: Boolean = true + override def dataType: NullType = NullType + override lazy val resolved = true +} + case class JsonTestTreeNode(arg: Any) extends LeafNode { override def output: Seq[Attribute] = Seq.empty[Attribute] } @@ -146,6 +154,17 @@ class TreeNodeSuite extends SparkFunSuite { assert(actual === Dummy(None)) } + test("mapChildren should only works on children") { + val children = Seq((Literal(1), Literal(2))) + val nonChildren = Seq((Literal(3), Literal(4))) + val before = SeqTupleExpression(children, nonChildren) + val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } + val expect = SeqTupleExpression(Seq((Literal(0), Literal(0))), nonChildren) + + val actual = before mapChildren toZero + assert(actual === expect) + } + test("preserves origin") { CurrentOrigin.setPosition(1, 1) val add = Add(Literal(1), Literal(1)) From 7a3e5dc28b67ac1630c5a578a27a5a5acf80aa51 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 15 Jun 2017 23:06:58 -0700 Subject: [PATCH 012/118] [SPARK-20749][SQL] Built-in SQL Function Support - all variants of LEN[GTH] ## What changes were proposed in this pull request? This PR adds built-in SQL function `BIT_LENGTH()`, `CHAR_LENGTH()`, and `OCTET_LENGTH()` functions. `BIT_LENGTH()` returns the bit length of the given string or binary expression. `CHAR_LENGTH()` returns the length of the given string or binary expression. (i.e. equal to `LENGTH()`) `OCTET_LENGTH()` returns the byte length of the given string or binary expression. ## How was this patch tested? Added new test suites for these three functions Author: Kazuaki Ishizaki Closes #18046 from kiszk/SPARK-20749. --- .../catalyst/analysis/FunctionRegistry.scala | 3 + .../expressions/stringExpressions.scala | 61 ++++++++++++++++++- .../expressions/StringExpressionsSuite.scala | 20 ++++++ .../resources/sql-tests/inputs/operators.sql | 5 ++ .../sql-tests/results/operators.sql.out | 26 +++++++- 5 files changed, 112 insertions(+), 3 deletions(-) mode change 100644 => 100755 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 877328164a8a9..e4e9918a3a887 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -305,6 +305,8 @@ object FunctionRegistry { expression[Chr]("char"), expression[Chr]("chr"), expression[Base64]("base64"), + expression[BitLength]("bit_length"), + expression[Length]("char_length"), expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), expression[Decode]("decode"), @@ -321,6 +323,7 @@ object FunctionRegistry { expression[Levenshtein]("levenshtein"), expression[Like]("like"), expression[Lower]("lower"), + expression[OctetLength]("octet_length"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala old mode 100644 new mode 100755 index 717ada225a4f1..908fdb8f7e68f --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1199,15 +1199,18 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } /** - * A function that return the length of the given string or binary expression. + * A function that returns the char length of the given string expression or + * number of bytes of the given binary expression. */ +// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the length of `expr` or number of bytes in binary data.", + usage = "_FUNC_(expr) - Returns the character length of `expr` or number of bytes in binary data.", extended = """ Examples: > SELECT _FUNC_('Spark SQL'); 9 """) +// scalastyle:on line.size.limit case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1225,6 +1228,60 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastIn } } +/** + * A function that returns the bit length of the given string or binary expression. + */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the bit length of `expr` or number of bits in binary data.", + extended = """ + Examples: + > SELECT _FUNC_('Spark SQL'); + 72 + """) +case class BitLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + override def dataType: DataType = IntegerType + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) + + protected override def nullSafeEval(value: Any): Any = child.dataType match { + case StringType => value.asInstanceOf[UTF8String].numBytes * 8 + case BinaryType => value.asInstanceOf[Array[Byte]].length * 8 + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.dataType match { + case StringType => defineCodeGen(ctx, ev, c => s"($c).numBytes() * 8") + case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length * 8") + } + } +} + +/** + * A function that returns the byte length of the given string or binary expression. + */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the byte length of `expr` or number of bytes in binary data.", + extended = """ + Examples: + > SELECT _FUNC_('Spark SQL'); + 9 + """) +case class OctetLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + override def dataType: DataType = IntegerType + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) + + protected override def nullSafeEval(value: Any): Any = child.dataType match { + case StringType => value.asInstanceOf[UTF8String].numBytes + case BinaryType => value.asInstanceOf[Array[Byte]].length + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.dataType match { + case StringType => defineCodeGen(ctx, ev, c => s"($c).numBytes()") + case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") + } + } +} + /** * A function that return the Levenshtein distance between the two given strings. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 4bdb43bfed8b5..4f08031153ab0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -558,20 +558,40 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:off // non ascii characters are not allowed in the source code, so we disable the scalastyle. checkEvaluation(Length(Literal("a花花c")), 4, create_row(string)) + checkEvaluation(OctetLength(Literal("a花花c")), 8, create_row(string)) + checkEvaluation(BitLength(Literal("a花花c")), 8 * 8, create_row(string)) // scalastyle:on checkEvaluation(Length(Literal(bytes)), 5, create_row(Array.empty[Byte])) + checkEvaluation(OctetLength(Literal(bytes)), 5, create_row(Array.empty[Byte])) + checkEvaluation(BitLength(Literal(bytes)), 5 * 8, create_row(Array.empty[Byte])) checkEvaluation(Length(a), 5, create_row(string)) + checkEvaluation(OctetLength(a), 5, create_row(string)) + checkEvaluation(BitLength(a), 5 * 8, create_row(string)) checkEvaluation(Length(b), 5, create_row(bytes)) + checkEvaluation(OctetLength(b), 5, create_row(bytes)) + checkEvaluation(BitLength(b), 5 * 8, create_row(bytes)) checkEvaluation(Length(a), 0, create_row("")) + checkEvaluation(OctetLength(a), 0, create_row("")) + checkEvaluation(BitLength(a), 0, create_row("")) checkEvaluation(Length(b), 0, create_row(Array.empty[Byte])) + checkEvaluation(OctetLength(b), 0, create_row(Array.empty[Byte])) + checkEvaluation(BitLength(b), 0, create_row(Array.empty[Byte])) checkEvaluation(Length(a), null, create_row(null)) + checkEvaluation(OctetLength(a), null, create_row(null)) + checkEvaluation(BitLength(a), null, create_row(null)) checkEvaluation(Length(b), null, create_row(null)) + checkEvaluation(OctetLength(b), null, create_row(null)) + checkEvaluation(BitLength(b), null, create_row(null)) checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string)) + checkEvaluation(OctetLength(Literal.create(null, StringType)), null, create_row(string)) + checkEvaluation(BitLength(Literal.create(null, StringType)), null, create_row(string)) checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes)) + checkEvaluation(OctetLength(Literal.create(null, BinaryType)), null, create_row(bytes)) + checkEvaluation(BitLength(Literal.create(null, BinaryType)), null, create_row(bytes)) } test("format_number / FormatNumber") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 3934620577e99..a8de23e73892c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -80,3 +80,8 @@ select 1 > 0.00001; -- mod select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, null); + +-- length +select BIT_LENGTH('abc'); +select CHAR_LENGTH('abc'); +select OCTET_LENGTH('abc'); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 51ccf764d952f..85ee10b4d274f 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 51 +-- Number of queries: 54 -- !query 0 @@ -420,3 +420,27 @@ select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, nu struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NULL AS INT) % 2):int,(CAST(NULL AS DOUBLE) % CAST(NULL AS DOUBLE)):double> -- !query 50 output 1 NULL 0 NULL NULL NULL + + +-- !query 51 +select BIT_LENGTH('abc') +-- !query 51 schema +struct +-- !query 51 output +24 + + +-- !query 52 +select CHAR_LENGTH('abc') +-- !query 52 schema +struct +-- !query 52 output +3 + + +-- !query 53 +select OCTET_LENGTH('abc') +-- !query 53 schema +struct +-- !query 53 output +3 From 2837b14cdc42f096dce07e383caa30c7469c5d6b Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 16 Jun 2017 14:24:15 +0800 Subject: [PATCH 013/118] [SPARK-12552][FOLLOWUP] Fix flaky test for "o.a.s.deploy.master.MasterSuite.master correctly recover the application" ## What changes were proposed in this pull request? Due to the RPC asynchronous event processing, The test "correctly recover the application" could potentially be failed. The issue could be found in here: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/78126/testReport/org.apache.spark.deploy.master/MasterSuite/master_correctly_recover_the_application/. So here fixing this flaky test. ## How was this patch tested? Existing UT. CC cloud-fan jiangxb1987 , please help to review, thanks! Author: jerryshao Closes #18321 from jerryshao/SPARK-12552-followup. --- .../test/scala/org/apache/spark/deploy/master/MasterSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 6bb0eec040787..a2232126787f6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -214,7 +214,7 @@ class MasterSuite extends SparkFunSuite master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) // Wait until Master recover from checkpoint data. eventually(timeout(5 seconds), interval(100 milliseconds)) { - master.idToApp.size should be(1) + master.workers.size should be(1) } master.idToApp.keySet should be(Set(fakeAppInfo.id)) From 45824fb608930eb461e7df53bb678c9534c183a9 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 16 Jun 2017 11:03:54 +0100 Subject: [PATCH 014/118] [MINOR][DOCS] Improve Running R Tests docs ## What changes were proposed in this pull request? Update Running R Tests dependence packages to: ```bash R -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival'), repos='http://cran.us.r-project.org')" ``` ## How was this patch tested? manual tests Author: Yuming Wang Closes #18271 from wangyum/building-spark. --- R/README.md | 6 +----- R/WINDOWS.md | 3 +-- docs/building-spark.md | 8 +++++--- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/R/README.md b/R/README.md index 4c40c5963db70..1152b1e8e5f9f 100644 --- a/R/README.md +++ b/R/README.md @@ -66,11 +66,7 @@ To run one of them, use `./bin/spark-submit `. For example: ```bash ./bin/spark-submit examples/src/main/r/dataframe.R ``` -You can also run the unit tests for SparkR by running. You need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first: -```bash -R -e 'install.packages("testthat", repos="http://cran.us.r-project.org")' -./R/run-tests.sh -``` +You can run R unit tests by following the instructions under [Running R Tests](http://spark.apache.org/docs/latest/building-spark.html#running-r-tests). ### Running on YARN diff --git a/R/WINDOWS.md b/R/WINDOWS.md index 9ca7e58e20cd2..124bc631be9cd 100644 --- a/R/WINDOWS.md +++ b/R/WINDOWS.md @@ -34,10 +34,9 @@ To run the SparkR unit tests on Windows, the following steps are required —ass 4. Set the environment variable `HADOOP_HOME` to the full path to the newly created `hadoop` directory. -5. Run unit tests for SparkR by running the command below. You need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first: +5. Run unit tests for SparkR by running the command below. You need to install the needed packages following the instructions under [Running R Tests](http://spark.apache.org/docs/latest/building-spark.html#running-r-tests) first: ``` - R -e "install.packages('testthat', repos='http://cran.us.r-project.org')" .\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R ``` diff --git a/docs/building-spark.md b/docs/building-spark.md index 0f551bc66b8c9..777635a64f83c 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -218,9 +218,11 @@ The run-tests script also can be limited to a specific Python version or a speci ## Running R Tests -To run the SparkR tests you will need to install the R package `testthat` -(run `install.packages(testthat)` from R shell). You can run just the SparkR tests using -the command: +To run the SparkR tests you will need to install the [knitr](https://cran.r-project.org/package=knitr), [rmarkdown](https://cran.r-project.org/package=rmarkdown), [testthat](https://cran.r-project.org/package=testthat), [e1071](https://cran.r-project.org/package=e1071) and [survival](https://cran.r-project.org/package=survival) packages first: + + R -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival'), repos='http://cran.us.r-project.org')" + +You can run just the SparkR tests using the command: ./R/run-tests.sh From 93dd0c518d040155b04e5ab258c5835aec7776fc Mon Sep 17 00:00:00 2001 From: jinxing Date: Fri, 16 Jun 2017 20:09:45 +0800 Subject: [PATCH 015/118] [SPARK-20994] Remove redundant characters in OpenBlocks to save memory for shuffle service. ## What changes were proposed in this pull request? In current code, blockIds in `OpenBlocks` are stored in the iterator on shuffle service. There are some redundant characters in blockId(`"shuffle_" + shuffleId + "_" + mapId + "_" + reduceId`). This pr proposes to improve the footprint and alleviate the memory pressure on shuffle service. Author: jinxing Closes #18231 from jinxing64/SPARK-20994-v2. --- .../shuffle/ExternalShuffleBlockHandler.java | 70 +++++++++++++------ .../shuffle/ExternalShuffleBlockResolver.java | 23 +++--- .../network/sasl/SaslIntegrationSuite.java | 2 +- .../ExternalShuffleBlockHandlerSuite.java | 11 +-- .../ExternalShuffleBlockResolverSuite.java | 10 +-- .../ExternalShuffleIntegrationSuite.java | 8 +-- 6 files changed, 73 insertions(+), 51 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index c0f1da50f5e65..fc7bba41185f0 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -44,7 +44,6 @@ import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; import org.apache.spark.network.util.TransportConf; - /** * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. * @@ -91,26 +90,8 @@ protected void handleMessage( try { OpenBlocks msg = (OpenBlocks) msgObj; checkAuth(client, msg.appId); - - Iterator iter = new Iterator() { - private int index = 0; - - @Override - public boolean hasNext() { - return index < msg.blockIds.length; - } - - @Override - public ManagedBuffer next() { - final ManagedBuffer block = blockManager.getBlockData(msg.appId, msg.execId, - msg.blockIds[index]); - index++; - metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); - return block; - } - }; - - long streamId = streamManager.registerStream(client.getClientId(), iter); + long streamId = streamManager.registerStream(client.getClientId(), + new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds)); if (logger.isTraceEnabled()) { logger.trace("Registered streamId {} with {} buffers for client {} from host {}", streamId, @@ -209,4 +190,51 @@ public Map getMetrics() { } } + private class ManagedBufferIterator implements Iterator { + + private int index = 0; + private final String appId; + private final String execId; + private final int shuffleId; + // An array containing mapId and reduceId pairs. + private final int[] mapIdAndReduceIds; + + ManagedBufferIterator(String appId, String execId, String[] blockIds) { + this.appId = appId; + this.execId = execId; + String[] blockId0Parts = blockIds[0].split("_"); + if (blockId0Parts.length != 4 || !blockId0Parts[0].equals("shuffle")) { + throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[0]); + } + this.shuffleId = Integer.parseInt(blockId0Parts[1]); + mapIdAndReduceIds = new int[2 * blockIds.length]; + for (int i = 0; i < blockIds.length; i++) { + String[] blockIdParts = blockIds[i].split("_"); + if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { + throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]); + } + if (Integer.parseInt(blockIdParts[1]) != shuffleId) { + throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + + ", got:" + blockIds[i]); + } + mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]); + mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]); + } + } + + @Override + public boolean hasNext() { + return index < mapIdAndReduceIds.length; + } + + @Override + public ManagedBuffer next() { + final ManagedBuffer block = blockManager.getBlockData(appId, execId, shuffleId, + mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); + index += 2; + metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); + return block; + } + } + } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 62d58aba4c1e7..d7ec0e299dead 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -150,27 +150,20 @@ public void registerExecutor( } /** - * Obtains a FileSegmentManagedBuffer from a shuffle block id. We expect the blockId has the - * format "shuffle_ShuffleId_MapId_ReduceId" (from ShuffleBlockId), and additionally make - * assumptions about how the hash and sort based shuffles store their data. + * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId). We make assumptions + * about how the hash and sort based shuffles store their data. */ - public ManagedBuffer getBlockData(String appId, String execId, String blockId) { - String[] blockIdParts = blockId.split("_"); - if (blockIdParts.length < 4) { - throw new IllegalArgumentException("Unexpected block id format: " + blockId); - } else if (!blockIdParts[0].equals("shuffle")) { - throw new IllegalArgumentException("Expected shuffle block id, got: " + blockId); - } - int shuffleId = Integer.parseInt(blockIdParts[1]); - int mapId = Integer.parseInt(blockIdParts[2]); - int reduceId = Integer.parseInt(blockIdParts[3]); - + public ManagedBuffer getBlockData( + String appId, + String execId, + int shuffleId, + int mapId, + int reduceId) { ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); if (executor == null) { throw new RuntimeException( String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); } - return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 0c054fc5db8f4..8110f1e004c73 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -202,7 +202,7 @@ public void onBlockFetchFailure(String blockId, Throwable t) { } }; - String[] blockIds = { "shuffle_2_3_4", "shuffle_6_7_8" }; + String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" }; OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf, null); fetcher.start(); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 4d48b18970386..7846b71d5a8b1 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -83,9 +83,10 @@ public void testOpenShuffleBlocks() { ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); - when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); - when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); - ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(block0Marker); + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(block1Marker); + ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", + new String[] { "shuffle_0_0_0", "shuffle_0_0_1" }) .toByteBuffer(); handler.receive(client, openBlocks, callback); @@ -105,8 +106,8 @@ public void testOpenShuffleBlocks() { assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); assertFalse(buffers.hasNext()); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1); // Verify open block request latency metrics Timer openBlockRequestLatencyMillis = (Timer) ((ExternalShuffleBlockHandler) handler) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index bc97594903bef..23438a08fa094 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -65,7 +65,7 @@ public void testBadRequests() throws IOException { ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); // Unregistered executor try { - resolver.getBlockData("app0", "exec1", "shuffle_1_1_0"); + resolver.getBlockData("app0", "exec1", 1, 1, 0); fail("Should have failed"); } catch (RuntimeException e) { assertTrue("Bad error message: " + e, e.getMessage().contains("not registered")); @@ -74,7 +74,7 @@ public void testBadRequests() throws IOException { // Invalid shuffle manager try { resolver.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar")); - resolver.getBlockData("app0", "exec2", "shuffle_1_1_0"); + resolver.getBlockData("app0", "exec2", 1, 1, 0); fail("Should have failed"); } catch (UnsupportedOperationException e) { // pass @@ -84,7 +84,7 @@ public void testBadRequests() throws IOException { resolver.registerExecutor("app0", "exec3", dataContext.createExecutorInfo(SORT_MANAGER)); try { - resolver.getBlockData("app0", "exec3", "shuffle_1_1_0"); + resolver.getBlockData("app0", "exec3", 1, 1, 0); fail("Should have failed"); } catch (Exception e) { // pass @@ -98,14 +98,14 @@ public void testSortShuffleBlocks() throws IOException { dataContext.createExecutorInfo(SORT_MANAGER)); InputStream block0Stream = - resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); + resolver.getBlockData("app0", "exec0", 0, 0, 0).createInputStream(); String block0 = CharStreams.toString( new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); block0Stream.close(); assertEquals(sortBlock0, block0); InputStream block1Stream = - resolver.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream(); + resolver.getBlockData("app0", "exec0", 0, 0, 1).createInputStream(); String block1 = CharStreams.toString( new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); block1Stream.close(); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index d1d8f5b4e188a..4391e3023491b 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -214,10 +214,10 @@ public void testFetchNonexistent() throws Exception { @Test public void testFetchWrongExecutor() throws Exception { registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-0", - new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ }); - assertEquals(Sets.newHashSet("shuffle_0_0_0"), execFetch.successBlocks); - assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks); + FetchResult execFetch0 = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" /* right */}); + FetchResult execFetch1 = fetchBlocks("exec-0", new String[] { "shuffle_1_0_0" /* wrong */ }); + assertEquals(Sets.newHashSet("shuffle_0_0_0"), execFetch0.successBlocks); + assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch1.failedBlocks); } @Test From d1c333ac77e2554832477fd9ec56fb0b2015cde6 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 16 Jun 2017 08:05:43 -0700 Subject: [PATCH 016/118] [SPARK-21119][SQL] unset table properties should keep the table comment ## What changes were proposed in this pull request? Previous code mistakenly use `table.properties.get("comment")` to read the existing table comment, we should use `table.comment` ## How was this patch tested? new regression test Author: Wenchen Fan Closes #18325 from cloud-fan/unset. --- .../spark/sql/execution/command/ddl.scala | 4 +- .../resources/sql-tests/inputs/describe.sql | 8 + .../sql-tests/results/describe.sql.out | 201 ++++++++++++------ 3 files changed, 148 insertions(+), 65 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index f924b3d914635..413f5f3ba539c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -264,14 +264,14 @@ case class AlterTableUnsetPropertiesCommand( DDLUtils.verifyAlterTableType(catalog, table, isView) if (!ifExists) { propKeys.foreach { k => - if (!table.properties.contains(k)) { + if (!table.properties.contains(k) && k != "comment") { throw new AnalysisException( s"Attempted to unset non-existent property '$k' in table '${table.identifier}'") } } } // If comment is in the table property, we reset it to None - val tableComment = if (propKeys.contains("comment")) None else table.properties.get("comment") + val tableComment = if (propKeys.contains("comment")) None else table.comment val newProperties = table.properties.filter { case (k, _) => !propKeys.contains(k) } val newTable = table.copy(properties = newProperties, comment = tableComment) catalog.alterTable(newTable) diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql index 91b966829f8fb..a222e11916cda 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/describe.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql @@ -28,6 +28,14 @@ DESC FORMATTED t; DESC EXTENDED t; +ALTER TABLE t UNSET TBLPROPERTIES (e); + +DESC EXTENDED t; + +ALTER TABLE t UNSET TBLPROPERTIES (comment); + +DESC EXTENDED t; + DESC t PARTITION (c='Us', d=1); DESC EXTENDED t PARTITION (c='Us', d=1); diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index ab9f2783f06bb..e2b79e8f7801d 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 32 +-- Number of queries: 36 -- !query 0 @@ -166,10 +166,85 @@ Partition Provider Catalog -- !query 11 -DESC t PARTITION (c='Us', d=1) +ALTER TABLE t UNSET TBLPROPERTIES (e) -- !query 11 schema -struct +struct<> -- !query 11 output + + + +-- !query 12 +DESC EXTENDED t +-- !query 12 schema +struct +-- !query 12 output +a string +b int +c string +d string +# Partition Information +# col_name data_type comment +c string +d string + +# Detailed Table Information +Database default +Table t +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Comment table_comment +Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] +Partition Provider Catalog + + +-- !query 13 +ALTER TABLE t UNSET TBLPROPERTIES (comment) +-- !query 13 schema +struct<> +-- !query 13 output + + + +-- !query 14 +DESC EXTENDED t +-- !query 14 schema +struct +-- !query 14 output +a string +b int +c string +d string +# Partition Information +# col_name data_type comment +c string +d string + +# Detailed Table Information +Database default +Table t +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] +Partition Provider Catalog + + +-- !query 15 +DESC t PARTITION (c='Us', d=1) +-- !query 15 schema +struct +-- !query 15 output a string b int c string @@ -180,11 +255,11 @@ c string d string --- !query 12 +-- !query 16 DESC EXTENDED t PARTITION (c='Us', d=1) --- !query 12 schema +-- !query 16 schema struct --- !query 12 output +-- !query 16 output a string b int c string @@ -209,11 +284,11 @@ Location [not included in comparison]sql/core/spark-warehouse/t Storage Properties [a=1, b=2] --- !query 13 +-- !query 17 DESC FORMATTED t PARTITION (c='Us', d=1) --- !query 13 schema +-- !query 17 schema struct --- !query 13 output +-- !query 17 output a string b int c string @@ -238,31 +313,31 @@ Location [not included in comparison]sql/core/spark-warehouse/t Storage Properties [a=1, b=2] --- !query 14 +-- !query 18 DESC t PARTITION (c='Us', d=2) --- !query 14 schema +-- !query 18 schema struct<> --- !query 14 output +-- !query 18 output org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException Partition not found in table 't' database 'default': c -> Us d -> 2; --- !query 15 +-- !query 19 DESC t PARTITION (c='Us') --- !query 15 schema +-- !query 19 schema struct<> --- !query 15 output +-- !query 19 output org.apache.spark.sql.AnalysisException Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`default`.`t`'; --- !query 16 +-- !query 20 DESC t PARTITION (c='Us', d) --- !query 16 schema +-- !query 20 schema struct<> --- !query 16 output +-- !query 20 output org.apache.spark.sql.catalyst.parser.ParseException PARTITION specification is incomplete: `d`(line 1, pos 0) @@ -272,55 +347,55 @@ DESC t PARTITION (c='Us', d) ^^^ --- !query 17 +-- !query 21 DESC temp_v --- !query 17 schema +-- !query 21 schema struct --- !query 17 output +-- !query 21 output a string b int c string d string --- !query 18 +-- !query 22 DESC TABLE temp_v --- !query 18 schema +-- !query 22 schema struct --- !query 18 output +-- !query 22 output a string b int c string d string --- !query 19 +-- !query 23 DESC FORMATTED temp_v --- !query 19 schema +-- !query 23 schema struct --- !query 19 output +-- !query 23 output a string b int c string d string --- !query 20 +-- !query 24 DESC EXTENDED temp_v --- !query 20 schema +-- !query 24 schema struct --- !query 20 output +-- !query 24 output a string b int c string d string --- !query 21 +-- !query 25 DESC temp_Data_Source_View --- !query 21 schema +-- !query 25 schema struct --- !query 21 output +-- !query 25 output intType int test comment test1 stringType string dateType date @@ -339,42 +414,42 @@ arrayType array structType struct --- !query 22 +-- !query 26 DESC temp_v PARTITION (c='Us', d=1) --- !query 22 schema +-- !query 26 schema struct<> --- !query 22 output +-- !query 26 output org.apache.spark.sql.AnalysisException DESC PARTITION is not allowed on a temporary view: temp_v; --- !query 23 +-- !query 27 DESC v --- !query 23 schema +-- !query 27 schema struct --- !query 23 output +-- !query 27 output a string b int c string d string --- !query 24 +-- !query 28 DESC TABLE v --- !query 24 schema +-- !query 28 schema struct --- !query 24 output +-- !query 28 output a string b int c string d string --- !query 25 +-- !query 29 DESC FORMATTED v --- !query 25 schema +-- !query 29 schema struct --- !query 25 output +-- !query 29 output a string b int c string @@ -392,11 +467,11 @@ View Query Output Columns [a, b, c, d] Table Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] --- !query 26 +-- !query 30 DESC EXTENDED v --- !query 26 schema +-- !query 30 schema struct --- !query 26 output +-- !query 30 output a string b int c string @@ -414,42 +489,42 @@ View Query Output Columns [a, b, c, d] Table Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] --- !query 27 +-- !query 31 DESC v PARTITION (c='Us', d=1) --- !query 27 schema +-- !query 31 schema struct<> --- !query 27 output +-- !query 31 output org.apache.spark.sql.AnalysisException DESC PARTITION is not allowed on a view: v; --- !query 28 +-- !query 32 DROP TABLE t --- !query 28 schema +-- !query 32 schema struct<> --- !query 28 output +-- !query 32 output --- !query 29 +-- !query 33 DROP VIEW temp_v --- !query 29 schema +-- !query 33 schema struct<> --- !query 29 output +-- !query 33 output --- !query 30 +-- !query 34 DROP VIEW temp_Data_Source_View --- !query 30 schema +-- !query 34 schema struct<> --- !query 30 output +-- !query 34 output --- !query 31 +-- !query 35 DROP VIEW v --- !query 31 schema +-- !query 35 schema struct<> --- !query 31 output +-- !query 35 output From 53e48f73e42bb3eea075894ff08494e0abe9d60a Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 16 Jun 2017 09:40:58 -0700 Subject: [PATCH 017/118] [SPARK-20931][SQL] ABS function support string type. ## What changes were proposed in this pull request? ABS function support string type. Hive/MySQL support this feature. Ref: https://github.com/apache/hive/blob/4ba713ccd85c3706d195aeef9476e6e6363f1c21/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFAbs.java#L93 ## How was this patch tested? unit tests Author: Yuming Wang Closes #18153 from wangyum/SPARK-20931. --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 1 + .../analysis/ExpressionTypeCheckingSuite.scala | 1 - .../src/test/resources/sql-tests/inputs/operators.sql | 3 +++ .../test/resources/sql-tests/results/operators.sql.out | 10 +++++++++- 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 1f217390518a6..6082c58e2c53a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -357,6 +357,7 @@ object TypeCoercion { val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) + case Abs(e @ StringType()) => Abs(Cast(e, DoubleType)) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 744057b7c5f4c..2239bf815de71 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -57,7 +57,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { test("check types for unary arithmetic") { assertError(UnaryMinus('stringField), "(numeric or calendarinterval) type") - assertError(Abs('stringField), "requires numeric type") assertError(BitwiseNot('stringField), "requires integral type") } diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index a8de23e73892c..a1e8a32ed8f66 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -85,3 +85,6 @@ select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, nu select BIT_LENGTH('abc'); select CHAR_LENGTH('abc'); select OCTET_LENGTH('abc'); + +-- abs +select abs(-3.13), abs('-2.19'); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 85ee10b4d274f..eac3080bec67d 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 54 +-- Number of queries: 55 -- !query 0 @@ -444,3 +444,11 @@ select OCTET_LENGTH('abc') struct -- !query 53 output 3 + + +-- !query 54 +select abs(-3.13), abs('-2.19') +-- !query 54 schema +struct +-- !query 54 output +3.13 2.19 From edcb878e2fbd0d85bf70614fed37f4cbf0caa95e Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Fri, 16 Jun 2017 10:34:52 -0700 Subject: [PATCH 018/118] [SPARK-20338][CORE] Spaces in spark.eventLog.dir are not correctly handled MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? “spark.eventLog.dir” supports with space characters. 1. Update EventLoggingListenerSuite like `testDir = Utils.createTempDir(namePrefix = s"history log")` 2. Fix EventLoggingListenerSuite tests ## How was this patch tested? update unit tests Author: zuotingbing Closes #18285 from zuotingbing/spark-resolveURI. --- .../org/apache/spark/scheduler/EventLoggingListener.scala | 4 ++-- .../spark/deploy/history/FsHistoryProviderSuite.scala | 5 ++--- .../apache/spark/scheduler/EventLoggingListenerSuite.scala | 7 +++---- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index f481436332249..35690b2783ad3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -96,8 +96,8 @@ private[spark] class EventLoggingListener( } val workingPath = logPath + IN_PROGRESS - val uri = new URI(workingPath) val path = new Path(workingPath) + val uri = path.toUri val defaultFs = FileSystem.getDefaultUri(hadoopConf).getScheme val isDefaultLocal = defaultFs == null || defaultFs == "file" @@ -320,7 +320,7 @@ private[spark] object EventLoggingListener extends Logging { appId: String, appAttemptId: Option[String], compressionCodecName: Option[String] = None): String = { - val base = logBaseDir.toString.stripSuffix("/") + "/" + sanitize(appId) + val base = new Path(logBaseDir).toString.stripSuffix("/") + "/" + sanitize(appId) val codec = compressionCodecName.map("." + _).getOrElse("") if (appAttemptId.isDefined) { base + "_" + sanitize(appAttemptId.get) + codec diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 9b3e4ec793825..7109146ece371 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.history import java.io._ -import java.net.URI import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit import java.util.zip.{ZipInputStream, ZipOutputStream} @@ -27,7 +26,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} -import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any @@ -63,7 +62,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc codec: Option[String] = None): File = { val ip = if (inProgress) EventLoggingListener.IN_PROGRESS else "" val logUri = EventLoggingListener.getLogPath(testDir.toURI, appId, appAttemptId) - val logPath = new URI(logUri).getPath + ip + val logPath = new Path(logUri).toUri.getPath + ip new File(logPath) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 4cae6c61118a8..0afd07b851cf9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.scheduler import java.io.{File, FileOutputStream, InputStream, IOException} -import java.net.URI import scala.collection.mutable import scala.io.Source @@ -52,7 +51,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit private var testDirPath: Path = _ before { - testDir = Utils.createTempDir() + testDir = Utils.createTempDir(namePrefix = s"history log") testDir.deleteOnExit() testDirPath = new Path(testDir.getAbsolutePath()) } @@ -111,7 +110,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit test("Log overwriting") { val logUri = EventLoggingListener.getLogPath(testDir.toURI, "test", None) - val logPath = new URI(logUri).getPath + val logPath = new Path(logUri).toUri.getPath // Create file before writing the event log new FileOutputStream(new File(logPath)).close() // Expected IOException, since we haven't enabled log overwrite. @@ -293,7 +292,7 @@ object EventLoggingListenerSuite { val conf = new SparkConf conf.set("spark.eventLog.enabled", "true") conf.set("spark.eventLog.testing", "true") - conf.set("spark.eventLog.dir", logDir.toUri.toString) + conf.set("spark.eventLog.dir", logDir.toString) compressionCodec.foreach { codec => conf.set("spark.eventLog.compress", "true") conf.set("spark.io.compression.codec", codec) From 0d8604bb849b3370cc21966cdd773238f3a29f84 Mon Sep 17 00:00:00 2001 From: liuzhaokun Date: Sun, 18 Jun 2017 08:32:29 +0100 Subject: [PATCH 019/118] [SPARK-21126] The configuration which named "spark.core.connection.auth.wait.timeout" hasn't been used in spark [https://issues.apache.org/jira/browse/SPARK-21126](https://issues.apache.org/jira/browse/SPARK-21126) The configuration which named "spark.core.connection.auth.wait.timeout" hasn't been used in spark,so I think it should be removed from configuration.md. Author: liuzhaokun Closes #18333 from liu-zhaokun/new3. --- docs/configuration.md | 8 -------- 1 file changed, 8 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index f777811a93f62..c1464741ebb6f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1774,14 +1774,6 @@ Apart from these, the following properties are also available, and may be useful you can set larger value. - - spark.core.connection.auth.wait.timeout - 30s - - How long for the connection to wait for authentication to occur before timing - out and giving up. - - spark.modify.acls Empty From 75a6d05853fea13f88e3c941b1959b24e4640824 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 18 Jun 2017 08:43:47 +0100 Subject: [PATCH 020/118] [MINOR][R] Add knitr and rmarkdown packages/improve output for version info in AppVeyor tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR proposes three things as below: **Install packages per documentation** - this does not affect the tests itself (but CRAN which we are not doing via AppVeyor) up to my knowledge. This adds `knitr` and `rmarkdown` per https://github.com/apache/spark/blob/45824fb608930eb461e7df53bb678c9534c183a9/R/WINDOWS.md#unit-tests (please see https://github.com/apache/spark/commit/45824fb608930eb461e7df53bb678c9534c183a9) **Improve logs/shorten logs** - actually, long logs can be a problem on AppVeyor (e.g., see https://github.com/apache/spark/pull/17873) `R -e ...` repeats printing R information for each invocation as below: ``` R version 3.3.1 (2016-06-21) -- "Bug in Your Hair" Copyright (C) 2016 The R Foundation for Statistical Computing Platform: i386-w64-mingw32/i386 (32-bit) R is free software and comes with ABSOLUTELY NO WARRANTY. You are welcome to redistribute it under certain conditions. Type 'license()' or 'licence()' for distribution details. Natural language support but running in an English locale R is a collaborative project with many contributors. Type 'contributors()' for more information and 'citation()' on how to cite R or R packages in publications. Type 'demo()' for some demos, 'help()' for on-line help, or 'help.start()' for an HTML browser interface to help. Type 'q()' to quit R. ``` It looks reducing the call might be slightly better and print out the versions together looks more readable. Before: ``` # R information ... > packageVersion('testthat') [1] '1.0.2' > > # R information ... > packageVersion('e1071') [1] '1.6.8' > > ... 3 more times ``` After: ``` # R information ... > packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival') [1] ‘1.16’ [1] ‘1.6’ [1] ‘1.0.2’ [1] ‘1.6.8’ [1] ‘2.41.3’ ``` **Add`appveyor.yml`/`dev/appveyor-install-dependencies.ps1` for triggering the test** Changing this file might break the test, e.g., https://github.com/apache/spark/pull/16927 ## How was this patch tested? Before (please see https://ci.appveyor.com/project/HyukjinKwon/spark/build/169-master) After (please see the AppVeyor build in this PR): Author: hyukjinkwon Closes #18336 from HyukjinKwon/minor-add-knitr-and-rmarkdown. --- appveyor.yml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index 58c2e98289e96..43dad9bce60ac 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -26,6 +26,8 @@ branches: only_commits: files: + - appveyor.yml + - dev/appveyor-install-dependencies.ps1 - R/ - sql/core/src/main/scala/org/apache/spark/sql/api/r/ - core/src/main/scala/org/apache/spark/api/r/ @@ -38,12 +40,8 @@ install: # Install maven and dependencies - ps: .\dev\appveyor-install-dependencies.ps1 # Required package for R unit tests - - cmd: R -e "install.packages('testthat', repos='http://cran.us.r-project.org')" - - cmd: R -e "packageVersion('testthat')" - - cmd: R -e "install.packages('e1071', repos='http://cran.us.r-project.org')" - - cmd: R -e "packageVersion('e1071')" - - cmd: R -e "install.packages('survival', repos='http://cran.us.r-project.org')" - - cmd: R -e "packageVersion('survival')" + - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival'), repos='http://cran.us.r-project.org')" + - cmd: R -e "packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival')" build_script: - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package From 05f83c532a96ead8dec1c046f985164b7f7205c0 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 18 Jun 2017 11:26:27 -0700 Subject: [PATCH 021/118] [SPARK-21128][R] Remove both "spark-warehouse" and "metastore_db" before listing files in R tests ## What changes were proposed in this pull request? This PR proposes to list the files in test _after_ removing both "spark-warehouse" and "metastore_db" so that the next run of R tests pass fine. This is sometimes a bit annoying. ## How was this patch tested? Manually running multiple times R tests via `./R/run-tests.sh`. **Before** Second run: ``` SparkSQL functions: Spark package found in SPARK_HOME: .../spark ............................................................................................................................................................... ............................................................................................................................................................... ............................................................................................................................................................... ............................................................................................................................................................... ............................................................................................................................................................... ....................................................................................................1234....................... Failed ------------------------------------------------------------------------- 1. Failure: No extra files are created in SPARK_HOME by starting session and making calls (test_sparkSQL.R#3384) length(list1) not equal to length(list2). 1/1 mismatches [1] 25 - 23 == 2 2. Failure: No extra files are created in SPARK_HOME by starting session and making calls (test_sparkSQL.R#3384) sort(list1, na.last = TRUE) not equal to sort(list2, na.last = TRUE). 10/25 mismatches x[16]: "metastore_db" y[16]: "pkg" x[17]: "pkg" y[17]: "R" x[18]: "R" y[18]: "README.md" x[19]: "README.md" y[19]: "run-tests.sh" x[20]: "run-tests.sh" y[20]: "SparkR_2.2.0.tar.gz" x[21]: "metastore_db" y[21]: "pkg" x[22]: "pkg" y[22]: "R" x[23]: "R" y[23]: "README.md" x[24]: "README.md" y[24]: "run-tests.sh" x[25]: "run-tests.sh" y[25]: "SparkR_2.2.0.tar.gz" 3. Failure: No extra files are created in SPARK_HOME by starting session and making calls (test_sparkSQL.R#3388) length(list1) not equal to length(list2). 1/1 mismatches [1] 25 - 23 == 2 4. Failure: No extra files are created in SPARK_HOME by starting session and making calls (test_sparkSQL.R#3388) sort(list1, na.last = TRUE) not equal to sort(list2, na.last = TRUE). 10/25 mismatches x[16]: "metastore_db" y[16]: "pkg" x[17]: "pkg" y[17]: "R" x[18]: "R" y[18]: "README.md" x[19]: "README.md" y[19]: "run-tests.sh" x[20]: "run-tests.sh" y[20]: "SparkR_2.2.0.tar.gz" x[21]: "metastore_db" y[21]: "pkg" x[22]: "pkg" y[22]: "R" x[23]: "R" y[23]: "README.md" x[24]: "README.md" y[24]: "run-tests.sh" x[25]: "run-tests.sh" y[25]: "SparkR_2.2.0.tar.gz" DONE =========================================================================== ``` **After** Second run: ``` SparkSQL functions: Spark package found in SPARK_HOME: .../spark ............................................................................................................................................................... ............................................................................................................................................................... ............................................................................................................................................................... ............................................................................................................................................................... ............................................................................................................................................................... ............................................................................................................................... ``` Author: hyukjinkwon Closes #18335 from HyukjinKwon/SPARK-21128. --- R/pkg/tests/run-all.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index f00a610679752..0aefd8006caa4 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -30,10 +30,10 @@ if (.Platform$OS.type == "windows") { install.spark() sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") -sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") invisible(lapply(sparkRWhitelistSQLDirs, function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) +sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) sparkRTestMaster <- "local[1]" if (identical(Sys.getenv("NOT_CRAN"), "true")) { From 110ce1f27b66905afada6b5fd63c34fbf7602739 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Sun, 18 Jun 2017 18:00:27 -0700 Subject: [PATCH 022/118] [SPARK-20892][SPARKR] Add SQL trunc function to SparkR ## What changes were proposed in this pull request? Add SQL trunc function ## How was this patch tested? standard test Author: actuaryzhang Closes #18291 from actuaryzhang/sparkRTrunc2. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 29 +++++++++++++++++++++++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 2 ++ 3 files changed, 32 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 4e3fe00a2e9bd..229de4a997eef 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -357,6 +357,7 @@ exportMethods("%<=>%", "to_utc_timestamp", "translate", "trim", + "trunc", "unbase64", "unhex", "unix_timestamp", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 06a90192bb12f..7128c3b9adff4 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -4015,3 +4015,32 @@ setMethod("input_file_name", signature("missing"), jc <- callJStatic("org.apache.spark.sql.functions", "input_file_name") column(jc) }) + +#' trunc +#' +#' Returns date truncated to the unit specified by the format. +#' +#' @param x Column to compute on. +#' @param format string used for specify the truncation method. For example, "year", "yyyy", +#' "yy" for truncate by year, or "month", "mon", "mm" for truncate by month. +#' +#' @rdname trunc +#' @name trunc +#' @family date time functions +#' @aliases trunc,Column-method +#' @export +#' @examples +#' \dontrun{ +#' trunc(df$c, "year") +#' trunc(df$c, "yy") +#' trunc(df$c, "month") +#' trunc(df$c, "mon") +#' } +#' @note trunc since 2.3.0 +setMethod("trunc", + signature(x = "Column"), + function(x, format) { + jc <- callJStatic("org.apache.spark.sql.functions", "trunc", + x@jc, as.character(format)) + column(jc) + }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index af529067f43e0..911b73b9ee551 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1382,6 +1382,8 @@ test_that("column functions", { c20 <- to_timestamp(c) + to_timestamp(c, "yyyy") + to_date(c, "yyyy") c21 <- posexplode_outer(c) + explode_outer(c) c22 <- not(c) + c23 <- trunc(c, "year") + trunc(c, "yyyy") + trunc(c, "yy") + + trunc(c, "month") + trunc(c, "mon") + trunc(c, "mm") # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) From ce49428ef7d640c1734e91ffcddc49dbc8547ba7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 18 Jun 2017 18:56:53 -0700 Subject: [PATCH 023/118] [SPARK-20749][SQL][FOLLOWUP] Support character_length ## What changes were proposed in this pull request? The function `char_length` is shorthand for `character_length` function. Both Hive and Postgresql support `character_length`, This PR add support for `character_length`. Ref: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-StringFunctions https://www.postgresql.org/docs/current/static/functions-string.html ## How was this patch tested? unit tests Author: Yuming Wang Closes #18330 from wangyum/SPARK-20749-character_length. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/stringExpressions.scala | 4 ++++ .../resources/sql-tests/inputs/operators.sql | 1 + .../sql-tests/results/operators.sql.out | 18 +++++++++++++----- 4 files changed, 19 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e4e9918a3a887..f4b3e86052d8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -307,6 +307,7 @@ object FunctionRegistry { expression[Base64]("base64"), expression[BitLength]("bit_length"), expression[Length]("char_length"), + expression[Length]("character_length"), expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), expression[Decode]("decode"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 908fdb8f7e68f..83fdcfce9c3bd 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1209,6 +1209,10 @@ case class Substring(str: Expression, pos: Expression, len: Expression) Examples: > SELECT _FUNC_('Spark SQL'); 9 + > SELECT CHAR_LENGTH('Spark SQL'); + 9 + > SELECT CHARACTER_LENGTH('Spark SQL'); + 9 """) // scalastyle:on line.size.limit case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index a1e8a32ed8f66..9841ec4b65983 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -84,6 +84,7 @@ select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, nu -- length select BIT_LENGTH('abc'); select CHAR_LENGTH('abc'); +select CHARACTER_LENGTH('abc'); select OCTET_LENGTH('abc'); -- abs diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index eac3080bec67d..4a6ef27c3be42 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 55 +-- Number of queries: 56 -- !query 0 @@ -439,16 +439,24 @@ struct -- !query 53 -select OCTET_LENGTH('abc') +select CHARACTER_LENGTH('abc') -- !query 53 schema -struct +struct -- !query 53 output 3 -- !query 54 -select abs(-3.13), abs('-2.19') +select OCTET_LENGTH('abc') -- !query 54 schema -struct +struct -- !query 54 output +3 + + +-- !query 55 +select abs(-3.13), abs('-2.19') +-- !query 55 schema +struct +-- !query 55 output 3.13 2.19 From f913f158ec41bd3de9dc229b908aaab0dbd60d27 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 18 Jun 2017 20:14:05 -0700 Subject: [PATCH 024/118] [SPARK-20948][SQL] Built-in SQL Function UnaryMinus/UnaryPositive support string type ## What changes were proposed in this pull request? Built-in SQL Function UnaryMinus/UnaryPositive support string type, if it's string type, convert it to double type, after this PR: ```sql spark-sql> select positive('-1.11'), negative('-1.11'); -1.11 1.11 spark-sql> ``` ## How was this patch tested? unit tests Author: Yuming Wang Closes #18173 from wangyum/SPARK-20948. --- .../apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 2 ++ .../catalyst/analysis/ExpressionTypeCheckingSuite.scala | 1 - .../src/test/resources/sql-tests/inputs/operators.sql | 3 +++ .../test/resources/sql-tests/results/operators.sql.out | 8 ++++++++ 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 6082c58e2c53a..a78e1c98e89de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -362,6 +362,8 @@ object TypeCoercion { case Average(e @ StringType()) => Average(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) + case UnaryMinus(e @ StringType()) => UnaryMinus(Cast(e, DoubleType)) + case UnaryPositive(e @ StringType()) => UnaryPositive(Cast(e, DoubleType)) case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 2239bf815de71..30459f173ab52 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -56,7 +56,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for unary arithmetic") { - assertError(UnaryMinus('stringField), "(numeric or calendarinterval) type") assertError(BitwiseNot('stringField), "requires integral type") } diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 9841ec4b65983..a766275192492 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -89,3 +89,6 @@ select OCTET_LENGTH('abc'); -- abs select abs(-3.13), abs('-2.19'); + +-- positive/negative +select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 4a6ef27c3be42..5cb6ed3e27bf2 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -460,3 +460,11 @@ select abs(-3.13), abs('-2.19') struct -- !query 55 output 3.13 2.19 + + +-- !query 55 +select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11) +-- !query 55 schema +struct<(+ CAST(-1.11 AS DOUBLE)):double,(+ -1.11):decimal(3,2),(- CAST(-1.11 AS DOUBLE)):double,(- -1.11):decimal(3,2)> +-- !query 55 output +-1.11 -1.11 1.11 1.11 From 112bd9bfc5b9729f6f86518998b5d80c5e79fe5e Mon Sep 17 00:00:00 2001 From: liuxian Date: Mon, 19 Jun 2017 11:46:58 +0800 Subject: [PATCH 025/118] [SPARK-21090][CORE] Optimize the unified memory manager code ## What changes were proposed in this pull request? 1.In `acquireStorageMemory`, when the Memory Mode is OFF_HEAP ,the `maxOffHeapMemory` should be modified to `maxOffHeapStorageMemory`. after this PR,it will same as ON_HEAP Memory Mode. Because when acquire memory is between `maxOffHeapStorageMemory` and `maxOffHeapMemory`,it will fail surely, so if acquire memory is greater than `maxOffHeapStorageMemory`(not greater than `maxOffHeapMemory`),we should fail fast. 2. Borrow memory from execution, `numBytes` modified to `numBytes - storagePool.memoryFree` will be more reasonable. Because we just acquire `(numBytes - storagePool.memoryFree)`, unnecessary borrowed `numBytes` from execution ## How was this patch tested? added unit test case Author: liuxian Closes #18296 from 10110346/wip-lx-0614. --- .../spark/memory/UnifiedMemoryManager.scala | 5 +-- .../spark/memory/MemoryManagerSuite.scala | 2 +- .../memory/UnifiedMemoryManagerSuite.scala | 32 +++++++++++++++++++ 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index fea2808218a53..df193552bed3c 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -160,7 +160,7 @@ private[spark] class UnifiedMemoryManager private[memory] ( case MemoryMode.OFF_HEAP => ( offHeapExecutionMemoryPool, offHeapStorageMemoryPool, - maxOffHeapMemory) + maxOffHeapStorageMemory) } if (numBytes > maxMemory) { // Fail fast if the block simply won't fit @@ -171,7 +171,8 @@ private[spark] class UnifiedMemoryManager private[memory] ( if (numBytes > storagePool.memoryFree) { // There is not enough free memory in the storage pool, so try to borrow free memory from // the execution pool. - val memoryBorrowedFromExecution = Math.min(executionPool.memoryFree, numBytes) + val memoryBorrowedFromExecution = Math.min(executionPool.memoryFree, + numBytes - storagePool.memoryFree) executionPool.decrementPoolSize(memoryBorrowedFromExecution) storagePool.incrementPoolSize(memoryBorrowedFromExecution) } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index eb2b3ffd1509a..85eeb5055ae03 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -117,7 +117,7 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft evictBlocksToFreeSpaceCalled.set(numBytesToFree) if (numBytesToFree <= mm.storageMemoryUsed) { // We can evict enough blocks to fulfill the request for space - mm.releaseStorageMemory(numBytesToFree, MemoryMode.ON_HEAP) + mm.releaseStorageMemory(numBytesToFree, mm.tungstenMemoryMode) evictedBlocks += Tuple2(null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L)) numBytesToFree } else { diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index c821054412d7d..02b04cdbb2a5f 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -303,4 +303,36 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes mm.invokePrivate[Unit](assertInvariants()) } + test("not enough free memory in the storage pool --OFF_HEAP") { + val conf = new SparkConf() + .set("spark.memory.offHeap.size", "1000") + .set("spark.testing.memory", "1000") + .set("spark.memory.offHeap.enabled", "true") + val taskAttemptId = 0L + val mm = UnifiedMemoryManager(conf, numCores = 1) + val ms = makeMemoryStore(mm) + val memoryMode = MemoryMode.OFF_HEAP + + assert(mm.acquireExecutionMemory(400L, taskAttemptId, memoryMode) === 400L) + assert(mm.storageMemoryUsed === 0L) + assert(mm.executionMemoryUsed === 400L) + + // Fail fast + assert(!mm.acquireStorageMemory(dummyBlock, 700L, memoryMode)) + assert(mm.storageMemoryUsed === 0L) + + assert(mm.acquireStorageMemory(dummyBlock, 100L, memoryMode)) + assert(mm.storageMemoryUsed === 100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) + + // Borrow 50 from execution memory + assert(mm.acquireStorageMemory(dummyBlock, 450L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 550L) + + // Borrow 50 from execution memory and evict 50 to free space + assert(mm.acquireStorageMemory(dummyBlock, 100L, memoryMode)) + assertEvictBlocksToFreeSpaceCalled(ms, 50) + assert(mm.storageMemoryUsed === 600L) + } } From ea542d29b2ae99cfff47fed40b7a9ab77d41b391 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sun, 18 Jun 2017 22:05:06 -0700 Subject: [PATCH 026/118] [SPARK-19824][CORE] Update JsonProtocol to keep consistent with the UI ## What changes were proposed in this pull request? Fix any inconsistent part in JsonProtocol with the UI. This PR also contains the modifications in #17181 ## How was this patch tested? Updated JsonProtocolSuite. Before this change, localhost:8080/json shows: ``` { "url" : "spark://xingbos-MBP.local:7077", "workers" : [ { "id" : "worker-20170615172946-192.168.0.101-49450", "host" : "192.168.0.101", "port" : 49450, "webuiaddress" : "http://192.168.0.101:8081", "cores" : 8, "coresused" : 8, "coresfree" : 0, "memory" : 15360, "memoryused" : 1024, "memoryfree" : 14336, "state" : "ALIVE", "lastheartbeat" : 1497519481722 }, { "id" : "worker-20170615172948-192.168.0.101-49452", "host" : "192.168.0.101", "port" : 49452, "webuiaddress" : "http://192.168.0.101:8082", "cores" : 8, "coresused" : 8, "coresfree" : 0, "memory" : 15360, "memoryused" : 1024, "memoryfree" : 14336, "state" : "ALIVE", "lastheartbeat" : 1497519484160 }, { "id" : "worker-20170615172951-192.168.0.101-49469", "host" : "192.168.0.101", "port" : 49469, "webuiaddress" : "http://192.168.0.101:8083", "cores" : 8, "coresused" : 8, "coresfree" : 0, "memory" : 15360, "memoryused" : 1024, "memoryfree" : 14336, "state" : "ALIVE", "lastheartbeat" : 1497519486905 } ], "cores" : 24, "coresused" : 24, "memory" : 46080, "memoryused" : 3072, "activeapps" : [ { "starttime" : 1497519426990, "id" : "app-20170615173706-0001", "name" : "Spark shell", "user" : "xingbojiang", "memoryperslave" : 1024, "submitdate" : "Thu Jun 15 17:37:06 CST 2017", "state" : "RUNNING", "duration" : 65362 } ], "completedapps" : [ { "starttime" : 1497519250893, "id" : "app-20170615173410-0000", "name" : "Spark shell", "user" : "xingbojiang", "memoryperslave" : 1024, "submitdate" : "Thu Jun 15 17:34:10 CST 2017", "state" : "FINISHED", "duration" : 116895 } ], "activedrivers" : [ ], "status" : "ALIVE" } ``` After the change: ``` { "url" : "spark://xingbos-MBP.local:7077", "workers" : [ { "id" : "worker-20170615175032-192.168.0.101-49951", "host" : "192.168.0.101", "port" : 49951, "webuiaddress" : "http://192.168.0.101:8081", "cores" : 8, "coresused" : 8, "coresfree" : 0, "memory" : 15360, "memoryused" : 1024, "memoryfree" : 14336, "state" : "ALIVE", "lastheartbeat" : 1497520292900 }, { "id" : "worker-20170615175034-192.168.0.101-49953", "host" : "192.168.0.101", "port" : 49953, "webuiaddress" : "http://192.168.0.101:8082", "cores" : 8, "coresused" : 8, "coresfree" : 0, "memory" : 15360, "memoryused" : 1024, "memoryfree" : 14336, "state" : "ALIVE", "lastheartbeat" : 1497520280301 }, { "id" : "worker-20170615175037-192.168.0.101-49955", "host" : "192.168.0.101", "port" : 49955, "webuiaddress" : "http://192.168.0.101:8083", "cores" : 8, "coresused" : 8, "coresfree" : 0, "memory" : 15360, "memoryused" : 1024, "memoryfree" : 14336, "state" : "ALIVE", "lastheartbeat" : 1497520282884 } ], "aliveworkers" : 3, "cores" : 24, "coresused" : 24, "memory" : 46080, "memoryused" : 3072, "activeapps" : [ { "id" : "app-20170615175122-0001", "starttime" : 1497520282115, "name" : "Spark shell", "cores" : 24, "user" : "xingbojiang", "memoryperslave" : 1024, "submitdate" : "Thu Jun 15 17:51:22 CST 2017", "state" : "RUNNING", "duration" : 10805 } ], "completedapps" : [ { "id" : "app-20170615175058-0000", "starttime" : 1497520258766, "name" : "Spark shell", "cores" : 24, "user" : "xingbojiang", "memoryperslave" : 1024, "submitdate" : "Thu Jun 15 17:50:58 CST 2017", "state" : "FINISHED", "duration" : 9876 } ], "activedrivers" : [ ], "completeddrivers" : [ ], "status" : "ALIVE" } ``` Author: Xingbo Jiang Closes #18303 from jiangxb1987/json-protocol. --- .../apache/spark/deploy/JsonProtocol.scala | 158 +++++++++++++++--- .../apache/spark/deploy/DeployTestUtils.scala | 4 +- .../spark/deploy/JsonProtocolSuite.scala | 15 +- 3 files changed, 149 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 220b20bf7cbd1..7212696166570 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -21,30 +21,65 @@ import org.json4s.JsonAST.JObject import org.json4s.JsonDSL._ import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} -import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} +import org.apache.spark.deploy.master._ +import org.apache.spark.deploy.master.RecoveryState.MasterState import org.apache.spark.deploy.worker.ExecutorRunner private[deploy] object JsonProtocol { - def writeWorkerInfo(obj: WorkerInfo): JObject = { - ("id" -> obj.id) ~ - ("host" -> obj.host) ~ - ("port" -> obj.port) ~ - ("webuiaddress" -> obj.webUiAddress) ~ - ("cores" -> obj.cores) ~ - ("coresused" -> obj.coresUsed) ~ - ("coresfree" -> obj.coresFree) ~ - ("memory" -> obj.memory) ~ - ("memoryused" -> obj.memoryUsed) ~ - ("memoryfree" -> obj.memoryFree) ~ - ("state" -> obj.state.toString) ~ - ("lastheartbeat" -> obj.lastHeartbeat) - } + /** + * Export the [[WorkerInfo]] to a Json object. A [[WorkerInfo]] consists of the information of a + * worker. + * + * @return a Json object containing the following fields: + * `id` a string identifier of the worker + * `host` the host that the worker is running on + * `port` the port that the worker is bound to + * `webuiaddress` the address used in web UI + * `cores` total cores of the worker + * `coresused` allocated cores of the worker + * `coresfree` free cores of the worker + * `memory` total memory of the worker + * `memoryused` allocated memory of the worker + * `memoryfree` free memory of the worker + * `state` state of the worker, see [[WorkerState]] + * `lastheartbeat` time in milliseconds that the latest heart beat message from the + * worker is received + */ + def writeWorkerInfo(obj: WorkerInfo): JObject = { + ("id" -> obj.id) ~ + ("host" -> obj.host) ~ + ("port" -> obj.port) ~ + ("webuiaddress" -> obj.webUiAddress) ~ + ("cores" -> obj.cores) ~ + ("coresused" -> obj.coresUsed) ~ + ("coresfree" -> obj.coresFree) ~ + ("memory" -> obj.memory) ~ + ("memoryused" -> obj.memoryUsed) ~ + ("memoryfree" -> obj.memoryFree) ~ + ("state" -> obj.state.toString) ~ + ("lastheartbeat" -> obj.lastHeartbeat) + } + /** + * Export the [[ApplicationInfo]] to a Json objec. An [[ApplicationInfo]] consists of the + * information of an application. + * + * @return a Json object containing the following fields: + * `id` a string identifier of the application + * `starttime` time in milliseconds that the application starts + * `name` the description of the application + * `cores` total cores granted to the application + * `user` name of the user who submitted the application + * `memoryperslave` minimal memory in MB required to each executor + * `submitdate` time in Date that the application is submitted + * `state` state of the application, see [[ApplicationState]] + * `duration` time in milliseconds that the application has been running + */ def writeApplicationInfo(obj: ApplicationInfo): JObject = { - ("starttime" -> obj.startTime) ~ ("id" -> obj.id) ~ + ("starttime" -> obj.startTime) ~ ("name" -> obj.desc.name) ~ - ("cores" -> obj.desc.maxCores) ~ + ("cores" -> obj.coresGranted) ~ ("user" -> obj.desc.user) ~ ("memoryperslave" -> obj.desc.memoryPerExecutorMB) ~ ("submitdate" -> obj.submitDate.toString) ~ @@ -52,14 +87,36 @@ private[deploy] object JsonProtocol { ("duration" -> obj.duration) } + /** + * Export the [[ApplicationDescription]] to a Json object. An [[ApplicationDescription]] consists + * of the description of an application. + * + * @return a Json object containing the following fields: + * `name` the description of the application + * `cores` max cores that can be allocated to the application, 0 means unlimited + * `memoryperslave` minimal memory in MB required to each executor + * `user` name of the user who submitted the application + * `command` the command string used to submit the application + */ def writeApplicationDescription(obj: ApplicationDescription): JObject = { ("name" -> obj.name) ~ - ("cores" -> obj.maxCores) ~ + ("cores" -> obj.maxCores.getOrElse(0)) ~ ("memoryperslave" -> obj.memoryPerExecutorMB) ~ ("user" -> obj.user) ~ ("command" -> obj.command.toString) } + /** + * Export the [[ExecutorRunner]] to a Json object. An [[ExecutorRunner]] consists of the + * information of an executor. + * + * @return a Json object containing the following fields: + * `id` an integer identifier of the executor + * `memory` memory in MB allocated to the executor + * `appid` a string identifier of the application that the executor is working on + * `appdesc` a Json object of the [[ApplicationDescription]] of the application that the + * executor is working on + */ def writeExecutorRunner(obj: ExecutorRunner): JObject = { ("id" -> obj.execId) ~ ("memory" -> obj.memory) ~ @@ -67,18 +124,59 @@ private[deploy] object JsonProtocol { ("appdesc" -> writeApplicationDescription(obj.appDesc)) } + /** + * Export the [[DriverInfo]] to a Json object. A [[DriverInfo]] consists of the information of a + * driver. + * + * @return a Json object containing the following fields: + * `id` a string identifier of the driver + * `starttime` time in milliseconds that the driver starts + * `state` state of the driver, see [[DriverState]] + * `cores` cores allocated to the driver + * `memory` memory in MB allocated to the driver + * `submitdate` time in Date that the driver is created + * `worker` identifier of the worker that the driver is running on + * `mainclass` main class of the command string that started the driver + */ def writeDriverInfo(obj: DriverInfo): JObject = { ("id" -> obj.id) ~ ("starttime" -> obj.startTime.toString) ~ ("state" -> obj.state.toString) ~ ("cores" -> obj.desc.cores) ~ - ("memory" -> obj.desc.mem) + ("memory" -> obj.desc.mem) ~ + ("submitdate" -> obj.submitDate.toString) ~ + ("worker" -> obj.worker.map(_.id).getOrElse("None")) ~ + ("mainclass" -> obj.desc.command.arguments(2)) } + /** + * Export the [[MasterStateResponse]] to a Json object. A [[MasterStateResponse]] consists the + * information of a master node. + * + * @return a Json object containing the following fields: + * `url` the url of the master node + * `workers` a list of Json objects of [[WorkerInfo]] of the workers allocated to the + * master + * `aliveworkers` size of alive workers allocated to the master + * `cores` total cores available of the master + * `coresused` cores used by the master + * `memory` total memory available of the master + * `memoryused` memory used by the master + * `activeapps` a list of Json objects of [[ApplicationInfo]] of the active applications + * running on the master + * `completedapps` a list of Json objects of [[ApplicationInfo]] of the applications + * completed in the master + * `activedrivers` a list of Json objects of [[DriverInfo]] of the active drivers of the + * master + * `completeddrivers` a list of Json objects of [[DriverInfo]] of the completed drivers + * of the master + * `status` status of the master, see [[MasterState]] + */ def writeMasterState(obj: MasterStateResponse): JObject = { val aliveWorkers = obj.workers.filter(_.isAlive()) ("url" -> obj.uri) ~ ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ + ("aliveworkers" -> aliveWorkers.length) ~ ("cores" -> aliveWorkers.map(_.cores).sum) ~ ("coresused" -> aliveWorkers.map(_.coresUsed).sum) ~ ("memory" -> aliveWorkers.map(_.memory).sum) ~ @@ -86,9 +184,27 @@ private[deploy] object JsonProtocol { ("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~ ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~ ("activedrivers" -> obj.activeDrivers.toList.map(writeDriverInfo)) ~ + ("completeddrivers" -> obj.completedDrivers.toList.map(writeDriverInfo)) ~ ("status" -> obj.status.toString) } + /** + * Export the [[WorkerStateResponse]] to a Json object. A [[WorkerStateResponse]] consists the + * information of a worker node. + * + * @return a Json object containing the following fields: + * `id` a string identifier of the worker node + * `masterurl` url of the master node of the worker + * `masterwebuiurl` the address used in web UI of the master node of the worker + * `cores` total cores of the worker + * `coreused` used cores of the worker + * `memory` total memory of the worker + * `memoryused` used memory of the worker + * `executors` a list of Json objects of [[ExecutorRunner]] of the executors running on + * the worker + * `finishedexecutors` a list of Json objects of [[ExecutorRunner]] of the finished + * executors of the worker + */ def writeWorkerState(obj: WorkerStateResponse): JObject = { ("id" -> obj.workerId) ~ ("masterurl" -> obj.masterUrl) ~ @@ -97,7 +213,7 @@ private[deploy] object JsonProtocol { ("coresused" -> obj.coresUsed) ~ ("memory" -> obj.memory) ~ ("memoryused" -> obj.memoryUsed) ~ - ("executors" -> obj.executors.toList.map(writeExecutorRunner)) ~ - ("finishedexecutors" -> obj.finishedExecutors.toList.map(writeExecutorRunner)) + ("executors" -> obj.executors.map(writeExecutorRunner)) ~ + ("finishedexecutors" -> obj.finishedExecutors.map(writeExecutorRunner)) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala index 9c13c15281a42..55a541d60ea3c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala @@ -39,7 +39,7 @@ private[deploy] object DeployTestUtils { } def createDriverCommand(): Command = new Command( - "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), + "org.apache.spark.FakeClass", Seq("WORKER_URL", "USER_JAR", "mainClass"), Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") ) @@ -47,7 +47,7 @@ private[deploy] object DeployTestUtils { new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", - createDriverDesc(), new Date()) + createDriverDesc(), JsonConstants.submitDate) def createWorkerInfo(): WorkerInfo = { val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, "http://publicAddress:80") diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 7093dad05c5f6..1903130cb694a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -104,8 +104,8 @@ object JsonConstants { val submitDate = new Date(123456789) val appInfoJsonStr = """ - |{"starttime":3,"id":"id","name":"name", - |"cores":4,"user":"%s", + |{"id":"id","starttime":3,"name":"name", + |"cores":0,"user":"%s", |"memoryperslave":1234,"submitdate":"%s", |"state":"WAITING","duration":%d} """.format(System.getProperty("user.name", ""), @@ -134,19 +134,24 @@ object JsonConstants { val driverInfoJsonStr = """ - |{"id":"driver-3","starttime":"3","state":"SUBMITTED","cores":3,"memory":100} - """.stripMargin + |{"id":"driver-3","starttime":"3", + |"state":"SUBMITTED","cores":3,"memory":100, + |"submitdate":"%s","worker":"None", + |"mainclass":"mainClass"} + """.format(submitDate.toString).stripMargin val masterStateJsonStr = """ |{"url":"spark://host:8080", |"workers":[%s,%s], + |"aliveworkers":2, |"cores":8,"coresused":0,"memory":2468,"memoryused":0, |"activeapps":[%s],"completedapps":[], |"activedrivers":[%s], + |"completeddrivers":[%s], |"status":"ALIVE"} """.format(workerInfoJsonStr, workerInfoJsonStr, - appInfoJsonStr, driverInfoJsonStr).stripMargin + appInfoJsonStr, driverInfoJsonStr, driverInfoJsonStr).stripMargin val workerStateJsonStr = """ From 9413b84b5a99e264816c61f72905b392c2f9cd35 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 19 Jun 2017 15:51:21 +0800 Subject: [PATCH 027/118] [SPARK-21132][SQL] DISTINCT modifier of function arguments should not be silently ignored ### What changes were proposed in this pull request? We should not silently ignore `DISTINCT` when they are not supported in the function arguments. This PR is to block these cases and issue the error messages. ### How was this patch tested? Added test cases for both regular functions and window functions Author: Xiao Li Closes #18340 from gatorsmile/firstCount. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 14 ++++++++++++-- .../catalyst/analysis/AnalysisErrorSuite.scala | 15 +++++++++++++-- .../sql/catalyst/analysis/AnalysisTest.scala | 8 ++++++-- 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 196b4a9bada3c..647fc0b9342c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1206,11 +1206,21 @@ class Analyzer( // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within // the context of a Window clause. They do not need to be wrapped in an // AggregateExpression. - case wf: AggregateWindowFunction => wf + case wf: AggregateWindowFunction => + if (isDistinct) { + failAnalysis(s"${wf.prettyName} does not support the modifier DISTINCT") + } else { + wf + } // We get an aggregate function, we need to wrap it in an AggregateExpression. case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) // This function is not an aggregate function, just return the resolved one. - case other => other + case other => + if (isDistinct) { + failAnalysis(s"${other.prettyName} does not support the modifier DISTINCT") + } else { + other + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index d2ebca5a83dd3..5050318d96358 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Max} -import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.{Cross, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -152,7 +153,7 @@ class AnalysisErrorSuite extends AnalysisTest { "not supported within a window function" :: Nil) errorTest( - "distinct window function", + "distinct aggregate function in window", testRelation2.select( WindowExpression( AggregateExpression(Count(UnresolvedAttribute("b")), Complete, isDistinct = true), @@ -162,6 +163,16 @@ class AnalysisErrorSuite extends AnalysisTest { UnspecifiedFrame)).as('window)), "Distinct window functions are not supported" :: Nil) + errorTest( + "distinct function", + CatalystSqlParser.parsePlan("SELECT hex(DISTINCT a) FROM TaBlE"), + "hex does not support the modifier DISTINCT" :: Nil) + + errorTest( + "distinct window function", + CatalystSqlParser.parsePlan("SELECT percent_rank(DISTINCT a) over () FROM TaBlE"), + "percent_rank does not support the modifier DISTINCT" :: Nil) + errorTest( "nested aggregate functions", testRelation.groupBy('a)( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index afc7ce4195a8b..edfa8c45f9867 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.analysis +import java.net.URI import java.util.Locale import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf @@ -32,7 +33,10 @@ trait AnalysisTest extends PlanTest { private def makeAnalyzer(caseSensitive: Boolean): Analyzer = { val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) - val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin, conf) + catalog.createDatabase( + CatalogDatabase("default", "", new URI("loc"), Map.empty), + ignoreIfExists = false) catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true) catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true) catalog.createTempView("TaBlE3", TestRelations.testRelation3, overrideIfExists = true) From 9a145fd796145d1386fd75c01e4103deadb97ac9 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 19 Jun 2017 11:13:03 +0100 Subject: [PATCH 028/118] [MINOR] Bump SparkR and PySpark version to 2.3.0. ## What changes were proposed in this pull request? #17753 bumps master branch version to 2.3.0-SNAPSHOT, but it seems SparkR and PySpark version were omitted. ditto of https://github.com/apache/spark/pull/16488 / https://github.com/apache/spark/pull/17523 ## How was this patch tested? N/A Author: hyukjinkwon Closes #18341 from HyukjinKwon/r-version. --- R/pkg/DESCRIPTION | 2 +- python/pyspark/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 879c1f80f2c5d..b739d423a36cc 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,6 +1,6 @@ Package: SparkR Type: Package -Version: 2.2.0 +Version: 2.3.0 Title: R Frontend for Apache Spark Description: The SparkR package provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), diff --git a/python/pyspark/version.py b/python/pyspark/version.py index 41bf8c269b795..12dd53b9d2902 100644 --- a/python/pyspark/version.py +++ b/python/pyspark/version.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.2.0.dev0" +__version__ = "2.3.0.dev0" From e92ffe6f1771e3fe9ea2e62ba552c1b5cf255368 Mon Sep 17 00:00:00 2001 From: saturday_s Date: Mon, 19 Jun 2017 10:24:29 -0700 Subject: [PATCH 029/118] [SPARK-19688][STREAMING] Not to read `spark.yarn.credentials.file` from checkpoint. ## What changes were proposed in this pull request? Reload the `spark.yarn.credentials.file` property when restarting a streaming application from checkpoint. ## How was this patch tested? Manual tested with 1.6.3 and 2.1.1. I didn't test this with master because of some compile problems, but I think it will be the same result. ## Notice This should be merged into maintenance branches too. jira: [SPARK-21008](https://issues.apache.org/jira/browse/SPARK-21008) Author: saturday_s Closes #18230 from saturday-shi/SPARK-21008. --- .../src/main/scala/org/apache/spark/streaming/Checkpoint.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 5cbad8bf3ce6e..b8c780db07c98 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -55,6 +55,9 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.master", "spark.yarn.keytab", "spark.yarn.principal", + "spark.yarn.credentials.file", + "spark.yarn.credentials.renewalTime", + "spark.yarn.credentials.updateTime", "spark.ui.filters") val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) From 66a792cd88c63cc0a1d20cbe14ac5699afbb3662 Mon Sep 17 00:00:00 2001 From: assafmendelson Date: Mon, 19 Jun 2017 10:58:58 -0700 Subject: [PATCH 030/118] [SPARK-21123][DOCS][STRUCTURED STREAMING] Options for file stream source are in a wrong table ## What changes were proposed in this pull request? The description for several options of File Source for structured streaming appeared in the File Sink description instead. This pull request has two commits: The first includes changes to the version as it appeared in spark 2.1 and the second handled an additional option added for spark 2.2 ## How was this patch tested? Built the documentation by SKIP_API=1 jekyll build and visually inspected the structured streaming programming guide. The original documentation was written by tdas and lw-lin Author: assafmendelson Closes #18342 from assafmendelson/spark-21123. --- .../structured-streaming-programming-guide.md | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 9b9177d44145f..d478042dea5c8 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -510,7 +510,20 @@ Here are the details of all the sources in Spark. File source path: path to the input directory, and common to all file formats. -

    +
    + maxFilesPerTrigger: maximum number of new files to be considered in every trigger (default: no max) +
    + latestFirst: whether to processs the latest new files first, useful when there is a large backlog of files (default: false) +
    + fileNameOnly: whether to check new files based on only the filename instead of on the full path (default: false). With this set to `true`, the following files would be considered as the same file, because their filenames, "dataset.txt", are the same: +
    + · "file:///dataset.txt"
    + · "s3://a/dataset.txt"
    + · "s3n://a/b/dataset.txt"
    + · "s3a://a/b/c/dataset.txt"
    +
    + +
    For file-format-specific options, see the related methods in DataStreamReader (Scala/Java/Python/R). @@ -1234,18 +1247,7 @@ Here are the details of all the sinks in Spark. Append path: path to the output directory, must be specified. -
    - maxFilesPerTrigger: maximum number of new files to be considered in every trigger (default: no max) -
    - latestFirst: whether to processs the latest new files first, useful when there is a large backlog of files (default: false) -
    - fileNameOnly: whether to check new files based on only the filename instead of on the full path (default: false). With this set to `true`, the following files would be considered as the same file, because their filenames, "dataset.txt", are the same: -
    - · "file:///dataset.txt"
    - · "s3://a/dataset.txt"
    - · "s3n://a/b/dataset.txt"
    - · "s3a://a/b/c/dataset.txt"
    -
    +

    For file-format-specific options, see the related methods in DataFrameWriter (Scala/Java/Python/R). From e5387018e76a9af1318e78c4133ee68232e6a159 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 19 Jun 2017 11:40:07 -0700 Subject: [PATCH 031/118] [SPARK-19975][PYTHON][SQL] Add map_keys and map_values functions to Python ## What changes were proposed in this pull request? This fix tries to address the issue in SPARK-19975 where we have `map_keys` and `map_values` functions in SQL yet there is no Python equivalent functions. This fix adds `map_keys` and `map_values` functions to Python. ## How was this patch tested? This fix is tested manually (See Python docs for examples). Author: Yong Tang Closes #17328 from yongtang/SPARK-19975. --- python/pyspark/sql/functions.py | 40 +++++++++++++++++++ .../org/apache/spark/sql/functions.scala | 14 +++++++ 2 files changed, 54 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d9b86aff63fa0..240ae65a61785 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1855,6 +1855,46 @@ def sort_array(col, asc=True): return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) +@since(2.3) +def map_keys(col): + """ + Collection function: Returns an unordered array containing the keys of the map. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_keys + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data") + >>> df.select(map_keys("data").alias("keys")).show() + +------+ + | keys| + +------+ + |[1, 2]| + +------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_keys(_to_java_column(col))) + + +@since(2.3) +def map_values(col): + """ + Collection function: Returns an unordered array containing the values of the map. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_values + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data") + >>> df.select(map_values("data").alias("values")).show() + +------+ + |values| + +------+ + |[a, b]| + +------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_values(_to_java_column(col))) + + # ---------------------------- User Defined Function ---------------------------------- def _wrap_function(sc, func, returnType): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8d2e1f32da059..9a35a5c4658e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3161,6 +3161,20 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Returns an unordered array containing the keys of the map. + * @group collection_funcs + * @since 2.3.0 + */ + def map_keys(e: Column): Column = withExpr { MapKeys(e.expr) } + + /** + * Returns an unordered array containing the values of the map. + * @group collection_funcs + * @since 2.3.0 + */ + def map_values(e: Column): Column = withExpr { MapValues(e.expr) } + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// From ecc5631351e81bbee4befb213f3053a4f31532a7 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 19 Jun 2017 20:17:54 +0100 Subject: [PATCH 032/118] [MINOR][BUILD] Fix Java linter errors ## What changes were proposed in this pull request? This PR cleans up a few Java linter errors for Apache Spark 2.2 release. ## How was this patch tested? ```bash $ dev/lint-java Using `mvn` from path: /usr/local/bin/mvn Checkstyle checks passed. ``` We can check the result at Travis CI, [here](https://travis-ci.org/dongjoon-hyun/spark/builds/244297894). Author: Dongjoon Hyun Closes #18345 from dongjoon-hyun/fix_lint_java_2. --- .../src/main/java/org/apache/spark/kvstore/KVIndex.java | 2 +- .../src/main/java/org/apache/spark/kvstore/KVStore.java | 7 ++----- .../main/java/org/apache/spark/kvstore/KVStoreView.java | 3 --- .../main/java/org/apache/spark/kvstore/KVTypeInfo.java | 2 -- .../src/main/java/org/apache/spark/kvstore/LevelDB.java | 1 - .../java/org/apache/spark/kvstore/LevelDBIterator.java | 1 - .../java/org/apache/spark/kvstore/LevelDBTypeInfo.java | 5 ----- .../java/org/apache/spark/kvstore/DBIteratorSuite.java | 4 +--- .../test/java/org/apache/spark/kvstore/LevelDBSuite.java | 2 -- .../spark/network/shuffle/OneForOneBlockFetcher.java | 2 +- .../apache/spark/shuffle/sort/UnsafeShuffleWriter.java | 8 +++++--- .../java/org/apache/spark/examples/ml/JavaALSExample.java | 2 +- .../spark/examples/sql/JavaSQLDataSourceExample.java | 6 +++++- .../java/org/apache/spark/sql/streaming/OutputMode.java | 1 - 14 files changed, 16 insertions(+), 30 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java index 8b8899023c938..0cffefe07c25d 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java @@ -50,7 +50,7 @@ @Target({ElementType.FIELD, ElementType.METHOD}) public @interface KVIndex { - public static final String NATURAL_INDEX_NAME = "__main__"; + String NATURAL_INDEX_NAME = "__main__"; /** * The name of the index to be created for the annotated entity. Must be unique within diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java index 3be4b829b4d8d..c7808ea3c3881 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java @@ -18,9 +18,6 @@ package org.apache.spark.kvstore; import java.io.Closeable; -import java.util.Iterator; -import java.util.Map; -import java.util.NoSuchElementException; /** * Abstraction for a local key/value store for storing app data. @@ -84,7 +81,7 @@ public interface KVStore extends Closeable { * * @param naturalKey The object's "natural key", which uniquely identifies it. Null keys * are not allowed. - * @throws NoSuchElementException If an element with the given key does not exist. + * @throws java.util.NoSuchElementException If an element with the given key does not exist. */ T read(Class klass, Object naturalKey) throws Exception; @@ -107,7 +104,7 @@ public interface KVStore extends Closeable { * @param type The object's type. * @param naturalKey The object's "natural key", which uniquely identifies it. Null keys * are not allowed. - * @throws NoSuchElementException If an element with the given key does not exist. + * @throws java.util.NoSuchElementException If an element with the given key does not exist. */ void delete(Class type, Object naturalKey) throws Exception; diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java index b761640e6da8b..8cd1f52892293 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java @@ -17,9 +17,6 @@ package org.apache.spark.kvstore; -import java.util.Iterator; -import java.util.Map; - import com.google.common.base.Preconditions; /** diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java index 90f2ff0079b8a..e1cc0ba3f5aa7 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java @@ -19,8 +19,6 @@ import java.lang.reflect.Field; import java.lang.reflect.Method; -import java.util.ArrayList; -import java.util.Collection; import java.util.HashMap; import java.util.Map; import java.util.stream.Stream; diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java index 08b22fd8265d8..27141358dc0f2 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java @@ -29,7 +29,6 @@ import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import org.fusesource.leveldbjni.JniDBFactory; diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java index a5d0f9f4fb373..263d45c242106 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java @@ -18,7 +18,6 @@ package org.apache.spark.kvstore; import java.io.IOException; -import java.util.Arrays; import java.util.ArrayList; import java.util.List; import java.util.Map; diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java index 3ab17dbd03ca7..722f54e6f9c66 100644 --- a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java @@ -18,17 +18,12 @@ package org.apache.spark.kvstore; import java.lang.reflect.Array; -import java.lang.reflect.Field; -import java.lang.reflect.Method; -import java.io.ByteArrayOutputStream; -import java.io.IOException; import java.util.Collection; import java.util.HashMap; import java.util.Map; import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.base.Preconditions; -import com.google.common.base.Throwables; import org.iq80.leveldb.WriteBatch; /** diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java index 8549712213393..3a418189ecfec 100644 --- a/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java @@ -25,11 +25,9 @@ import java.util.List; import java.util.Random; -import com.google.common.base.Predicate; import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; import com.google.common.collect.Lists; -import org.apache.commons.io.FileUtils; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; @@ -50,7 +48,7 @@ public abstract class DBIteratorSuite { private static List clashingEntries; private static KVStore db; - private static interface BaseComparator extends Comparator { + private interface BaseComparator extends Comparator { /** * Returns a comparator that falls back to natural order if this comparator's ordering * returns equality for two elements. Used to mimic how the index sorts things internally. diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java index ee1c397c08573..42bff610457e7 100644 --- a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java @@ -20,9 +20,7 @@ import java.io.File; import java.util.Arrays; import java.util.List; -import java.util.Map; import java.util.NoSuchElementException; -import static java.nio.charset.StandardCharsets.UTF_8; import org.apache.commons.io.FileUtils; import org.iq80.leveldb.DBIterator; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 5f428759252aa..d46ce2e0e6b78 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -157,7 +157,7 @@ private class DownloadCallback implements StreamCallback { private File targetFile = null; private int chunkIndex; - public DownloadCallback(File targetFile, int chunkIndex) throws IOException { + DownloadCallback(File targetFile, int chunkIndex) throws IOException { this.targetFile = targetFile; this.channel = Channels.newChannel(new FileOutputStream(targetFile)); this.chunkIndex = chunkIndex; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 857ec8a4dadd2..34c179990214f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -364,7 +364,8 @@ private long[] mergeSpillsWithFileStream( // Use a counting output stream to avoid having to close the underlying file and ask // the file system for its size after each partition is written. final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos); - final int inputBufferSizeInBytes = (int) sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + final int inputBufferSizeInBytes = + (int) sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; boolean threwException = true; try { @@ -375,8 +376,9 @@ private long[] mergeSpillsWithFileStream( } for (int partition = 0; partition < numPartitions; partition++) { final long initialFileLength = mergedFileOutputStream.getByteCount(); - // Shield the underlying output stream from close() and flush() calls, so that we can close the higher - // level streams to make sure all data is really flushed and internal state is cleaned. + // Shield the underlying output stream from close() and flush() calls, so that we can close + // the higher level streams to make sure all data is really flushed and internal state is + // cleaned. OutputStream partitionOutput = new CloseAndFlushShieldOutputStream( new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java index 60ef03d89d17b..fe4d6bc83f04a 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -121,7 +121,7 @@ public static void main(String[] args) { // $example off$ userRecs.show(); movieRecs.show(); - + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index 706856b5215e4..95859c52c2aeb 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -124,7 +124,11 @@ private static void runBasicDataSourceExample(SparkSession spark) { peopleDF.write().bucketBy(42, "name").sortBy("age").saveAsTable("people_bucketed"); // $example off:write_sorting_and_bucketing$ // $example on:write_partitioning$ - usersDF.write().partitionBy("favorite_color").format("parquet").save("namesPartByColor.parquet"); + usersDF + .write() + .partitionBy("favorite_color") + .format("parquet") + .save("namesPartByColor.parquet"); // $example off:write_partitioning$ // $example on:write_partition_and_bucket$ peopleDF diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java index 8410abd14fd59..2800b3068f87b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.streaming; -import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.streaming.InternalOutputModes; From 0a4b7e4f81109cff651d2afb94f9f8bf734abdeb Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Mon, 19 Jun 2017 20:35:58 +0100 Subject: [PATCH 033/118] [MINOR] Fix some typo of the document ## What changes were proposed in this pull request? Fix some typo of the document. ## How was this patch tested? Existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Xianyang Liu Closes #18350 from ConeyLiu/fixtypo. --- dev/change-version-to-2.10.sh | 2 +- dev/change-version-to-2.11.sh | 2 +- python/pyspark/__init__.py | 2 +- .../apache/spark/sql/catalyst/expressions/ExpressionSet.scala | 2 +- .../apache/spark/sql/execution/streaming/BatchCommitLog.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 2 +- .../sql/execution/datasources/FileSourceStrategySuite.scala | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dev/change-version-to-2.10.sh b/dev/change-version-to-2.10.sh index 0962d34c52f28..b718d94f849dd 100755 --- a/dev/change-version-to-2.10.sh +++ b/dev/change-version-to-2.10.sh @@ -17,7 +17,7 @@ # limitations under the License. # -# This script exists for backwards compability. Use change-scala-version.sh instead. +# This script exists for backwards compatibility. Use change-scala-version.sh instead. echo "This script is deprecated. Please instead run: change-scala-version.sh 2.10" $(dirname $0)/change-scala-version.sh 2.10 diff --git a/dev/change-version-to-2.11.sh b/dev/change-version-to-2.11.sh index 4ccfeef09fd04..93087959a38dd 100755 --- a/dev/change-version-to-2.11.sh +++ b/dev/change-version-to-2.11.sh @@ -17,7 +17,7 @@ # limitations under the License. # -# This script exists for backwards compability. Use change-scala-version.sh instead. +# This script exists for backwards compatibility. Use change-scala-version.sh instead. echo "This script is deprecated. Please instead run: change-scala-version.sh 2.11" $(dirname $0)/change-scala-version.sh 2.11 diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 14c51a306e1c2..4d142c91629cc 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -35,7 +35,7 @@ - :class:`StorageLevel`: Finer-grained cache persistence levels. - :class:`TaskContext`: - Information about the current running task, avaialble on the workers and experimental. + Information about the current running task, available on the workers and experimental. """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala index f93e5736de401..ede0b1654bbd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -39,7 +39,7 @@ object ExpressionSet { * guaranteed to see at least one such expression. For example: * * {{{ - * val set = AttributeSet(a + 1, 1 + a) + * val set = ExpressionSet(a + 1, 1 + a) * * set.iterator => Iterator(a + 1) * set.contains(a + 1) => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala index a34938f911f76..5e24e8fc4e3cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.SparkSession * - process batch 1 * - write batch 1 to completion log * - trigger batch 2 - * - obtain bactch 2 offsets and write to offset log + * - obtain batch 2 offsets and write to offset log * - process batch 2 * - write batch 2 to completion log * .... diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 8569c2d76b694..5db354d79bb6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -507,7 +507,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil) } - test("SQL decimal test (used for catching certain demical handling bugs in aggregates)") { + test("SQL decimal test (used for catching certain decimal handling bugs in aggregates)") { checkAnswer( decimalData.groupBy('a cast DecimalType(10, 2)).agg(avg('b cast DecimalType(10, 2))), Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(1.5)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 9a2dcafb5e4b3..d77f0c298ffe3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -244,7 +244,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi val df2 = table.where("(p1 + c2) = 2 AND c1 = 1") // Filter on data only are advisory so we have to reevaluate. assert(getPhysicalFilters(df2) contains resolve(df2, "c1 = 1")) - // Need to evalaute filters that are not pushed down. + // Need to evaluate filters that are not pushed down. assert(getPhysicalFilters(df2) contains resolve(df2, "(p1 + c2) = 2")) } From 581565dd871ca51507603d19b2d4203993c2636d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 19 Jun 2017 14:41:58 -0700 Subject: [PATCH 034/118] [SPARK-21124][UI] Show correct application user in UI. The jobs page currently shows the application user, but it assumes the OS user is the same as the user running the application, which may not be true in all scenarios (e.g., kerberos). While it might be useful to show both in the UI, this change just chooses the application user over the OS user, since the latter can be found in the environment page if needed. Tested in live application and in history server. Author: Marcelo Vanzin Closes #18331 from vanzin/SPARK-21124. --- core/src/main/scala/org/apache/spark/ui/SparkUI.scala | 4 +++- .../main/scala/org/apache/spark/ui/env/EnvironmentTab.scala | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index f271c56021e95..589f811145519 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -86,7 +86,9 @@ private[spark] class SparkUI private ( initialize() def getSparkUser: String = { - environmentListener.systemProperties.toMap.getOrElse("user.name", "") + environmentListener.sparkUser + .orElse(environmentListener.systemProperties.toMap.get("user.name")) + .getOrElse("") } def getAppName: String = appName diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala index 8c18464e6477a..61b12aaa32bb6 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala @@ -34,11 +34,16 @@ private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "en @DeveloperApi @deprecated("This class will be removed in a future release.", "2.2.0") class EnvironmentListener extends SparkListener { + var sparkUser: Option[String] = None var jvmInformation = Seq[(String, String)]() var sparkProperties = Seq[(String, String)]() var systemProperties = Seq[(String, String)]() var classpathEntries = Seq[(String, String)]() + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { + sparkUser = Some(event.sparkUser) + } + override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { synchronized { val environmentDetails = environmentUpdate.environmentDetails From 3d4d11a80fe8953d48d8bfac2ce112e37d38dc90 Mon Sep 17 00:00:00 2001 From: sharkdtu Date: Mon, 19 Jun 2017 14:54:54 -0700 Subject: [PATCH 035/118] [SPARK-21138][YARN] Cannot delete staging dir when the clusters of "spark.yarn.stagingDir" and "spark.hadoop.fs.defaultFS" are different MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? When I set different clusters for "spark.hadoop.fs.defaultFS" and "spark.yarn.stagingDir" as follows: ``` spark.hadoop.fs.defaultFS hdfs://tl-nn-tdw.tencent-distribute.com:54310 spark.yarn.stagingDir hdfs://ss-teg-2-v2/tmp/spark ``` The staging dir can not be deleted, it will prompt following message: ``` java.lang.IllegalArgumentException: Wrong FS: hdfs://ss-teg-2-v2/tmp/spark/.sparkStaging/application_1496819138021_77618, expected: hdfs://tl-nn-tdw.tencent-distribute.com:54310 ``` ## How was this patch tested? Existing tests Author: sharkdtu Closes #18352 from sharkdtu/master. --- .../org/apache/spark/deploy/yarn/ApplicationMaster.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 4f71a1606312d..4868180569778 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -209,8 +209,6 @@ private[spark] class ApplicationMaster( logInfo("ApplicationAttemptId: " + appAttemptId) - val fs = FileSystem.get(yarnConf) - // This shutdown hook should run *after* the SparkContext is shut down. val priority = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1 ShutdownHookManager.addShutdownHook(priority) { () => @@ -232,7 +230,7 @@ private[spark] class ApplicationMaster( // we only want to unregister if we don't want the RM to retry if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) { unregister(finalStatus, finalMsg) - cleanupStagingDir(fs) + cleanupStagingDir() } } } @@ -533,7 +531,7 @@ private[spark] class ApplicationMaster( /** * Clean up the staging directory. */ - private def cleanupStagingDir(fs: FileSystem) { + private def cleanupStagingDir(): Unit = { var stagingDirPath: Path = null try { val preserveFiles = sparkConf.get(PRESERVE_STAGING_FILES) @@ -544,6 +542,7 @@ private[spark] class ApplicationMaster( return } logInfo("Deleting staging directory " + stagingDirPath) + val fs = stagingDirPath.getFileSystem(yarnConf) fs.delete(stagingDirPath, true) } } catch { From 9eacc5e4384de26eaf1d6475bcc698c4e86c996d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 19 Jun 2017 15:14:33 -0700 Subject: [PATCH 036/118] [INFRA] Close stale PRs. Closes #18311 Closes #18278 From 9b57cd8d5c594731a7b3c90ce59bcddb05193d79 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 20 Jun 2017 09:22:30 +0800 Subject: [PATCH 037/118] [SPARK-21133][CORE] Fix HighlyCompressedMapStatus#writeExternal throws NPE ## What changes were proposed in this pull request? Fix HighlyCompressedMapStatus#writeExternal NPE: ``` 17/06/18 15:00:27 ERROR Utils: Exception encountered java.lang.NullPointerException at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply$mcV$sp(MapStatus.scala:171) at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply(MapStatus.scala:167) at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply(MapStatus.scala:167) at org.apache.spark.util.Utils$.tryOrIOException(Utils.scala:1303) at org.apache.spark.scheduler.HighlyCompressedMapStatus.writeExternal(MapStatus.scala:167) at java.io.ObjectOutputStream.writeExternalData(ObjectOutputStream.java:1459) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1430) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1378) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1174) at java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply$mcV$sp(MapOutputTracker.scala:617) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply(MapOutputTracker.scala:616) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply(MapOutputTracker.scala:616) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1337) at org.apache.spark.MapOutputTracker$.serializeMapStatuses(MapOutputTracker.scala:619) at org.apache.spark.MapOutputTrackerMaster.getSerializedMapOutputStatuses(MapOutputTracker.scala:562) at org.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:351) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) 17/06/18 15:00:27 ERROR MapOutputTrackerMaster: java.lang.NullPointerException java.io.IOException: java.lang.NullPointerException at org.apache.spark.util.Utils$.tryOrIOException(Utils.scala:1310) at org.apache.spark.scheduler.HighlyCompressedMapStatus.writeExternal(MapStatus.scala:167) at java.io.ObjectOutputStream.writeExternalData(ObjectOutputStream.java:1459) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1430) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1378) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1174) at java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply$mcV$sp(MapOutputTracker.scala:617) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply(MapOutputTracker.scala:616) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply(MapOutputTracker.scala:616) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1337) at org.apache.spark.MapOutputTracker$.serializeMapStatuses(MapOutputTracker.scala:619) at org.apache.spark.MapOutputTrackerMaster.getSerializedMapOutputStatuses(MapOutputTracker.scala:562) at org.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:351) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) Caused by: java.lang.NullPointerException at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply$mcV$sp(MapStatus.scala:171) at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply(MapStatus.scala:167) at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply(MapStatus.scala:167) at org.apache.spark.util.Utils$.tryOrIOException(Utils.scala:1303) ... 17 more 17/06/18 15:00:27 INFO MapOutputTrackerMasterEndpoint: Asked to send map output locations for shuffle 0 to 10.17.47.20:50188 17/06/18 15:00:27 ERROR Utils: Exception encountered java.lang.NullPointerException at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply$mcV$sp(MapStatus.scala:171) at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply(MapStatus.scala:167) at org.apache.spark.scheduler.HighlyCompressedMapStatus$$anonfun$writeExternal$2.apply(MapStatus.scala:167) at org.apache.spark.util.Utils$.tryOrIOException(Utils.scala:1303) at org.apache.spark.scheduler.HighlyCompressedMapStatus.writeExternal(MapStatus.scala:167) at java.io.ObjectOutputStream.writeExternalData(ObjectOutputStream.java:1459) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1430) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) at java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1378) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1174) at java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply$mcV$sp(MapOutputTracker.scala:617) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply(MapOutputTracker.scala:616) at org.apache.spark.MapOutputTracker$$anonfun$serializeMapStatuses$1.apply(MapOutputTracker.scala:616) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1337) at org.apache.spark.MapOutputTracker$.serializeMapStatuses(MapOutputTracker.scala:619) at org.apache.spark.MapOutputTrackerMaster.getSerializedMapOutputStatuses(MapOutputTracker.scala:562) at org.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:351) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` ## How was this patch tested? manual tests Author: Yuming Wang Closes #18343 from wangyum/SPARK-21133. --- .../org/apache/spark/scheduler/MapStatus.scala | 2 +- .../spark/serializer/KryoSerializer.scala | 1 + .../spark/scheduler/MapStatusSuite.scala | 18 ++++++++++++++++-- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 048e0d0186594..5e45b375ddd45 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -141,7 +141,7 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, - @transient private var hugeBlockSizes: Map[Int, Byte]) + private var hugeBlockSizes: Map[Int, Byte]) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index e15166d11c243..4f03e54e304f6 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -175,6 +175,7 @@ class KryoSerializer(conf: SparkConf) kryo.register(None.getClass) kryo.register(Nil.getClass) kryo.register(Utils.classForName("scala.collection.immutable.$colon$colon")) + kryo.register(Utils.classForName("scala.collection.immutable.Map$EmptyMap$")) kryo.register(classOf[ArrayBuffer[Any]]) kryo.setClassLoader(classLoader) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 3ec37f674c77b..e6120139f4958 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -24,9 +24,9 @@ import scala.util.Random import org.mockito.Mockito._ import org.roaringbitmap.RoaringBitmap -import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.internal.config -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.storage.BlockManagerId class MapStatusSuite extends SparkFunSuite { @@ -154,4 +154,18 @@ class MapStatusSuite extends SparkFunSuite { case part => assert(status2.getSizeForBlock(part) >= sizes(part)) } } + + test("SPARK-21133 HighlyCompressedMapStatus#writeExternal throws NPE") { + val conf = new SparkConf() + .set("spark.serializer", classOf[KryoSerializer].getName) + .setMaster("local") + .setAppName("SPARK-21133") + val sc = new SparkContext(conf) + try { + val count = sc.parallelize(0 until 3000, 10).repartition(2001).collect().length + assert(count === 3000) + } finally { + sc.stop() + } + } } From 8965fe764a4218d944938aa4828072f1ad9dbda7 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Mon, 19 Jun 2017 19:41:24 -0700 Subject: [PATCH 038/118] [SPARK-20889][SPARKR] Grouped documentation for AGGREGATE column methods ## What changes were proposed in this pull request? Grouped documentation for the aggregate functions for Column. Author: actuaryzhang Closes #18025 from actuaryzhang/sparkRDoc4. --- R/pkg/R/functions.R | 427 ++++++++++++++++++-------------------------- R/pkg/R/generics.R | 56 ++++-- R/pkg/R/stats.R | 22 +-- 3 files changed, 219 insertions(+), 286 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 7128c3b9adff4..01ca8b8c4527d 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -18,6 +18,22 @@ #' @include generics.R column.R NULL +#' Aggregate functions for Column operations +#' +#' Aggregate functions defined for \code{Column}. +#' +#' @param x Column to compute on. +#' @param y,na.rm,use currently not used. +#' @param ... additional argument(s). For example, it could be used to pass additional Columns. +#' @name column_aggregate_functions +#' @rdname column_aggregate_functions +#' @family aggregate functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))} +NULL + #' lit #' #' A new \linkS4class{Column} is created to represent the literal value. @@ -85,17 +101,20 @@ setMethod("acos", column(jc) }) -#' Returns the approximate number of distinct items in a group +#' @details +#' \code{approxCountDistinct}: Returns the approximate number of distinct items in a group. #' -#' Returns the approximate number of distinct items in a group. This is a column -#' aggregate function. -#' -#' @rdname approxCountDistinct -#' @name approxCountDistinct -#' @return the approximate number of distinct items in a group. +#' @rdname column_aggregate_functions #' @export -#' @aliases approxCountDistinct,Column-method -#' @examples \dontrun{approxCountDistinct(df$c)} +#' @aliases approxCountDistinct approxCountDistinct,Column-method +#' @examples +#' +#' \dontrun{ +#' head(select(df, approxCountDistinct(df$gear))) +#' head(select(df, approxCountDistinct(df$gear, 0.02))) +#' head(select(df, countDistinct(df$gear, df$cyl))) +#' head(select(df, n_distinct(df$gear))) +#' head(distinct(select(df, "gear")))} #' @note approxCountDistinct(Column) since 1.4.0 setMethod("approxCountDistinct", signature(x = "Column"), @@ -342,10 +361,13 @@ setMethod("column", #' #' @rdname corr #' @name corr -#' @family math functions +#' @family aggregate functions #' @export #' @aliases corr,Column-method -#' @examples \dontrun{corr(df$c, df$d)} +#' @examples +#' \dontrun{ +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' head(select(df, corr(df$mpg, df$hp)))} #' @note corr since 1.6.0 setMethod("corr", signature(x = "Column"), function(x, col2) { @@ -356,20 +378,22 @@ setMethod("corr", signature(x = "Column"), #' cov #' -#' Compute the sample covariance between two expressions. +#' Compute the covariance between two expressions. +#' +#' @details +#' \code{cov}: Compute the sample covariance between two expressions. #' #' @rdname cov #' @name cov -#' @family math functions +#' @family aggregate functions #' @export #' @aliases cov,characterOrColumn-method #' @examples #' \dontrun{ -#' cov(df$c, df$d) -#' cov("c", "d") -#' covar_samp(df$c, df$d) -#' covar_samp("c", "d") -#' } +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' head(select(df, cov(df$mpg, df$hp), cov("mpg", "hp"), +#' covar_samp(df$mpg, df$hp), covar_samp("mpg", "hp"), +#' covar_pop(df$mpg, df$hp), covar_pop("mpg", "hp")))} #' @note cov since 1.6.0 setMethod("cov", signature(x = "characterOrColumn"), function(x, col2) { @@ -377,6 +401,9 @@ setMethod("cov", signature(x = "characterOrColumn"), covar_samp(x, col2) }) +#' @details +#' \code{covar_sample}: Alias for \code{cov}. +#' #' @rdname cov #' #' @param col1 the first Column. @@ -395,23 +422,13 @@ setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterO column(jc) }) -#' covar_pop +#' @details +#' \code{covar_pop}: Computes the population covariance between two expressions. #' -#' Compute the population covariance between two expressions. -#' -#' @param col1 First column to compute cov_pop. -#' @param col2 Second column to compute cov_pop. -#' -#' @rdname covar_pop +#' @rdname cov #' @name covar_pop -#' @family math functions #' @export #' @aliases covar_pop,characterOrColumn,characterOrColumn-method -#' @examples -#' \dontrun{ -#' covar_pop(df$c, df$d) -#' covar_pop("c", "d") -#' } #' @note covar_pop since 2.0.0 setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), function(col1, col2) { @@ -823,18 +840,16 @@ setMethod("isnan", column(jc) }) -#' kurtosis -#' -#' Aggregate function: returns the kurtosis of the values in a group. +#' @details +#' \code{kurtosis}: Returns the kurtosis of the values in a group. #' -#' @param x Column to compute on. -#' -#' @rdname kurtosis -#' @name kurtosis -#' @aliases kurtosis,Column-method -#' @family aggregate functions +#' @rdname column_aggregate_functions +#' @aliases kurtosis kurtosis,Column-method #' @export -#' @examples \dontrun{kurtosis(df$c)} +#' @examples +#' +#' \dontrun{ +#' head(select(df, mean(df$mpg), sd(df$mpg), skewness(df$mpg), kurtosis(df$mpg)))} #' @note kurtosis since 1.6.0 setMethod("kurtosis", signature(x = "Column"), @@ -1040,18 +1055,11 @@ setMethod("ltrim", column(jc) }) -#' max -#' -#' Aggregate function: returns the maximum value of the expression in a group. -#' -#' @param x Column to compute on. +#' @details +#' \code{max}: Returns the maximum value of the expression in a group. #' -#' @rdname max -#' @name max -#' @family aggregate functions -#' @aliases max,Column-method -#' @export -#' @examples \dontrun{max(df$c)} +#' @rdname column_aggregate_functions +#' @aliases max max,Column-method #' @note max since 1.5.0 setMethod("max", signature(x = "Column"), @@ -1081,19 +1089,24 @@ setMethod("md5", column(jc) }) -#' mean +#' @details +#' \code{mean}: Returns the average of the values in a group. Alias for \code{avg}. #' -#' Aggregate function: returns the average of the values in a group. -#' Alias for avg. +#' @rdname column_aggregate_functions +#' @aliases mean mean,Column-method +#' @export +#' @examples #' -#' @param x Column to compute on. +#' \dontrun{ +#' head(select(df, avg(df$mpg), mean(df$mpg), sum(df$mpg), min(df$wt), max(df$qsec))) #' -#' @rdname mean -#' @name mean -#' @family aggregate functions -#' @aliases mean,Column-method -#' @export -#' @examples \dontrun{mean(df$c)} +#' # metrics by num of cylinders +#' tmp <- agg(groupBy(df, "cyl"), avg(df$mpg), avg(df$hp), avg(df$wt), avg(df$qsec)) +#' head(orderBy(tmp, "cyl")) +#' +#' # car with the max mpg +#' mpg_max <- as.numeric(collect(agg(df, max(df$mpg)))) +#' head(where(df, df$mpg == mpg_max))} #' @note mean since 1.5.0 setMethod("mean", signature(x = "Column"), @@ -1102,18 +1115,12 @@ setMethod("mean", column(jc) }) -#' min -#' -#' Aggregate function: returns the minimum value of the expression in a group. -#' -#' @param x Column to compute on. +#' @details +#' \code{min}: Returns the minimum value of the expression in a group. #' -#' @rdname min -#' @name min -#' @aliases min,Column-method -#' @family aggregate functions +#' @rdname column_aggregate_functions +#' @aliases min min,Column-method #' @export -#' @examples \dontrun{min(df$c)} #' @note min since 1.5.0 setMethod("min", signature(x = "Column"), @@ -1338,24 +1345,17 @@ setMethod("rtrim", column(jc) }) -#' sd -#' -#' Aggregate function: alias for \link{stddev_samp} + +#' @details +#' \code{sd}: Alias for \code{stddev_samp}. #' -#' @param x Column to compute on. -#' @param na.rm currently not used. -#' @rdname sd -#' @name sd -#' @family aggregate functions -#' @aliases sd,Column-method -#' @seealso \link{stddev_pop}, \link{stddev_samp} +#' @rdname column_aggregate_functions +#' @aliases sd sd,Column-method #' @export #' @examples -#'\dontrun{ -#'stddev(df$c) -#'select(df, stddev(df$age)) -#'agg(df, sd(df$age)) -#'} +#' +#' \dontrun{ +#' head(select(df, sd(df$mpg), stddev(df$mpg), stddev_pop(df$wt), stddev_samp(df$qsec)))} #' @note sd since 1.6.0 setMethod("sd", signature(x = "Column"), @@ -1465,18 +1465,12 @@ setMethod("sinh", column(jc) }) -#' skewness -#' -#' Aggregate function: returns the skewness of the values in a group. -#' -#' @param x Column to compute on. +#' @details +#' \code{skewness}: Returns the skewness of the values in a group. #' -#' @rdname skewness -#' @name skewness -#' @family aggregate functions -#' @aliases skewness,Column-method +#' @rdname column_aggregate_functions +#' @aliases skewness skewness,Column-method #' @export -#' @examples \dontrun{skewness(df$c)} #' @note skewness since 1.6.0 setMethod("skewness", signature(x = "Column"), @@ -1527,9 +1521,11 @@ setMethod("spark_partition_id", column(jc) }) -#' @rdname sd -#' @aliases stddev,Column-method -#' @name stddev +#' @details +#' \code{stddev}: Alias for \code{std_dev}. +#' +#' @rdname column_aggregate_functions +#' @aliases stddev stddev,Column-method #' @note stddev since 1.6.0 setMethod("stddev", signature(x = "Column"), @@ -1538,19 +1534,12 @@ setMethod("stddev", column(jc) }) -#' stddev_pop -#' -#' Aggregate function: returns the population standard deviation of the expression in a group. -#' -#' @param x Column to compute on. +#' @details +#' \code{stddev_pop}: Returns the population standard deviation of the expression in a group. #' -#' @rdname stddev_pop -#' @name stddev_pop -#' @family aggregate functions -#' @aliases stddev_pop,Column-method -#' @seealso \link{sd}, \link{stddev_samp} +#' @rdname column_aggregate_functions +#' @aliases stddev_pop stddev_pop,Column-method #' @export -#' @examples \dontrun{stddev_pop(df$c)} #' @note stddev_pop since 1.6.0 setMethod("stddev_pop", signature(x = "Column"), @@ -1559,19 +1548,12 @@ setMethod("stddev_pop", column(jc) }) -#' stddev_samp -#' -#' Aggregate function: returns the unbiased sample standard deviation of the expression in a group. -#' -#' @param x Column to compute on. +#' @details +#' \code{stddev_samp}: Returns the unbiased sample standard deviation of the expression in a group. #' -#' @rdname stddev_samp -#' @name stddev_samp -#' @family aggregate functions -#' @aliases stddev_samp,Column-method -#' @seealso \link{stddev_pop}, \link{sd} +#' @rdname column_aggregate_functions +#' @aliases stddev_samp stddev_samp,Column-method #' @export -#' @examples \dontrun{stddev_samp(df$c)} #' @note stddev_samp since 1.6.0 setMethod("stddev_samp", signature(x = "Column"), @@ -1630,18 +1612,12 @@ setMethod("sqrt", column(jc) }) -#' sum -#' -#' Aggregate function: returns the sum of all values in the expression. -#' -#' @param x Column to compute on. +#' @details +#' \code{sum}: Returns the sum of all values in the expression. #' -#' @rdname sum -#' @name sum -#' @family aggregate functions -#' @aliases sum,Column-method +#' @rdname column_aggregate_functions +#' @aliases sum sum,Column-method #' @export -#' @examples \dontrun{sum(df$c)} #' @note sum since 1.5.0 setMethod("sum", signature(x = "Column"), @@ -1650,18 +1626,17 @@ setMethod("sum", column(jc) }) -#' sumDistinct -#' -#' Aggregate function: returns the sum of distinct values in the expression. +#' @details +#' \code{sumDistinct}: Returns the sum of distinct values in the expression. #' -#' @param x Column to compute on. -#' -#' @rdname sumDistinct -#' @name sumDistinct -#' @family aggregate functions -#' @aliases sumDistinct,Column-method +#' @rdname column_aggregate_functions +#' @aliases sumDistinct sumDistinct,Column-method #' @export -#' @examples \dontrun{sumDistinct(df$c)} +#' @examples +#' +#' \dontrun{ +#' head(select(df, sumDistinct(df$gear))) +#' head(distinct(select(df, "gear")))} #' @note sumDistinct since 1.4.0 setMethod("sumDistinct", signature(x = "Column"), @@ -1952,24 +1927,16 @@ setMethod("upper", column(jc) }) -#' var -#' -#' Aggregate function: alias for \link{var_samp}. +#' @details +#' \code{var}: Alias for \code{var_samp}. #' -#' @param x a Column to compute on. -#' @param y,na.rm,use currently not used. -#' @rdname var -#' @name var -#' @family aggregate functions -#' @aliases var,Column-method -#' @seealso \link{var_pop}, \link{var_samp} +#' @rdname column_aggregate_functions +#' @aliases var var,Column-method #' @export #' @examples +#' #'\dontrun{ -#'variance(df$c) -#'select(df, var_pop(df$age)) -#'agg(df, var(df$age)) -#'} +#'head(agg(df, var(df$mpg), variance(df$mpg), var_pop(df$mpg), var_samp(df$mpg)))} #' @note var since 1.6.0 setMethod("var", signature(x = "Column"), @@ -1978,9 +1945,9 @@ setMethod("var", var_samp(x) }) -#' @rdname var -#' @aliases variance,Column-method -#' @name variance +#' @rdname column_aggregate_functions +#' @aliases variance variance,Column-method +#' @export #' @note variance since 1.6.0 setMethod("variance", signature(x = "Column"), @@ -1989,19 +1956,12 @@ setMethod("variance", column(jc) }) -#' var_pop +#' @details +#' \code{var_pop}: Returns the population variance of the values in a group. #' -#' Aggregate function: returns the population variance of the values in a group. -#' -#' @param x Column to compute on. -#' -#' @rdname var_pop -#' @name var_pop -#' @family aggregate functions -#' @aliases var_pop,Column-method -#' @seealso \link{var}, \link{var_samp} +#' @rdname column_aggregate_functions +#' @aliases var_pop var_pop,Column-method #' @export -#' @examples \dontrun{var_pop(df$c)} #' @note var_pop since 1.5.0 setMethod("var_pop", signature(x = "Column"), @@ -2010,19 +1970,12 @@ setMethod("var_pop", column(jc) }) -#' var_samp +#' @details +#' \code{var_samp}: Returns the unbiased variance of the values in a group. #' -#' Aggregate function: returns the unbiased variance of the values in a group. -#' -#' @param x Column to compute on. -#' -#' @rdname var_samp -#' @name var_samp -#' @aliases var_samp,Column-method -#' @family aggregate functions -#' @seealso \link{var_pop}, \link{var} +#' @rdname column_aggregate_functions +#' @aliases var_samp var_samp,Column-method #' @export -#' @examples \dontrun{var_samp(df$c)} #' @note var_samp since 1.6.0 setMethod("var_samp", signature(x = "Column"), @@ -2235,17 +2188,11 @@ setMethod("pmod", signature(y = "Column"), column(jc) }) - -#' @rdname approxCountDistinct -#' @name approxCountDistinct -#' -#' @param x Column to compute on. #' @param rsd maximum estimation error allowed (default = 0.05) -#' @param ... further arguments to be passed to or from other methods. #' +#' @rdname column_aggregate_functions #' @aliases approxCountDistinct,Column-method #' @export -#' @examples \dontrun{approxCountDistinct(df$c, 0.02)} #' @note approxCountDistinct(Column, numeric) since 1.4.0 setMethod("approxCountDistinct", signature(x = "Column"), @@ -2254,18 +2201,12 @@ setMethod("approxCountDistinct", column(jc) }) -#' Count Distinct Values +#' @details +#' \code{countDistinct}: Returns the number of distinct items in a group. #' -#' @param x Column to compute on -#' @param ... other columns -#' -#' @family aggregate functions -#' @rdname countDistinct -#' @name countDistinct -#' @aliases countDistinct,Column-method -#' @return the number of distinct items in a group. +#' @rdname column_aggregate_functions +#' @aliases countDistinct countDistinct,Column-method #' @export -#' @examples \dontrun{countDistinct(df$c)} #' @note countDistinct since 1.4.0 setMethod("countDistinct", signature(x = "Column"), @@ -2384,15 +2325,12 @@ setMethod("sign", signature(x = "Column"), signum(x) }) -#' n_distinct -#' -#' Aggregate function: returns the number of distinct items in a group. +#' @details +#' \code{n_distinct}: Returns the number of distinct items in a group. #' -#' @rdname countDistinct -#' @name n_distinct -#' @aliases n_distinct,Column-method +#' @rdname column_aggregate_functions +#' @aliases n_distinct n_distinct,Column-method #' @export -#' @examples \dontrun{n_distinct(df$c)} #' @note n_distinct since 1.4.0 setMethod("n_distinct", signature(x = "Column"), function(x, ...) { @@ -3717,18 +3655,18 @@ setMethod("create_map", column(jc) }) -#' collect_list +#' @details +#' \code{collect_list}: Creates a list of objects with duplicates. #' -#' Creates a list of objects with duplicates. -#' -#' @param x Column to compute on -#' -#' @rdname collect_list -#' @name collect_list -#' @family aggregate functions -#' @aliases collect_list,Column-method +#' @rdname column_aggregate_functions +#' @aliases collect_list collect_list,Column-method #' @export -#' @examples \dontrun{collect_list(df$x)} +#' @examples +#' +#' \dontrun{ +#' df2 = df[df$mpg > 20, ] +#' collect(select(df2, collect_list(df2$gear))) +#' collect(select(df2, collect_set(df2$gear)))} #' @note collect_list since 2.3.0 setMethod("collect_list", signature(x = "Column"), @@ -3737,18 +3675,12 @@ setMethod("collect_list", column(jc) }) -#' collect_set -#' -#' Creates a list of objects with duplicate elements eliminated. +#' @details +#' \code{collect_set}: Creates a list of objects with duplicate elements eliminated. #' -#' @param x Column to compute on -#' -#' @rdname collect_set -#' @name collect_set -#' @family aggregate functions -#' @aliases collect_set,Column-method +#' @rdname column_aggregate_functions +#' @aliases collect_set collect_set,Column-method #' @export -#' @examples \dontrun{collect_set(df$x)} #' @note collect_set since 2.3.0 setMethod("collect_set", signature(x = "Column"), @@ -3908,24 +3840,17 @@ setMethod("not", column(jc) }) -#' grouping_bit -#' -#' Indicates whether a specified column in a GROUP BY list is aggregated or not, -#' returns 1 for aggregated or 0 for not aggregated in the result set. +#' @details +#' \code{grouping_bit}: Indicates whether a specified column in a GROUP BY list is aggregated or not, +#' returns 1 for aggregated or 0 for not aggregated in the result set. Same as \code{GROUPING} in SQL +#' and \code{grouping} function in Scala. #' -#' Same as \code{GROUPING} in SQL and \code{grouping} function in Scala. -#' -#' @param x Column to compute on -#' -#' @rdname grouping_bit -#' @name grouping_bit -#' @family aggregate functions -#' @aliases grouping_bit,Column-method +#' @rdname column_aggregate_functions +#' @aliases grouping_bit grouping_bit,Column-method #' @export #' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) #' +#' \dontrun{ #' # With cube #' agg( #' cube(df, "cyl", "gear", "am"), @@ -3938,8 +3863,7 @@ setMethod("not", #' rollup(df, "cyl", "gear", "am"), #' mean(df$mpg), #' grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am) -#' ) -#' } +#' )} #' @note grouping_bit since 2.3.0 setMethod("grouping_bit", signature(x = "Column"), @@ -3948,26 +3872,18 @@ setMethod("grouping_bit", column(jc) }) -#' grouping_id -#' -#' Returns the level of grouping. -#' +#' @details +#' \code{grouping_id}: Returns the level of grouping. #' Equals to \code{ #' grouping_bit(c1) * 2^(n - 1) + grouping_bit(c2) * 2^(n - 2) + ... + grouping_bit(cn) #' } #' -#' @param x Column to compute on -#' @param ... additional Column(s) (optional). -#' -#' @rdname grouping_id -#' @name grouping_id -#' @family aggregate functions -#' @aliases grouping_id,Column-method +#' @rdname column_aggregate_functions +#' @aliases grouping_id grouping_id,Column-method #' @export #' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) #' +#' \dontrun{ #' # With cube #' agg( #' cube(df, "cyl", "gear", "am"), @@ -3980,8 +3896,7 @@ setMethod("grouping_bit", #' rollup(df, "cyl", "gear", "am"), #' mean(df$mpg), #' grouping_id(df$cyl, df$gear, df$am) -#' ) -#' } +#' )} #' @note grouping_id since 2.3.0 setMethod("grouping_id", signature(x = "Column"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5630d0c8a0df9..b3cc4868a0b33 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -479,7 +479,7 @@ setGeneric("corr", function(x, ...) {standardGeneric("corr") }) #' @export setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") }) -#' @rdname covar_pop +#' @rdname cov #' @export setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") }) @@ -907,8 +907,9 @@ setGeneric("windowOrderBy", function(col, ...) { standardGeneric("windowOrderBy" #' @export setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) -#' @rdname approxCountDistinct +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) #' @rdname array_contains @@ -949,12 +950,14 @@ setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) #' @export setGeneric("ceil", function(x) { standardGeneric("ceil") }) -#' @rdname collect_list +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("collect_list", function(x) { standardGeneric("collect_list") }) -#' @rdname collect_set +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @rdname column @@ -973,8 +976,9 @@ setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) #' @export setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) -#' @rdname countDistinct +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) #' @rdname crc32 @@ -1071,12 +1075,14 @@ setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") #' @export setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) -#' @rdname grouping_bit +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") }) -#' @rdname grouping_id +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) #' @rdname hex @@ -1109,8 +1115,9 @@ setGeneric("instr", function(y, x) { standardGeneric("instr") }) #' @export setGeneric("isnan", function(x) { standardGeneric("isnan") }) -#' @rdname kurtosis +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) #' @rdname lag @@ -1203,8 +1210,9 @@ setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) #' @export setGeneric("ntile", function(x) { standardGeneric("ntile") }) -#' @rdname countDistinct +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) #' @param x empty. Should be used with no argument. @@ -1274,8 +1282,9 @@ setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) #' @export setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) -#' @rdname sd +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) #' @rdname second @@ -1310,8 +1319,9 @@ setGeneric("signum", function(x) { standardGeneric("signum") }) #' @export setGeneric("size", function(x) { standardGeneric("size") }) -#' @rdname skewness +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) #' @rdname sort_array @@ -1331,16 +1341,19 @@ setGeneric("soundex", function(x) { standardGeneric("soundex") }) #' @export setGeneric("spark_partition_id", function(x = "missing") { standardGeneric("spark_partition_id") }) -#' @rdname sd +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("stddev", function(x) { standardGeneric("stddev") }) -#' @rdname stddev_pop +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) -#' @rdname stddev_samp +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) #' @rdname struct @@ -1351,8 +1364,9 @@ setGeneric("struct", function(x, ...) { standardGeneric("struct") }) #' @export setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) -#' @rdname sumDistinct +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) #' @rdname toDegrees @@ -1403,20 +1417,24 @@ setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timesta #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) -#' @rdname var +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("var", function(x, y = NULL, na.rm = FALSE, use) { standardGeneric("var") }) -#' @rdname var +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("variance", function(x) { standardGeneric("variance") }) -#' @rdname var_pop +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("var_pop", function(x) { standardGeneric("var_pop") }) -#' @rdname var_samp +#' @rdname column_aggregate_functions #' @export +#' @name NULL setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) #' @rdname weekofyear diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index d78a10893f92e..9a9fa84044ce6 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -52,22 +52,23 @@ setMethod("crosstab", collect(dataFrame(sct)) }) -#' Calculate the sample covariance of two numerical columns of a SparkDataFrame. +#' @details +#' \code{cov}: When applied to SparkDataFrame, this calculates the sample covariance of two numerical +#' columns of \emph{one} SparkDataFrame. #' #' @param colName1 the name of the first column #' @param colName2 the name of the second column #' @return The covariance of the two columns. #' #' @rdname cov -#' @name cov #' @aliases cov,SparkDataFrame-method #' @family stat functions #' @export #' @examples -#'\dontrun{ -#' df <- read.json("/path/to/file.json") -#' cov <- cov(df, "title", "gender") -#' } +#' +#' \dontrun{ +#' cov(df, "mpg", "hp") +#' cov(df, df$mpg, df$hp)} #' @note cov since 1.6.0 setMethod("cov", signature(x = "SparkDataFrame"), @@ -93,11 +94,10 @@ setMethod("cov", #' @family stat functions #' @export #' @examples -#'\dontrun{ -#' df <- read.json("/path/to/file.json") -#' corr <- corr(df, "title", "gender") -#' corr <- corr(df, "title", "gender", method = "pearson") -#' } +#' +#' \dontrun{ +#' corr(df, "mpg", "hp") +#' corr(df, "mpg", "hp", method = "pearson")} #' @note corr since 1.6.0 setMethod("corr", signature(x = "SparkDataFrame"), From cc67bd573264c9046c4a034927ed8deb2a732110 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 19 Jun 2017 23:04:17 -0700 Subject: [PATCH 039/118] [SPARK-20929][ML] LinearSVC should use its own threshold param ## What changes were proposed in this pull request? LinearSVC should use its own threshold param, rather than the shared one, since it applies to rawPrediction instead of probability. This PR changes the param in the Scala, Python and R APIs. ## How was this patch tested? New unit test to make sure the threshold can be set to any Double value. Author: Joseph K. Bradley Closes #18151 from jkbradley/ml-2.2-linearsvc-cleanup. --- R/pkg/R/mllib_classification.R | 4 ++- .../spark/ml/classification/LinearSVC.scala | 25 +++++++++++-- .../ml/classification/LinearSVCSuite.scala | 35 ++++++++++++++++++- python/pyspark/ml/classification.py | 20 ++++++++++- 4 files changed, 79 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 306a9b8676539..bdcc0818d139d 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -62,7 +62,9 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' of models will be always returned on the original scale, so it will be transparent for #' users. Note that with/without standardization, the models should be always converged #' to the same solution when no regularization is applied. -#' @param threshold The threshold in binary classification, in range [0, 1]. +#' @param threshold The threshold in binary classification applied to the linear model prediction. +#' This threshold can be any real number, where Inf will make all predictions 0.0 +#' and -Inf will make all predictions 1.0. #' @param weightCol The weight column name. #' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features #' or the number of partitions are large, this param could be adjusted to a larger size. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 9900fbc9edda7..d6ed6a4570a4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -42,7 +42,23 @@ import org.apache.spark.sql.functions.{col, lit} /** Params for linear SVM Classifier. */ private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol - with HasThreshold with HasAggregationDepth + with HasAggregationDepth { + + /** + * Param for threshold in binary classification prediction. + * For LinearSVC, this threshold is applied to the rawPrediction, rather than a probability. + * This threshold can be any real number, where Inf will make all predictions 0.0 + * and -Inf will make all predictions 1.0. + * Default: 0.0 + * + * @group param + */ + final val threshold: DoubleParam = new DoubleParam(this, "threshold", + "threshold in binary classification prediction applied to rawPrediction") + + /** @group getParam */ + def getThreshold: Double = $(threshold) +} /** * :: Experimental :: @@ -126,7 +142,7 @@ class LinearSVC @Since("2.2.0") ( def setWeightCol(value: String): this.type = set(weightCol, value) /** - * Set threshold in binary classification, in range [0, 1]. + * Set threshold in binary classification. * * @group setParam */ @@ -284,6 +300,7 @@ class LinearSVCModel private[classification] ( @Since("2.2.0") def setThreshold(value: Double): this.type = set(threshold, value) + setDefault(threshold, 0.0) @Since("2.2.0") def setWeightCol(value: Double): this.type = set(threshold, value) @@ -301,6 +318,10 @@ class LinearSVCModel private[classification] ( Vectors.dense(-m, m) } + override protected def raw2prediction(rawPrediction: Vector): Double = { + if (rawPrediction(1) > $(threshold)) 1.0 else 0.0 + } + @Since("2.2.0") override def copy(extra: ParamMap): LinearSVCModel = { copyValues(new LinearSVCModel(uid, coefficients, intercept), extra).setParent(parent) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index 2f87afc23fe7e..f2b00d0bae1d6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LinearSVCSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -127,6 +127,39 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau MLTestingUtils.checkCopyAndUids(lsvc, model) } + test("LinearSVC threshold acts on rawPrediction") { + val lsvc = + new LinearSVCModel(uid = "myLSVCM", coefficients = Vectors.dense(1.0), intercept = 0.0) + val df = spark.createDataFrame(Seq( + (1, Vectors.dense(1e-7)), + (0, Vectors.dense(0.0)), + (-1, Vectors.dense(-1e-7)))).toDF("id", "features") + + def checkOneResult( + model: LinearSVCModel, + threshold: Double, + expected: Set[(Int, Double)]): Unit = { + model.setThreshold(threshold) + val results = model.transform(df).select("id", "prediction").collect() + .map(r => (r.getInt(0), r.getDouble(1))) + .toSet + assert(results === expected, s"Failed for threshold = $threshold") + } + + def checkResults(threshold: Double, expected: Set[(Int, Double)]): Unit = { + // Check via code path using Classifier.raw2prediction + lsvc.setRawPredictionCol("rawPrediction") + checkOneResult(lsvc, threshold, expected) + // Check via code path using Classifier.predict + lsvc.setRawPredictionCol("") + checkOneResult(lsvc, threshold, expected) + } + + checkResults(0.0, Set((1, 1.0), (0, 0.0), (-1, 0.0))) + checkResults(Double.PositiveInfinity, Set((1, 0.0), (0, 0.0), (-1, 0.0))) + checkResults(Double.NegativeInfinity, Set((1, 1.0), (0, 1.0), (-1, 1.0))) + } + test("linear svc doesn't fit intercept when fitIntercept is off") { val lsvc = new LinearSVC().setFitIntercept(false).setMaxIter(5) val model = lsvc.fit(smallBinaryDataset) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 60bdeedd6a144..9b345ac73f3d9 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -63,7 +63,7 @@ def numClasses(self): @inherit_doc class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasRawPredictionCol, HasFitIntercept, HasStandardization, - HasThreshold, HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable): + HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -109,6 +109,12 @@ class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, Ha .. versionadded:: 2.2.0 """ + threshold = Param(Params._dummy(), "threshold", + "The threshold in binary classification applied to the linear model" + " prediction. This threshold can be any real number, where Inf will make" + " all predictions 0.0 and -Inf will make all predictions 1.0.", + typeConverter=TypeConverters.toFloat) + @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", @@ -147,6 +153,18 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LinearSVCModel(java_model) + def setThreshold(self, value): + """ + Sets the value of :py:attr:`threshold`. + """ + return self._set(threshold=value) + + def getThreshold(self): + """ + Gets the value of threshold or its default value. + """ + return self.getOrDefault(self.threshold) + class LinearSVCModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): """ From ef1622899ffc6ab136102ffc6bcc714402e6f334 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 20 Jun 2017 17:17:21 +0800 Subject: [PATCH 040/118] [SPARK-20989][CORE] Fail to start multiple workers on one host if external shuffle service is enabled in standalone mode ## What changes were proposed in this pull request? In standalone mode, if we enable external shuffle service by setting `spark.shuffle.service.enabled` to true, and then we try to start multiple workers on one host(by setting `SPARK_WORKER_INSTANCES=3` in spark-env.sh, and then run `sbin/start-slaves.sh`), we can only launch one worker on each host successfully and the rest of the workers fail to launch. The reason is the port of external shuffle service if configed by `spark.shuffle.service.port`, so currently we could start no more than one external shuffle service on each host. In our case, each worker tries to start a external shuffle service, and only one of them succeeded doing this. We should give explicit reason of failure instead of fail silently. ## How was this patch tested? Manually test by the following steps: 1. SET `SPARK_WORKER_INSTANCES=1` in `conf/spark-env.sh`; 2. SET `spark.shuffle.service.enabled` to `true` in `conf/spark-defaults.conf`; 3. Run `sbin/start-all.sh`. Before the change, you will see no error in the command line, as the following: ``` starting org.apache.spark.deploy.master.Master, logging to /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.master.Master-1-xxx.local.out localhost: starting org.apache.spark.deploy.worker.Worker, logging to /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-1-xxx.local.out localhost: starting org.apache.spark.deploy.worker.Worker, logging to /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-2-xxx.local.out localhost: starting org.apache.spark.deploy.worker.Worker, logging to /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-3-xxx.local.out ``` And you can see in the webUI that only one worker is running. After the change, you get explicit error messages in the command line: ``` starting org.apache.spark.deploy.master.Master, logging to /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.master.Master-1-xxx.local.out localhost: starting org.apache.spark.deploy.worker.Worker, logging to /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-1-xxx.local.out localhost: failed to launch: nice -n 0 /Users/xxx/workspace/spark/bin/spark-class org.apache.spark.deploy.worker.Worker --webui-port 8081 spark://xxx.local:7077 localhost: 17/06/13 23:24:53 INFO SecurityManager: Changing view acls to: xxx localhost: 17/06/13 23:24:53 INFO SecurityManager: Changing modify acls to: xxx localhost: 17/06/13 23:24:53 INFO SecurityManager: Changing view acls groups to: localhost: 17/06/13 23:24:53 INFO SecurityManager: Changing modify acls groups to: localhost: 17/06/13 23:24:53 INFO SecurityManager: SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(xxx); groups with view permissions: Set(); users with modify permissions: Set(xxx); groups with modify permissions: Set() localhost: 17/06/13 23:24:54 INFO Utils: Successfully started service 'sparkWorker' on port 63354. localhost: Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: Start multiple worker on one host failed because we may launch no more than one external shuffle service on each host, please set spark.shuffle.service.enabled to false or set SPARK_WORKER_INSTANCES to 1 to resolve the conflict. localhost: at scala.Predef$.require(Predef.scala:224) localhost: at org.apache.spark.deploy.worker.Worker$.main(Worker.scala:752) localhost: at org.apache.spark.deploy.worker.Worker.main(Worker.scala) localhost: full log in /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-1-xxx.local.out localhost: starting org.apache.spark.deploy.worker.Worker, logging to /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-2-xxx.local.out localhost: failed to launch: nice -n 0 /Users/xxx/workspace/spark/bin/spark-class org.apache.spark.deploy.worker.Worker --webui-port 8082 spark://xxx.local:7077 localhost: 17/06/13 23:24:56 INFO SecurityManager: Changing view acls to: xxx localhost: 17/06/13 23:24:56 INFO SecurityManager: Changing modify acls to: xxx localhost: 17/06/13 23:24:56 INFO SecurityManager: Changing view acls groups to: localhost: 17/06/13 23:24:56 INFO SecurityManager: Changing modify acls groups to: localhost: 17/06/13 23:24:56 INFO SecurityManager: SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(xxx); groups with view permissions: Set(); users with modify permissions: Set(xxx); groups with modify permissions: Set() localhost: 17/06/13 23:24:56 INFO Utils: Successfully started service 'sparkWorker' on port 63359. localhost: Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: Start multiple worker on one host failed because we may launch no more than one external shuffle service on each host, please set spark.shuffle.service.enabled to false or set SPARK_WORKER_INSTANCES to 1 to resolve the conflict. localhost: at scala.Predef$.require(Predef.scala:224) localhost: at org.apache.spark.deploy.worker.Worker$.main(Worker.scala:752) localhost: at org.apache.spark.deploy.worker.Worker.main(Worker.scala) localhost: full log in /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-2-xxx.local.out localhost: starting org.apache.spark.deploy.worker.Worker, logging to /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-3-xxx.local.out localhost: failed to launch: nice -n 0 /Users/xxx/workspace/spark/bin/spark-class org.apache.spark.deploy.worker.Worker --webui-port 8083 spark://xxx.local:7077 localhost: 17/06/13 23:24:59 INFO SecurityManager: Changing view acls to: xxx localhost: 17/06/13 23:24:59 INFO SecurityManager: Changing modify acls to: xxx localhost: 17/06/13 23:24:59 INFO SecurityManager: Changing view acls groups to: localhost: 17/06/13 23:24:59 INFO SecurityManager: Changing modify acls groups to: localhost: 17/06/13 23:24:59 INFO SecurityManager: SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(xxx); groups with view permissions: Set(); users with modify permissions: Set(xxx); groups with modify permissions: Set() localhost: 17/06/13 23:24:59 INFO Utils: Successfully started service 'sparkWorker' on port 63360. localhost: Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: Start multiple worker on one host failed because we may launch no more than one external shuffle service on each host, please set spark.shuffle.service.enabled to false or set SPARK_WORKER_INSTANCES to 1 to resolve the conflict. localhost: at scala.Predef$.require(Predef.scala:224) localhost: at org.apache.spark.deploy.worker.Worker$.main(Worker.scala:752) localhost: at org.apache.spark.deploy.worker.Worker.main(Worker.scala) localhost: full log in /Users/xxx/workspace/spark/logs/spark-xxx-org.apache.spark.deploy.worker.Worker-3-xxx.local.out ``` Author: Xingbo Jiang Closes #18290 from jiangxb1987/start-slave. --- .../scala/org/apache/spark/deploy/worker/Worker.scala | 11 +++++++++++ sbin/spark-daemon.sh | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 1198e3cb05eaa..bed47455680dd 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -742,6 +742,17 @@ private[deploy] object Worker extends Logging { val args = new WorkerArguments(argStrings, conf) val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir, conf = conf) + // With external shuffle service enabled, if we request to launch multiple workers on one host, + // we can only successfully launch the first worker and the rest fails, because with the port + // bound, we may launch no more than one external shuffle service on each host. + // When this happens, we should give explicit reason of failure instead of fail silently. For + // more detail see SPARK-20989. + val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + val sparkWorkerInstances = scala.sys.env.getOrElse("SPARK_WORKER_INSTANCES", "1").toInt + require(externalShuffleServiceEnabled == false || sparkWorkerInstances <= 1, + "Starting multiple workers on one host is failed because we may launch no more than one " + + "external shuffle service on each host, please set spark.shuffle.service.enabled to " + + "false or set SPARK_WORKER_INSTANCES to 1 to resolve the conflict.") rpcEnv.awaitTermination() } diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index c227c9828e6ac..6de67e039b48f 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -143,7 +143,7 @@ execute_command() { # Check if the process has died; in that case we'll tail the log so the user can see if [[ ! $(ps -p "$newpid" -o comm=) =~ "java" ]]; then echo "failed to launch: $@" - tail -2 "$log" | sed 's/^/ /' + tail -10 "$log" | sed 's/^/ /' echo "full log in $log" fi else From e862dc904963cf7832bafc1d3d0ea9090bbddd81 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 20 Jun 2017 09:15:33 -0700 Subject: [PATCH 041/118] [SPARK-21150][SQL] Persistent view stored in Hive metastore should be case preserving ## What changes were proposed in this pull request? This is a regression in Spark 2.2. In Spark 2.2, we introduced a new way to resolve persisted view: https://issues.apache.org/jira/browse/SPARK-18209 , but this makes the persisted view non case-preserving because we store the schema in hive metastore directly. We should follow data source table and store schema in table properties. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #18360 from cloud-fan/view. --- .../spark/sql/execution/command/views.scala | 4 +- .../spark/sql/execution/SQLViewSuite.scala | 10 +++ .../spark/sql/hive/HiveExternalCatalog.scala | 84 ++++++++++--------- 3 files changed, 56 insertions(+), 42 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 1945d68241343..a6d56ca91a3ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -159,7 +159,9 @@ case class CreateViewCommand( checkCyclicViewReference(analyzedPlan, Seq(viewIdent), viewIdent) // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` - catalog.alterTable(prepareTable(sparkSession, analyzedPlan)) + // Nothing we need to retain from the old view, so just drop and create a new one + catalog.dropTable(viewIdent, ignoreIfNotExists = false, purge = false) + catalog.createTable(prepareTable(sparkSession, analyzedPlan), ignoreIfExists = false) } else { // Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already // exists. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index d32716c18ddfb..6761f05bb462a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -669,4 +669,14 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { "positive.")) } } + + test("permanent view should be case-preserving") { + withView("v") { + sql("CREATE VIEW v AS SELECT 1 as aBc") + assert(spark.table("v").schema.head.name == "aBc") + + sql("CREATE OR REPLACE VIEW v AS SELECT 2 as cBa") + assert(spark.table("v").schema.head.name == "cBa") + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 19453679a30df..6e7c475fa34c9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -224,39 +224,36 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat throw new TableAlreadyExistsException(db = db, table = table) } - if (tableDefinition.tableType == VIEW) { - client.createTable(tableDefinition, ignoreIfExists) + // Ideally we should not create a managed table with location, but Hive serde table can + // specify location for managed table. And in [[CreateDataSourceTableAsSelectCommand]] we have + // to create the table directory and write out data before we create this table, to avoid + // exposing a partial written table. + val needDefaultTableLocation = tableDefinition.tableType == MANAGED && + tableDefinition.storage.locationUri.isEmpty + + val tableLocation = if (needDefaultTableLocation) { + Some(CatalogUtils.stringToURI(defaultTablePath(tableDefinition.identifier))) } else { - // Ideally we should not create a managed table with location, but Hive serde table can - // specify location for managed table. And in [[CreateDataSourceTableAsSelectCommand]] we have - // to create the table directory and write out data before we create this table, to avoid - // exposing a partial written table. - val needDefaultTableLocation = tableDefinition.tableType == MANAGED && - tableDefinition.storage.locationUri.isEmpty - - val tableLocation = if (needDefaultTableLocation) { - Some(CatalogUtils.stringToURI(defaultTablePath(tableDefinition.identifier))) - } else { - tableDefinition.storage.locationUri - } + tableDefinition.storage.locationUri + } - if (DDLUtils.isHiveTable(tableDefinition)) { - val tableWithDataSourceProps = tableDefinition.copy( - // We can't leave `locationUri` empty and count on Hive metastore to set a default table - // location, because Hive metastore uses hive.metastore.warehouse.dir to generate default - // table location for tables in default database, while we expect to use the location of - // default database. - storage = tableDefinition.storage.copy(locationUri = tableLocation), - // Here we follow data source tables and put table metadata like table schema, partition - // columns etc. in table properties, so that we can work around the Hive metastore issue - // about not case preserving and make Hive serde table support mixed-case column names. - properties = tableDefinition.properties ++ tableMetaToTableProps(tableDefinition)) - client.createTable(tableWithDataSourceProps, ignoreIfExists) - } else { - createDataSourceTable( - tableDefinition.withNewStorage(locationUri = tableLocation), - ignoreIfExists) - } + if (DDLUtils.isDatasourceTable(tableDefinition)) { + createDataSourceTable( + tableDefinition.withNewStorage(locationUri = tableLocation), + ignoreIfExists) + } else { + val tableWithDataSourceProps = tableDefinition.copy( + // We can't leave `locationUri` empty and count on Hive metastore to set a default table + // location, because Hive metastore uses hive.metastore.warehouse.dir to generate default + // table location for tables in default database, while we expect to use the location of + // default database. + storage = tableDefinition.storage.copy(locationUri = tableLocation), + // Here we follow data source tables and put table metadata like table schema, partition + // columns etc. in table properties, so that we can work around the Hive metastore issue + // about not case preserving and make Hive serde table and view support mixed-case column + // names. + properties = tableDefinition.properties ++ tableMetaToTableProps(tableDefinition)) + client.createTable(tableWithDataSourceProps, ignoreIfExists) } } @@ -679,16 +676,21 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat var table = inputTable - if (table.tableType != VIEW) { - table.properties.get(DATASOURCE_PROVIDER) match { - // No provider in table properties, which means this is a Hive serde table. - case None => - table = restoreHiveSerdeTable(table) - - // This is a regular data source table. - case Some(provider) => - table = restoreDataSourceTable(table, provider) - } + table.properties.get(DATASOURCE_PROVIDER) match { + case None if table.tableType == VIEW => + // If this is a view created by Spark 2.2 or higher versions, we should restore its schema + // from table properties. + if (table.properties.contains(DATASOURCE_SCHEMA_NUMPARTS)) { + table = table.copy(schema = getSchemaFromTableProperties(table)) + } + + // No provider in table properties, which means this is a Hive serde table. + case None => + table = restoreHiveSerdeTable(table) + + // This is a regular data source table. + case Some(provider) => + table = restoreDataSourceTable(table, provider) } // Restore Spark's statistics from information in Metastore. From b6b108826a5dd5c889a70180365f9320452557fc Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 20 Jun 2017 11:34:22 -0700 Subject: [PATCH 042/118] [SPARK-21103][SQL] QueryPlanConstraints should be part of LogicalPlan ## What changes were proposed in this pull request? QueryPlanConstraints should be part of LogicalPlan, rather than QueryPlan, since the constraint framework is only used for query plan rewriting and not for physical planning. ## How was this patch tested? Should be covered by existing tests, since it is a simple refactoring. Author: Reynold Xin Closes #18310 from rxin/SPARK-21103. --- .../org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 5 +---- .../spark/sql/catalyst/plans/logical/LogicalPlan.scala | 2 +- .../plans/{ => logical}/QueryPlanConstraints.scala | 7 ++++--- 3 files changed, 6 insertions(+), 8 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/{ => logical}/QueryPlanConstraints.scala (96%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 9130b14763e24..1f6d05bc8d816 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -22,10 +22,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} -abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] - extends TreeNode[PlanType] - with QueryPlanConstraints[PlanType] { - +abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] { self: PlanType => def conf: SQLConf = SQLConf.get diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 2ebb2ff323c6b..95b4165f6b10d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType -abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { +abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstraints with Logging { private var _analyzed: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala similarity index 96% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index b08a009f0dca1..8bffbd0c208cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.plans +package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -trait QueryPlanConstraints[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[PlanType] => +trait QueryPlanConstraints { self: LogicalPlan => /** * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For @@ -99,7 +99,8 @@ trait QueryPlanConstraints[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[Pl private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( expressions.collect { case a: Alias => (a.toAttribute, a.child) - } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints[PlanType]].aliasMap)) + } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap)) + // Note: the explicit cast is necessary, since Scala compiler fails to infer the type. /** * Infers an additional set of constraints from a given set of equality constraints. From 9ce714dca272315ef7f50d791563f22e8d5922ac Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Tue, 20 Jun 2017 22:35:42 -0700 Subject: [PATCH 043/118] [SPARK-10655][SQL] Adding additional data type mappings to jdbc DB2dialect. This patch adds DB2 specific data type mappings for decfloat, real, xml , and timestamp with time zone (DB2Z specific type) types on read and for byte, short data types on write to the to jdbc data source DB2 dialect. Default mapping does not work for these types when reading/writing from DB2 database. Added docker test, and a JDBC unit test case. Author: sureshthalamati Closes #9162 from sureshthalamati/db2dialect_enhancements-spark-10655. --- .../spark/sql/jdbc/DB2IntegrationSuite.scala | 47 +++++++++++++++---- .../apache/spark/sql/jdbc/DB2Dialect.scala | 21 ++++++++- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 9 ++++ 3 files changed, 66 insertions(+), 11 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala index 3da34b1b382d7..f5930bc281e8c 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala @@ -21,10 +21,13 @@ import java.math.BigDecimal import java.sql.{Connection, Date, Timestamp} import java.util.Properties -import org.scalatest._ +import org.scalatest.Ignore +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{BooleanType, ByteType, ShortType, StructType} import org.apache.spark.tags.DockerTest + @DockerTest @Ignore // AMPLab Jenkins needs to be updated before shared memory works on docker class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { @@ -47,19 +50,22 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate() conn.prepareStatement("CREATE TABLE numbers ( small SMALLINT, med INTEGER, big BIGINT, " - + "deci DECIMAL(31,20), flt FLOAT, dbl DOUBLE)").executeUpdate() + + "deci DECIMAL(31,20), flt FLOAT, dbl DOUBLE, real REAL, " + + "decflt DECFLOAT, decflt16 DECFLOAT(16), decflt34 DECFLOAT(34))").executeUpdate() conn.prepareStatement("INSERT INTO numbers VALUES (17, 77777, 922337203685477580, " - + "123456745.56789012345000000000, 42.75, 5.4E-70)").executeUpdate() + + "123456745.56789012345000000000, 42.75, 5.4E-70, " + + "3.4028234663852886e+38, 4.2999, DECFLOAT('9.999999999999999E19', 16), " + + "DECFLOAT('1234567891234567.123456789123456789', 34))").executeUpdate() conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, ts TIMESTAMP )").executeUpdate() conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', " + "'2009-02-13 23:31:30')").executeUpdate() // TODO: Test locale conversion for strings. - conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c CLOB, d BLOB)") - .executeUpdate() - conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', BLOB('fox'))") + conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c CLOB, d BLOB, e XML)") .executeUpdate() + conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', BLOB('fox')," + + "'Kathy')").executeUpdate() } test("Basic test") { @@ -77,13 +83,17 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 6) + assert(types.length == 10) assert(types(0).equals("class java.lang.Integer")) assert(types(1).equals("class java.lang.Integer")) assert(types(2).equals("class java.lang.Long")) assert(types(3).equals("class java.math.BigDecimal")) assert(types(4).equals("class java.lang.Double")) assert(types(5).equals("class java.lang.Double")) + assert(types(6).equals("class java.lang.Float")) + assert(types(7).equals("class java.math.BigDecimal")) + assert(types(8).equals("class java.math.BigDecimal")) + assert(types(9).equals("class java.math.BigDecimal")) assert(rows(0).getInt(0) == 17) assert(rows(0).getInt(1) == 77777) assert(rows(0).getLong(2) == 922337203685477580L) @@ -91,6 +101,10 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { assert(rows(0).getAs[BigDecimal](3).equals(bd)) assert(rows(0).getDouble(4) == 42.75) assert(rows(0).getDouble(5) == 5.4E-70) + assert(rows(0).getFloat(6) == 3.4028234663852886e+38) + assert(rows(0).getDecimal(7) == new BigDecimal("4.299900000000000000")) + assert(rows(0).getDecimal(8) == new BigDecimal("99999999999999990000.000000000000000000")) + assert(rows(0).getDecimal(9) == new BigDecimal("1234567891234567.123456789123456789")) } test("Date types") { @@ -112,7 +126,7 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 4) + assert(types.length == 5) assert(types(0).equals("class java.lang.String")) assert(types(1).equals("class java.lang.String")) assert(types(2).equals("class java.lang.String")) @@ -121,14 +135,27 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { assert(rows(0).getString(1).equals("quick")) assert(rows(0).getString(2).equals("brown")) assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](3), Array[Byte](102, 111, 120))) + assert(rows(0).getString(4).equals("""Kathy""")) } test("Basic write test") { - // val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + // cast decflt column with precision value of 38 to DB2 max decimal precision value of 31. + val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + .selectExpr("small", "med", "big", "deci", "flt", "dbl", "real", + "cast(decflt as decimal(31, 5)) as decflt") val df2 = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) val df3 = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) - // df1.write.jdbc(jdbcUrl, "numberscopy", new Properties) + df1.write.jdbc(jdbcUrl, "numberscopy", new Properties) df2.write.jdbc(jdbcUrl, "datescopy", new Properties) df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) + // spark types that does not have exact matching db2 table types. + val df4 = sqlContext.createDataFrame( + sparkContext.parallelize(Seq(Row("1".toShort, "20".toByte, true))), + new StructType().add("c1", ShortType).add("b", ByteType).add("c3", BooleanType)) + df4.write.jdbc(jdbcUrl, "otherscopy", new Properties) + val rows = sqlContext.read.jdbc(jdbcUrl, "otherscopy", new Properties).collect() + assert(rows(0).getInt(0) == 1) + assert(rows(0).getInt(1) == 20) + assert(rows(0).getString(2) == "1") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 190463df0d928..d160ad82888a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -17,15 +17,34 @@ package org.apache.spark.sql.jdbc -import org.apache.spark.sql.types.{BooleanType, DataType, StringType} +import java.sql.Types + +import org.apache.spark.sql.types._ private object DB2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") + override def getCatalystType( + sqlType: Int, + typeName: String, + size: Int, + md: MetadataBuilder): Option[DataType] = sqlType match { + case Types.REAL => Option(FloatType) + case Types.OTHER => + typeName match { + case "DECFLOAT" => Option(DecimalType(38, 18)) + case "XML" => Option(StringType) + case t if (t.startsWith("TIMESTAMP")) => Option(TimestampType) // TIMESTAMP WITH TIMEZONE + case _ => None + } + case _ => None + } + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) + case ShortType | ByteType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 70bee929b31da..d1daf860fdfff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -713,6 +713,15 @@ class JDBCSuite extends SparkFunSuite val db2Dialect = JdbcDialects.get("jdbc:db2://127.0.0.1/db") assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") + assert(db2Dialect.getJDBCType(ShortType).map(_.databaseTypeDefinition).get == "SMALLINT") + assert(db2Dialect.getJDBCType(ByteType).map(_.databaseTypeDefinition).get == "SMALLINT") + // test db2 dialect mappings on read + assert(db2Dialect.getCatalystType(java.sql.Types.REAL, "REAL", 1, null) == Option(FloatType)) + assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "DECFLOAT", 1, null) == + Option(DecimalType(38, 18))) + assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "XML", 1, null) == Option(StringType)) + assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "TIMESTAMP WITH TIME ZONE", 1, null) == + Option(TimestampType)) } test("PostgresDialect type mapping") { From d107b3b910d8f434fb15b663a9db4c2dfe0a9f43 Mon Sep 17 00:00:00 2001 From: Li Yichao Date: Wed, 21 Jun 2017 21:54:29 +0800 Subject: [PATCH 044/118] [SPARK-20640][CORE] Make rpc timeout and retry for shuffle registration configurable. ## What changes were proposed in this pull request? Currently the shuffle service registration timeout and retry has been hardcoded. This works well for small workloads but under heavy workload when the shuffle service is busy transferring large amount of data we see significant delay in responding to the registration request, as a result we often see the executors fail to register with the shuffle service, eventually failing the job. We need to make these two parameters configurable. ## How was this patch tested? * Updated `BlockManagerSuite` to test registration timeout and max attempts configuration actually works. cc sitalkedia Author: Li Yichao Closes #18092 from liyichao/SPARK-20640. --- .../shuffle/ExternalShuffleClient.java | 7 +- .../mesos/MesosExternalShuffleClient.java | 5 +- .../ExternalShuffleIntegrationSuite.java | 4 +- .../shuffle/ExternalShuffleSecuritySuite.java | 2 +- .../spark/internal/config/package.scala | 13 ++++ .../apache/spark/storage/BlockManager.scala | 7 +- .../spark/storage/BlockManagerSuite.scala | 68 +++++++++++++++++-- docs/configuration.md | 14 ++++ .../MesosCoarseGrainedSchedulerBackend.scala | 4 +- 9 files changed, 109 insertions(+), 15 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 269fa72dad5f5..6ac9302517ee0 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -49,6 +49,7 @@ public class ExternalShuffleClient extends ShuffleClient { private final TransportConf conf; private final boolean authEnabled; private final SecretKeyHolder secretKeyHolder; + private final long registrationTimeoutMs; protected TransportClientFactory clientFactory; protected String appId; @@ -60,10 +61,12 @@ public class ExternalShuffleClient extends ShuffleClient { public ExternalShuffleClient( TransportConf conf, SecretKeyHolder secretKeyHolder, - boolean authEnabled) { + boolean authEnabled, + long registrationTimeoutMs) { this.conf = conf; this.secretKeyHolder = secretKeyHolder; this.authEnabled = authEnabled; + this.registrationTimeoutMs = registrationTimeoutMs; } protected void checkInit() { @@ -132,7 +135,7 @@ public void registerWithShuffleServer( checkInit(); try (TransportClient client = clientFactory.createUnmanagedClient(host, port)) { ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer(); - client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); + client.sendRpcSync(registerMessage, registrationTimeoutMs); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index dbc1010847fb1..60179f126bc44 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -60,8 +60,9 @@ public class MesosExternalShuffleClient extends ExternalShuffleClient { public MesosExternalShuffleClient( TransportConf conf, SecretKeyHolder secretKeyHolder, - boolean authEnabled) { - super(conf, secretKeyHolder, authEnabled); + boolean authEnabled, + long registrationTimeoutMs) { + super(conf, secretKeyHolder, authEnabled, registrationTimeoutMs); } public void registerDriverWithShuffleService( diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 4391e3023491b..a6a1b8d0ac3f1 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -133,7 +133,7 @@ private FetchResult fetchBlocks( final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false); + ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false, 5000); client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @@ -242,7 +242,7 @@ public void testFetchNoServer() throws Exception { private static void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) throws IOException, InterruptedException { - ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, 5000); client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index bf20c577ed420..16bad9f1b319d 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -97,7 +97,7 @@ private void validate(String appId, String secretKey, boolean encrypt) } ExternalShuffleClient client = - new ExternalShuffleClient(testConf, new TestSecretKeyHolder(appId, secretKey), true); + new ExternalShuffleClient(testConf, new TestSecretKeyHolder(appId, secretKey), true, 5000); client.init(appId); // Registration either succeeds or throws an exception. client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 84ef57f2d271b..615497d36fd14 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -303,6 +303,19 @@ package object config { .bytesConf(ByteUnit.BYTE) .createWithDefault(100 * 1024 * 1024) + private[spark] val SHUFFLE_REGISTRATION_TIMEOUT = + ConfigBuilder("spark.shuffle.registration.timeout") + .doc("Timeout in milliseconds for registration to the external shuffle service.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(5000) + + private[spark] val SHUFFLE_REGISTRATION_MAX_ATTEMPTS = + ConfigBuilder("spark.shuffle.registration.maxAttempts") + .doc("When we fail to register to the external shuffle service, we will " + + "retry for maxAttempts times.") + .intConf + .createWithDefault(3) + private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1689baa832d52..74be70348305c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -31,7 +31,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer @@ -174,7 +174,8 @@ private[spark] class BlockManager( // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) - new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled()) + new ExternalShuffleClient(transConf, securityManager, + securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) } else { blockTransferService } @@ -254,7 +255,7 @@ private[spark] class BlockManager( diskBlockManager.subDirsPerLocalDir, shuffleManager.getClass.getName) - val MAX_ATTEMPTS = 3 + val MAX_ATTEMPTS = conf.get(config.SHUFFLE_REGISTRATION_MAX_ATTEMPTS) val SLEEP_TIME_SECS = 5 for (i <- 1 to MAX_ATTEMPTS) { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 9d52b488b223e..88f18294aa015 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -20,13 +20,15 @@ package org.apache.spark.storage import java.io.File import java.nio.ByteBuffer +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.concurrent.Future -import scala.language.implicitConversions -import scala.language.postfixOps +import scala.language.{implicitConversions, postfixOps} import scala.reflect.ClassTag +import org.apache.commons.lang3.RandomUtils import org.mockito.{Matchers => mc} import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest._ @@ -38,10 +40,13 @@ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod import org.apache.spark.internal.config._ import org.apache.spark.memory.UnifiedMemoryManager -import org.apache.spark.network.{BlockDataManager, BlockTransferService} +import org.apache.spark.network.{BlockDataManager, BlockTransferService, TransportContext} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.netty.NettyBlockTransferService +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} +import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} @@ -1281,6 +1286,61 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(master.getLocations("item").isEmpty) } + test("SPARK-20640: Shuffle registration timeout and maxAttempts conf are working") { + val tryAgainMsg = "test_spark_20640_try_again" + // a server which delays response 50ms and must try twice for success. + def newShuffleServer(port: Int): (TransportServer, Int) = { + val attempts = new mutable.HashMap[String, Int]() + val handler = new NoOpRpcHandler { + override def receive( + client: TransportClient, + message: ByteBuffer, + callback: RpcResponseCallback): Unit = { + val msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message) + msgObj match { + case exec: RegisterExecutor => + Thread.sleep(50) + val attempt = attempts.getOrElse(exec.execId, 0) + 1 + attempts(exec.execId) = attempt + if (attempt < 2) { + callback.onFailure(new Exception(tryAgainMsg)) + return + } + callback.onSuccess(ByteBuffer.wrap(new Array[Byte](0))) + } + } + } + + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 0) + val transCtx = new TransportContext(transConf, handler, true) + (transCtx.createServer(port, Seq.empty[TransportServerBootstrap].asJava), port) + } + val candidatePort = RandomUtils.nextInt(1024, 65536) + val (server, shufflePort) = Utils.startServiceOnPort(candidatePort, + newShuffleServer, conf, "ShuffleServer") + + conf.set("spark.shuffle.service.enabled", "true") + conf.set("spark.shuffle.service.port", shufflePort.toString) + conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "40") + conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") + var e = intercept[SparkException]{ + makeBlockManager(8000, "executor1") + }.getMessage + assert(e.contains("TimeoutException")) + + conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") + conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") + e = intercept[SparkException]{ + makeBlockManager(8000, "executor2") + }.getMessage + assert(e.contains(tryAgainMsg)) + + conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") + conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "2") + makeBlockManager(8000, "executor3") + server.close() + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 diff --git a/docs/configuration.md b/docs/configuration.md index c1464741ebb6f..f1c6d04115ab0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -638,6 +638,20 @@ Apart from these, the following properties are also available, and may be useful underestimating shuffle block size when fetch shuffle blocks. + + spark.shuffle.registration.timeout + 5000 + + Timeout in milliseconds for registration to the external shuffle service. + + + + spark.shuffle.registration.maxAttempts + 3 + + When we fail to register to the external shuffle service, we will retry for maxAttempts times. + + spark.io.encryption.enabled false diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 871685c6cccc0..7dd42c41aa7c2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -29,6 +29,7 @@ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.SchedulerDriver import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState} +import org.apache.spark.internal.config import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.RpcEndpointAddress @@ -150,7 +151,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( new MesosExternalShuffleClient( SparkTransportConf.fromSparkConf(conf, "shuffle"), securityManager, - securityManager.isAuthenticationEnabled()) + securityManager.isAuthenticationEnabled(), + conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) } private var nextMesosTaskId = 0 From 987eb8faddbb533e006c769d382a3e4fda3dd6ee Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 21 Jun 2017 15:30:31 +0100 Subject: [PATCH 045/118] [MINOR][DOCS] Add lost tag for configuration.md ## What changes were proposed in this pull request? Add lost `` tag for `configuration.md`. ## How was this patch tested? N/A Author: Yuming Wang Closes #18372 from wangyum/docs-missing-tr. --- docs/configuration.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index f1c6d04115ab0..f4bec589208be 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1566,6 +1566,8 @@ Apart from these, the following properties are also available, and may be useful of this setting is to act as a safety-net to prevent runaway uncancellable tasks from rendering an executor unusable. + + spark.stage.maxConsecutiveAttempts 4 From e92befcb4b57c3e4afe57b6de1622ac72e7d819c Mon Sep 17 00:00:00 2001 From: Marcos P Date: Wed, 21 Jun 2017 15:34:10 +0100 Subject: [PATCH 046/118] [MINOR][DOC] modified issue link and updated status ## What changes were proposed in this pull request? This PR aims to clarify some outdated comments that i found at **spark-catalyst** and **spark-sql** pom files. Maven bug still happening and in order to track it I have updated the issue link and also the status of the issue. Author: Marcos P Closes #18374 from mpenate/fix/mng-3559-comment. --- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 36948ba52b064..0bbf7a95124cf 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -109,7 +109,7 @@ so that the tests classes of external modules can use them. The two execution profiles are necessary - first one for 'mvn package', second one for 'mvn test-compile'. Ideally, 'mvn compile' should not compile test classes and therefore should not need this. - However, an open Maven bug (http://jira.codehaus.org/browse/MNG-3559) + However, a closed due to "Cannot Reproduce" Maven bug (https://issues.apache.org/jira/browse/MNG-3559) causes the compilation to fail if catalyst test-jar is not generated. Hence, the second execution profile for 'mvn test-compile'. --> diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 7327c9b0c9c50..1bc34a6b069d9 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -161,7 +161,7 @@ so that the tests classes of external modules can use them. The two execution profiles are necessary - first one for 'mvn package', second one for 'mvn test-compile'. Ideally, 'mvn compile' should not compile test classes and therefore should not need this. - However, an open Maven bug (http://jira.codehaus.org/browse/MNG-3559) + However, a closed due to "Cannot Reproduce" Maven bug (https://issues.apache.org/jira/browse/MNG-3559) causes the compilation to fail if catalyst test-jar is not generated. Hence, the second execution profile for 'mvn test-compile'. --> From cad88f17e87e6cb96550b70e35d3ed75305dc59d Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Wed, 21 Jun 2017 09:40:06 -0700 Subject: [PATCH 047/118] [SPARK-17851][SQL][TESTS] Make sure all test sqls in catalyst pass checkAnalysis ## What changes were proposed in this pull request? Currently we have several tens of test sqls in catalyst will fail at `SimpleAnalyzer.checkAnalysis`, we should make sure they are valid. This PR makes the following changes: 1. Apply `checkAnalysis` on plans that tests `Optimizer` rules, but don't require the testcases for `Parser`/`Analyzer` pass `checkAnalysis`; 2. Fix testcases for `Optimizer` that would have fall. ## How was this patch tested? Apply `SimpleAnalyzer.checkAnalysis` on plans in `PlanTest.comparePlans`, update invalid test cases. Author: Xingbo Jiang Author: jiangxingbo Closes #15417 from jiangxb1987/cptest. --- .../sql/catalyst/analysis/AnalysisTest.scala | 8 +++ .../analysis/DecimalPrecisionSuite.scala | 2 +- .../catalyst/analysis/TypeCoercionSuite.scala | 2 +- .../catalog/SessionCatalogSuite.scala | 2 +- .../optimizer/AggregateOptimizeSuite.scala | 4 +- .../BooleanSimplificationSuite.scala | 57 ++++++++++--------- .../optimizer/ColumnPruningSuite.scala | 4 +- .../optimizer/ConstantPropagationSuite.scala | 9 ++- .../optimizer/FilterPushdownSuite.scala | 11 ++-- .../optimizer/LimitPushdownSuite.scala | 12 ++-- .../optimizer/OptimizeCodegenSuite.scala | 4 +- .../optimizer/OuterJoinEliminationSuite.scala | 4 +- .../optimizer/SimplifyCastsSuite.scala | 9 ++- .../sql/catalyst/parser/PlanParserSuite.scala | 6 +- .../spark/sql/catalyst/plans/PlanTest.scala | 14 ++++- .../apache/spark/sql/DataFrameHintSuite.scala | 4 +- .../sql/execution/SparkSqlParserSuite.scala | 5 +- .../spark/sql/hive/HiveDDLCommandSuite.scala | 20 +++---- 18 files changed, 101 insertions(+), 76 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index edfa8c45f9867..549a4355dfba3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -59,6 +59,14 @@ trait AnalysisTest extends PlanTest { comparePlans(actualPlan, expectedPlan) } + protected override def comparePlans( + plan1: LogicalPlan, + plan2: LogicalPlan, + checkAnalysis: Boolean = false): Unit = { + // Analysis tests may have not been fully resolved, so skip checkAnalysis. + super.comparePlans(plan1, plan2, checkAnalysis) + } + protected def assertAnalysisSuccess( inputPlan: LogicalPlan, caseSensitive: Boolean = true): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 8f43171f309a9..ccf3c3fb0949d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project, Unio import org.apache.spark.sql.types._ -class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { +class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) private val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 7358f401ed520..b3994ab0828ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval -class TypeCoercionSuite extends PlanTest { +class TypeCoercionSuite extends AnalysisTest { // scalastyle:off line.size.limit // The following table shows all implicit data type conversions that are not visible to the user. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index dce73b3635e72..a6dc21b03d446 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -44,7 +44,7 @@ class InMemorySessionCatalogSuite extends SessionCatalogSuite { * signatures but do not extend a common parent. This is largely by design but * unfortunately leads to very similar test code in two places. */ -abstract class SessionCatalogSuite extends PlanTest { +abstract class SessionCatalogSuite extends AnalysisTest { protected val utils: CatalogTestUtils protected val isHiveExternalCatalog = false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index e6132ab2e4d17..a3184a4266c7c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -59,9 +59,9 @@ class AggregateOptimizeSuite extends PlanTest { } test("Remove aliased literals") { - val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b)) + val query = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a, 'y)(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) - val correctAnswer = testRelation.select('a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze + val correctAnswer = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a)(sum('b)).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 1df0a89cf0bf1..c6345b60b744b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -41,7 +41,8 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { PruneFilters) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) + val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string, + 'e.boolean, 'f.boolean, 'g.boolean, 'h.boolean) val testRelationWithData = LocalRelation.fromExternalRows( testRelation.output, Seq(Row(1, 2, 3, "abc")) @@ -101,52 +102,52 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { 'a === 'b || 'b > 3 && 'a > 3 && 'a < 5) } - test("a && (!a || b)") { - checkCondition('a && (!'a || 'b ), 'a && 'b) + test("e && (!e || f)") { + checkCondition('e && (!'e || 'f ), 'e && 'f) - checkCondition('a && ('b || !'a ), 'a && 'b) + checkCondition('e && ('f || !'e ), 'e && 'f) - checkCondition((!'a || 'b ) && 'a, 'b && 'a) + checkCondition((!'e || 'f ) && 'e, 'f && 'e) - checkCondition(('b || !'a ) && 'a, 'b && 'a) + checkCondition(('f || !'e ) && 'e, 'f && 'e) } - test("a < 1 && (!(a < 1) || b)") { - checkCondition('a < 1 && (!('a < 1) || 'b), ('a < 1) && 'b) - checkCondition('a < 1 && ('b || !('a < 1)), ('a < 1) && 'b) + test("a < 1 && (!(a < 1) || f)") { + checkCondition('a < 1 && (!('a < 1) || 'f), ('a < 1) && 'f) + checkCondition('a < 1 && ('f || !('a < 1)), ('a < 1) && 'f) - checkCondition('a <= 1 && (!('a <= 1) || 'b), ('a <= 1) && 'b) - checkCondition('a <= 1 && ('b || !('a <= 1)), ('a <= 1) && 'b) + checkCondition('a <= 1 && (!('a <= 1) || 'f), ('a <= 1) && 'f) + checkCondition('a <= 1 && ('f || !('a <= 1)), ('a <= 1) && 'f) - checkCondition('a > 1 && (!('a > 1) || 'b), ('a > 1) && 'b) - checkCondition('a > 1 && ('b || !('a > 1)), ('a > 1) && 'b) + checkCondition('a > 1 && (!('a > 1) || 'f), ('a > 1) && 'f) + checkCondition('a > 1 && ('f || !('a > 1)), ('a > 1) && 'f) - checkCondition('a >= 1 && (!('a >= 1) || 'b), ('a >= 1) && 'b) - checkCondition('a >= 1 && ('b || !('a >= 1)), ('a >= 1) && 'b) + checkCondition('a >= 1 && (!('a >= 1) || 'f), ('a >= 1) && 'f) + checkCondition('a >= 1 && ('f || !('a >= 1)), ('a >= 1) && 'f) } - test("a < 1 && ((a >= 1) || b)") { - checkCondition('a < 1 && ('a >= 1 || 'b ), ('a < 1) && 'b) - checkCondition('a < 1 && ('b || 'a >= 1), ('a < 1) && 'b) + test("a < 1 && ((a >= 1) || f)") { + checkCondition('a < 1 && ('a >= 1 || 'f ), ('a < 1) && 'f) + checkCondition('a < 1 && ('f || 'a >= 1), ('a < 1) && 'f) - checkCondition('a <= 1 && ('a > 1 || 'b ), ('a <= 1) && 'b) - checkCondition('a <= 1 && ('b || 'a > 1), ('a <= 1) && 'b) + checkCondition('a <= 1 && ('a > 1 || 'f ), ('a <= 1) && 'f) + checkCondition('a <= 1 && ('f || 'a > 1), ('a <= 1) && 'f) - checkCondition('a > 1 && (('a <= 1) || 'b), ('a > 1) && 'b) - checkCondition('a > 1 && ('b || ('a <= 1)), ('a > 1) && 'b) + checkCondition('a > 1 && (('a <= 1) || 'f), ('a > 1) && 'f) + checkCondition('a > 1 && ('f || ('a <= 1)), ('a > 1) && 'f) - checkCondition('a >= 1 && (('a < 1) || 'b), ('a >= 1) && 'b) - checkCondition('a >= 1 && ('b || ('a < 1)), ('a >= 1) && 'b) + checkCondition('a >= 1 && (('a < 1) || 'f), ('a >= 1) && 'f) + checkCondition('a >= 1 && ('f || ('a < 1)), ('a >= 1) && 'f) } test("DeMorgan's law") { - checkCondition(!('a && 'b), !'a || !'b) + checkCondition(!('e && 'f), !'e || !'f) - checkCondition(!('a || 'b), !'a && !'b) + checkCondition(!('e || 'f), !'e && !'f) - checkCondition(!(('a && 'b) || ('c && 'd)), (!'a || !'b) && (!'c || !'d)) + checkCondition(!(('e && 'f) || ('g && 'h)), (!'e || !'f) && (!'g || !'h)) - checkCondition(!(('a || 'b) && ('c || 'd)), (!'a && !'b) || (!'c && !'d)) + checkCondition(!(('e || 'f) && ('g || 'h)), (!'e && !'f) || (!'g && !'h)) } private val caseInsensitiveConf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index a0a0daea7d075..0b419e9631b29 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -266,8 +266,8 @@ class ColumnPruningSuite extends PlanTest { test("Column pruning on Window with useless aggregate functions") { val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) - val winSpec = windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame) - val winExpr = windowExpr(count('b), winSpec) + val winSpec = windowSpec('a :: Nil, 'd.asc :: Nil, UnspecifiedFrame) + val winExpr = windowExpr(count('d), winSpec) val originalQuery = input.groupBy('a, 'c, 'd)('a, 'c, 'd, winExpr.as('window)).select('a, 'c) val correctAnswer = input.select('a, 'c, 'd).groupBy('a, 'c, 'd)('a, 'c).analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala index 81d2f3667e2d0..94174eec8fd0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala @@ -35,7 +35,6 @@ class ConstantPropagationSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("ConstantPropagation", FixedPoint(10), - ColumnPruning, ConstantPropagation, ConstantFolding, BooleanSimplification) :: Nil @@ -43,9 +42,9 @@ class ConstantPropagationSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - private val columnA = 'a.int - private val columnB = 'b.int - private val columnC = 'c.int + private val columnA = 'a + private val columnB = 'b + private val columnC = 'c test("basic test") { val query = testRelation @@ -160,7 +159,7 @@ class ConstantPropagationSuite extends PlanTest { val correctAnswer = testRelation .select(columnA) - .where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)) + .where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index d4d281e7e05db..3553d23560dad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -629,14 +629,14 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { testRelationWithArrayType .generate(Explode('c_arr), true, false, Some("arr")) - .where(('b >= 5) && ('a + Rand(10).as("rnd") > 6) && ('c > 6)) + .where(('b >= 5) && ('a + Rand(10).as("rnd") > 6) && ('col > 6)) } val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = { testRelationWithArrayType .where('b >= 5) .generate(Explode('c_arr), true, false, Some("arr")) - .where('a + Rand(10).as("rnd") > 6 && 'c > 6) + .where('a + Rand(10).as("rnd") > 6 && 'col > 6) .analyze } @@ -676,7 +676,7 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { testRelationWithArrayType .generate(Explode('c_arr), true, false, Some("arr")) - .where(('c > 6) || ('b > 5)).analyze + .where(('col > 6) || ('b > 5)).analyze } val optimized = Optimize.execute(originalQuery) @@ -1129,6 +1129,9 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = x.where("x.a".attr === 5).join(y.where("y.a".attr === 5), condition = Some("x.a".attr === Rand(10) && "y.b".attr === 5)) - comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + // CheckAnalysis will ensure nondeterministic expressions not appear in join condition. + // TODO support nondeterministic expressions in join condition. + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, + checkAnalysis = false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index 2885fd6841e9d..fb34c82de468b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -70,19 +70,21 @@ class LimitPushdownSuite extends PlanTest { } test("Union: no limit to both sides if children having smaller limit values") { - val unionQuery = Union(testRelation.limit(1), testRelation2.select('d).limit(1)).limit(2) + val unionQuery = + Union(testRelation.limit(1), testRelation2.select('d, 'e, 'f).limit(1)).limit(2) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Limit(2, Union(testRelation.limit(1), testRelation2.select('d).limit(1))).analyze + Limit(2, Union(testRelation.limit(1), testRelation2.select('d, 'e, 'f).limit(1))).analyze comparePlans(unionOptimized, unionCorrectAnswer) } test("Union: limit to each sides if children having larger limit values") { - val testLimitUnion = Union(testRelation.limit(3), testRelation2.select('d).limit(4)) - val unionQuery = testLimitUnion.limit(2) + val unionQuery = + Union(testRelation.limit(3), testRelation2.select('d, 'e, 'f).limit(4)).limit(2) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Limit(2, Union(LocalLimit(2, testRelation), LocalLimit(2, testRelation2.select('d)))).analyze + Limit(2, Union( + LocalLimit(2, testRelation), LocalLimit(2, testRelation2.select('d, 'e, 'f)))).analyze comparePlans(unionOptimized, unionCorrectAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala index f3b65cc797ec4..9dc6738ba04b3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala @@ -50,10 +50,10 @@ class OptimizeCodegenSuite extends PlanTest { test("Nested CaseWhen Codegen.") { assertEquivalent( CaseWhen( - Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), Literal(3))), + Seq((CaseWhen(Seq((TrueLiteral, TrueLiteral)), FalseLiteral), Literal(3))), CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))), CaseWhen( - Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), Literal(3))), + Seq((CaseWhen(Seq((TrueLiteral, TrueLiteral)), FalseLiteral).toCodegen(), Literal(3))), CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen()) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index a37bc4bca2422..623ff3d446a5f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -201,7 +201,7 @@ class OuterJoinEliminationSuite extends PlanTest { val originalQuery = x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) - .where(Coalesce("y.e".attr :: "x.a".attr :: Nil)) + .where(Coalesce("y.e".attr :: "x.a".attr :: Nil) === 0) val optimized = Optimize.execute(originalQuery.analyze) @@ -209,7 +209,7 @@ class OuterJoinEliminationSuite extends PlanTest { val right = testRelation1 val correctAnswer = left.join(right, FullOuter, Option("a".attr === "d".attr)) - .where(Coalesce("e".attr :: "a".attr :: Nil)).analyze + .where(Coalesce("e".attr :: "a".attr :: Nil) === 0).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala index e84f11272d214..7b3f5b084b015 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -44,7 +44,9 @@ class SimplifyCastsSuite extends PlanTest { val input = LocalRelation('a.array(ArrayType(IntegerType, true))) val plan = input.select('a.cast(ArrayType(IntegerType, false)).as("casted")).analyze val optimized = Optimize.execute(plan) - comparePlans(optimized, plan) + // Though cast from `ArrayType(IntegerType, true)` to `ArrayType(IntegerType, false)` is not + // allowed, here we just ensure that `SimplifyCasts` rule respect the plan. + comparePlans(optimized, plan, checkAnalysis = false) } test("non-nullable value map to nullable value map cast") { @@ -61,7 +63,10 @@ class SimplifyCastsSuite extends PlanTest { val plan = input.select('m.cast(MapType(StringType, StringType, false)) .as("casted")).analyze val optimized = Optimize.execute(plan) - comparePlans(optimized, plan) + // Though cast from `MapType(StringType, StringType, true)` to + // `MapType(StringType, StringType, false)` is not allowed, here we just ensure that + // `SimplifyCasts` rule respect the plan. + comparePlans(optimized, plan, checkAnalysis = false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index fef39a5b6a32f..0a4ae098d65cc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedTableValuedFunction} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedTableValuedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -29,13 +29,13 @@ import org.apache.spark.sql.types.IntegerType * * There is also SparkSqlParserSuite in sql/core module for parser rules defined in sql/core module. */ -class PlanParserSuite extends PlanTest { +class PlanParserSuite extends AnalysisTest { import CatalystSqlParser._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { - comparePlans(parsePlan(sqlCommand), plan) + comparePlans(parsePlan(sqlCommand), plan, checkAnalysis = false) } private def intercept(sqlCommand: String, messages: String*): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index f44428c3512a9..25313af2be184 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ @@ -90,7 +91,16 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { } /** Fails the test if the two plans do not match */ - protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { + protected def comparePlans( + plan1: LogicalPlan, + plan2: LogicalPlan, + checkAnalysis: Boolean = true): Unit = { + if (checkAnalysis) { + // Make sure both plan pass checkAnalysis. + SimpleAnalyzer.checkAnalysis(plan1) + SimpleAnalyzer.checkAnalysis(plan2) + } + val normalized1 = normalizePlan(normalizeExprIds(plan1)) val normalized2 = normalizePlan(normalizeExprIds(plan2)) if (normalized1 != normalized2) { @@ -104,7 +114,7 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { /** Fails the test if the two expressions do not match */ protected def compareExpressions(e1: Expression, e2: Expression): Unit = { - comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) + comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation), checkAnalysis = false) } /** Fails the test if the join order in the two plans do not match */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala index 60f6f23860ed9..0dd5bdcba2e4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.test.SharedSQLContext -class DataFrameHintSuite extends PlanTest with SharedSQLContext { +class DataFrameHintSuite extends AnalysisTest with SharedSQLContext { import testImplicits._ lazy val df = spark.range(10) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index b32fb90e10072..bd9c2ebd6fab9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.expressions.{Ascending, Concat, SortOrder} import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, RepartitionByExpression, Sort} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.CreateTable @@ -36,7 +35,7 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType * See [[org.apache.spark.sql.catalyst.parser.PlanParserSuite]] for rules * defined in the Catalyst module. */ -class SparkSqlParserSuite extends PlanTest { +class SparkSqlParserSuite extends AnalysisTest { val newConf = new SQLConf private lazy val parser = new SparkSqlParser(newConf) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index d97b11e447fe2..bee470d8e1382 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan import org.apache.spark.sql.catalyst.expressions.JsonTuple import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} +import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, ScriptTransformation} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} @@ -59,6 +59,11 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle }.head } + private def compareTransformQuery(sql: String, expected: LogicalPlan): Unit = { + val plan = parser.parsePlan(sql).asInstanceOf[ScriptTransformation].copy(ioschema = null) + comparePlans(plan, expected, checkAnalysis = false) + } + test("Test CTAS #1") { val s1 = """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view @@ -253,22 +258,15 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle } test("transform query spec") { - val plan1 = parser.parsePlan("select transform(a, b) using 'func' from e where f < 10") - .asInstanceOf[ScriptTransformation].copy(ioschema = null) - val plan2 = parser.parsePlan("map a, b using 'func' as c, d from e") - .asInstanceOf[ScriptTransformation].copy(ioschema = null) - val plan3 = parser.parsePlan("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e") - .asInstanceOf[ScriptTransformation].copy(ioschema = null) - val p = ScriptTransformation( Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), "func", Seq.empty, plans.table("e"), null) - comparePlans(plan1, + compareTransformQuery("select transform(a, b) using 'func' from e where f < 10", p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) - comparePlans(plan2, + compareTransformQuery("map a, b using 'func' as c, d from e", p.copy(output = Seq('c.string, 'd.string))) - comparePlans(plan3, + compareTransformQuery("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e", p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) } From ad459cfb1d169d8dd7b9e039ca135ba5cafcab83 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Wed, 21 Jun 2017 10:35:16 -0700 Subject: [PATCH 048/118] [SPARK-20917][ML][SPARKR] SparkR supports string encoding consistent with R ## What changes were proposed in this pull request? Add `stringIndexerOrderType` to `spark.glm` and `spark.survreg` to support string encoding that is consistent with default R. ## How was this patch tested? new tests Author: actuaryzhang Closes #18140 from actuaryzhang/sparkRFormula. --- R/pkg/R/mllib_regression.R | 52 +++++++++++++--- R/pkg/tests/fulltests/test_mllib_regression.R | 62 +++++++++++++++++++ .../ml/r/AFTSurvivalRegressionWrapper.scala | 4 +- .../GeneralizedLinearRegressionWrapper.scala | 6 +- 4 files changed, 115 insertions(+), 9 deletions(-) diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index d59c890f3e5fd..9ecd887f2c127 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -70,6 +70,12 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' the relationship between the variance and mean of the distribution. Only #' applicable to the Tweedie family. #' @param link.power the index in the power link function. Only applicable to the Tweedie family. +#' @param stringIndexerOrderType how to order categories of a string feature column. This is used to +#' decide the base level of a string feature as the last category after +#' ordering is dropped when encoding strings. Supported options are +#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". +#' The default value is "frequencyDesc". When the ordering is set to +#' "alphabetDesc", this drops the same category as R when encoding strings. #' @param ... additional arguments passed to the method. #' @aliases spark.glm,SparkDataFrame,formula-method #' @return \code{spark.glm} returns a fitted generalized linear model. @@ -79,7 +85,7 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' @examples #' \dontrun{ #' sparkR.session() -#' t <- as.data.frame(Titanic) +#' t <- as.data.frame(Titanic, stringsAsFactors = FALSE) #' df <- createDataFrame(t) #' model <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian") #' summary(model) @@ -96,6 +102,15 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' savedModel <- read.ml(path) #' summary(savedModel) #' +#' # note that the default string encoding is different from R's glm +#' model2 <- glm(Freq ~ Sex + Age, family = "gaussian", data = t) +#' summary(model2) +#' # use stringIndexerOrderType = "alphabetDesc" to force string encoding +#' # to be consistent with R +#' model3 <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian", +#' stringIndexerOrderType = "alphabetDesc") +#' summary(model3) +#' #' # fit tweedie model #' model <- spark.glm(df, Freq ~ Sex + Age, family = "tweedie", #' var.power = 1.2, link.power = 0) @@ -110,8 +125,11 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' @seealso \link{glm}, \link{read.ml} setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL, - regParam = 0.0, var.power = 0.0, link.power = 1.0 - var.power) { + regParam = 0.0, var.power = 0.0, link.power = 1.0 - var.power, + stringIndexerOrderType = c("frequencyDesc", "frequencyAsc", + "alphabetDesc", "alphabetAsc")) { + stringIndexerOrderType <- match.arg(stringIndexerOrderType) if (is.character(family)) { # Handle when family = "tweedie" if (tolower(family) == "tweedie") { @@ -145,7 +163,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", "fit", formula, data@sdf, tolower(family$family), family$link, tol, as.integer(maxIter), weightCol, regParam, - as.double(var.power), as.double(link.power)) + as.double(var.power), as.double(link.power), + stringIndexerOrderType) new("GeneralizedLinearRegressionModel", jobj = jobj) }) @@ -167,6 +186,12 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @param maxit integer giving the maximal number of IRLS iterations. #' @param var.power the index of the power variance function in the Tweedie family. #' @param link.power the index of the power link function in the Tweedie family. +#' @param stringIndexerOrderType how to order categories of a string feature column. This is used to +#' decide the base level of a string feature as the last category after +#' ordering is dropped when encoding strings. Supported options are +#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". +#' The default value is "frequencyDesc". When the ordering is set to +#' "alphabetDesc", this drops the same category as R when encoding strings. #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @export @@ -182,9 +207,12 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @seealso \link{spark.glm} setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"), function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL, - var.power = 0.0, link.power = 1.0 - var.power) { + var.power = 0.0, link.power = 1.0 - var.power, + stringIndexerOrderType = c("frequencyDesc", "frequencyAsc", + "alphabetDesc", "alphabetAsc")) { spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol, - var.power = var.power, link.power = link.power) + var.power = var.power, link.power = link.power, + stringIndexerOrderType = stringIndexerOrderType) }) # Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). @@ -418,6 +446,12 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features #' or the number of partitions are large, this param could be adjusted to a larger size. #' This is an expert parameter. Default value should be good for most cases. +#' @param stringIndexerOrderType how to order categories of a string feature column. This is used to +#' decide the base level of a string feature as the last category after +#' ordering is dropped when encoding strings. Supported options are +#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". +#' The default value is "frequencyDesc". When the ordering is set to +#' "alphabetDesc", this drops the same category as R when encoding strings. #' @param ... additional arguments passed to the method. #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg @@ -443,10 +477,14 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' } #' @note spark.survreg since 2.0.0 setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, aggregationDepth = 2) { + function(data, formula, aggregationDepth = 2, + stringIndexerOrderType = c("frequencyDesc", "frequencyAsc", + "alphabetDesc", "alphabetAsc")) { + stringIndexerOrderType <- match.arg(stringIndexerOrderType) formula <- paste(deparse(formula), collapse = "") jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper", - "fit", formula, data@sdf, as.integer(aggregationDepth)) + "fit", formula, data@sdf, as.integer(aggregationDepth), + stringIndexerOrderType) new("AFTSurvivalRegressionModel", jobj = jobj) }) diff --git a/R/pkg/tests/fulltests/test_mllib_regression.R b/R/pkg/tests/fulltests/test_mllib_regression.R index 82472c92b9965..6b72a09b200d6 100644 --- a/R/pkg/tests/fulltests/test_mllib_regression.R +++ b/R/pkg/tests/fulltests/test_mllib_regression.R @@ -367,6 +367,49 @@ test_that("glm save/load", { unlink(modelPath) }) +test_that("spark.glm and glm with string encoding", { + t <- as.data.frame(Titanic, stringsAsFactors = FALSE) + df <- createDataFrame(t) + + # base R + rm <- stats::glm(Freq ~ Sex + Age, family = "gaussian", data = t) + # spark.glm with default stringIndexerOrderType = "frequencyDesc" + sm0 <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian") + # spark.glm with stringIndexerOrderType = "alphabetDesc" + sm1 <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian", + stringIndexerOrderType = "alphabetDesc") + # glm with stringIndexerOrderType = "alphabetDesc" + sm2 <- glm(Freq ~ Sex + Age, family = "gaussian", data = df, + stringIndexerOrderType = "alphabetDesc") + + rStats <- summary(rm) + rCoefs <- rStats$coefficients + sStats <- lapply(list(sm0, sm1, sm2), summary) + # order by coefficient size since column rendering may be different + o <- order(rCoefs[, 1]) + + # default encoding does not produce same results as R + expect_false(all(abs(rCoefs[o, ] - sStats[[1]]$coefficients[o, ]) < 1e-4)) + + # all estimates should be the same as R with stringIndexerOrderType = "alphabetDesc" + test <- lapply(sStats[2:3], function(stats) { + expect_true(all(abs(rCoefs[o, ] - stats$coefficients[o, ]) < 1e-4)) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + }) + + # fitted values should be equal regardless of string encoding + rVals <- predict(rm, t) + test <- lapply(list(sm0, sm1, sm2), function(sm) { + vals <- collect(select(predict(sm, df), "prediction")) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + }) +}) + test_that("spark.isoreg", { label <- c(7.0, 5.0, 3.0, 5.0, 1.0) feature <- c(0.0, 1.0, 2.0, 3.0, 4.0) @@ -462,6 +505,25 @@ test_that("spark.survreg", { model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData), NA) expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4) + + # Test stringIndexerOrderType + rData <- as.data.frame(rData) + rData$sex2 <- c("female", "male")[rData$sex + 1] + df <- createDataFrame(rData) + expect_error( + rModel <- survival::survreg(survival::Surv(time, status) ~ x + sex2, rData), NA) + rCoefs <- as.numeric(summary(rModel)$table[, 1]) + model <- spark.survreg(df, Surv(time, status) ~ x + sex2) + coefs <- as.vector(summary(model)$coefficients[, 1]) + o <- order(rCoefs) + # stringIndexerOrderType = "frequencyDesc" produces different estimates from R + expect_false(all(abs(rCoefs[o] - coefs[o]) < 1e-4)) + + # stringIndexerOrderType = "alphabetDesc" produces the same estimates as R + model <- spark.survreg(df, Surv(time, status) ~ x + sex2, + stringIndexerOrderType = "alphabetDesc") + coefs <- as.vector(summary(model)$coefficients[, 1]) + expect_true(all(abs(rCoefs[o] - coefs[o]) < 1e-4)) } }) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 0bf543d88894e..80d03ab03c87d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -85,11 +85,13 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg def fit( formula: String, data: DataFrame, - aggregationDepth: Int): AFTSurvivalRegressionWrapper = { + aggregationDepth: Int, + stringIndexerOrderType: String): AFTSurvivalRegressionWrapper = { val (rewritedFormula, censorCol) = formulaRewrite(formula) val rFormula = new RFormula().setFormula(rewritedFormula) + .setStringIndexerOrderType(stringIndexerOrderType) RWrapperUtils.checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 4bd4aa7113f68..ee1fc9b14ceaa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -65,6 +65,7 @@ private[r] class GeneralizedLinearRegressionWrapper private ( private[r] object GeneralizedLinearRegressionWrapper extends MLReadable[GeneralizedLinearRegressionWrapper] { + // scalastyle:off def fit( formula: String, data: DataFrame, @@ -75,8 +76,11 @@ private[r] object GeneralizedLinearRegressionWrapper weightCol: String, regParam: Double, variancePower: Double, - linkPower: Double): GeneralizedLinearRegressionWrapper = { + linkPower: Double, + stringIndexerOrderType: String): GeneralizedLinearRegressionWrapper = { + // scalastyle:on val rFormula = new RFormula().setFormula(formula) + .setStringIndexerOrderType(stringIndexerOrderType) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema From 7a00c658d44139d950b7d3ecd670d79f76e2e747 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 21 Jun 2017 10:51:17 -0700 Subject: [PATCH 049/118] [SPARK-21147][SS] Throws an analysis exception when a user-specified schema is given in socket/rate sources ## What changes were proposed in this pull request? This PR proposes to throw an exception if a schema is provided by user to socket source as below: **socket source** ```scala import org.apache.spark.sql.types._ val userSpecifiedSchema = StructType( StructField("name", StringType) :: StructField("area", StringType) :: Nil) val df = spark.readStream.format("socket").option("host", "localhost").option("port", 9999).schema(userSpecifiedSchema).load df.printSchema ``` Before ``` root |-- value: string (nullable = true) ``` After ``` org.apache.spark.sql.AnalysisException: The socket source does not support a user-specified schema.; at org.apache.spark.sql.execution.streaming.TextSocketSourceProvider.sourceSchema(socket.scala:199) at org.apache.spark.sql.execution.datasources.DataSource.sourceSchema(DataSource.scala:192) at org.apache.spark.sql.execution.datasources.DataSource.sourceInfo$lzycompute(DataSource.scala:87) at org.apache.spark.sql.execution.datasources.DataSource.sourceInfo(DataSource.scala:87) at org.apache.spark.sql.execution.streaming.StreamingRelation$.apply(StreamingRelation.scala:30) at org.apache.spark.sql.streaming.DataStreamReader.load(DataStreamReader.scala:150) ... 50 elided ``` **rate source** ```scala spark.readStream.format("rate").schema(spark.range(1).schema).load().printSchema() ``` Before ``` root |-- timestamp: timestamp (nullable = true) |-- value: long (nullable = true)` ``` After ``` org.apache.spark.sql.AnalysisException: The rate source does not support a user-specified schema.; at org.apache.spark.sql.execution.streaming.RateSourceProvider.sourceSchema(RateSourceProvider.scala:57) at org.apache.spark.sql.execution.datasources.DataSource.sourceSchema(DataSource.scala:192) at org.apache.spark.sql.execution.datasources.DataSource.sourceInfo$lzycompute(DataSource.scala:87) at org.apache.spark.sql.execution.datasources.DataSource.sourceInfo(DataSource.scala:87) at org.apache.spark.sql.execution.streaming.StreamingRelation$.apply(StreamingRelation.scala:30) at org.apache.spark.sql.streaming.DataStreamReader.load(DataStreamReader.scala:150) ... 48 elided ``` ## How was this patch tested? Unit test in `TextSocketStreamSuite` and `RateSourceSuite`. Author: hyukjinkwon Closes #18365 from HyukjinKwon/SPARK-21147. --- .../execution/streaming/RateSourceProvider.scala | 9 +++++++-- .../spark/sql/execution/streaming/socket.scala | 8 ++++++-- .../sql/execution/streaming/RateSourceSuite.scala | 12 ++++++++++++ .../streaming/TextSocketStreamSuite.scala | 15 +++++++++++++++ 4 files changed, 40 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index e61a8eb628891..e76d4dc6125df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -25,7 +25,7 @@ import org.apache.commons.io.IOUtils import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} @@ -52,8 +52,13 @@ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister { sqlContext: SQLContext, schema: Option[StructType], providerName: String, - parameters: Map[String, String]): (String, StructType) = + parameters: Map[String, String]): (String, StructType) = { + if (schema.nonEmpty) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + (shortName(), RateSourceProvider.SCHEMA) + } override def createSource( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala index 58bff27a05bf3..8e63207959575 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala @@ -195,13 +195,17 @@ class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegis if (!parameters.contains("port")) { throw new AnalysisException("Set a port to read from with option(\"port\", ...).") } - val schema = + if (schema.nonEmpty) { + throw new AnalysisException("The socket source does not support a user-specified schema.") + } + + val sourceSchema = if (parseIncludeTimestamp(parameters)) { TextSocketSource.SCHEMA_TIMESTAMP } else { TextSocketSource.SCHEMA_REGULAR } - ("textSocket", schema) + ("textSocket", sourceSchema) } override def createSource( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala index bdba536425a43..03d0f63fa4d7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.TimeUnit +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.util.ManualClock @@ -179,4 +180,15 @@ class RateSourceSuite extends StreamTest { testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) } + + test("user-specified schema given") { + val exception = intercept[AnalysisException] { + spark.readStream + .format("rate") + .schema(spark.range(1).schema) + .load() + } + assert(exception.getMessage.contains( + "rate source does not support a user-specified schema")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala index 5174a0415304c..9ebf4d2835266 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala @@ -148,6 +148,21 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before } } + test("user-specified schema given") { + val provider = new TextSocketSourceProvider + val userSpecifiedSchema = StructType( + StructField("name", StringType) :: + StructField("area", StringType) :: Nil) + val exception = intercept[AnalysisException] { + provider.sourceSchema( + sqlContext, Some(userSpecifiedSchema), + "", + Map("host" -> "localhost", "port" -> "1234")) + } + assert(exception.getMessage.contains( + "socket source does not support a user-specified schema")) + } + test("no server up") { val provider = new TextSocketSourceProvider val parameters = Map("host" -> "localhost", "port" -> "0") From ba78514da7bf2132873270b8bf39b50e54f4b094 Mon Sep 17 00:00:00 2001 From: sjarvie Date: Wed, 21 Jun 2017 10:51:45 -0700 Subject: [PATCH 050/118] [SPARK-21125][PYTHON] Extend setJobDescription to PySpark and JavaSpark APIs ## What changes were proposed in this pull request? Extend setJobDescription to PySpark and JavaSpark APIs SPARK-21125 ## How was this patch tested? Testing was done by running a local Spark shell on the built UI. I originally had added a unit test but the PySpark context cannot easily access the Scala Spark Context's private variable with the Job Description key so I omitted the test, due to the simplicity of this addition. Also ran the existing tests. # Misc This contribution is my original work and that I license the work to the project under the project's open source license. Author: sjarvie Closes #18332 from sjarvie/add_python_set_job_description. --- .../scala/org/apache/spark/api/java/JavaSparkContext.scala | 6 ++++++ python/pyspark/context.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 9481156bc93a5..f1936bf587282 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -757,6 +757,12 @@ class JavaSparkContext(val sc: SparkContext) */ def getLocalProperty(key: String): String = sc.getLocalProperty(key) + /** + * Set a human readable description of the current job. + * @since 2.3.0 + */ + def setJobDescription(value: String): Unit = sc.setJobDescription(value) + /** Control our logLevel. This overrides any user-defined log settings. * @param logLevel The desired log level as a string. * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 3be07325f4162..c4b7e6372d1a2 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -942,6 +942,12 @@ def getLocalProperty(self, key): """ return self._jsc.getLocalProperty(key) + def setJobDescription(self, value): + """ + Set a human readable description of the current job. + """ + self._jsc.setJobDescription(value) + def sparkUser(self): """ Get SPARK_USER for user who is running SparkContext. From 215281d88ed664547088309cb432da2fed18b8b7 Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 21 Jun 2017 14:59:52 -0700 Subject: [PATCH 051/118] [SPARK-20830][PYSPARK][SQL] Add posexplode and posexplode_outer ## What changes were proposed in this pull request? Add Python wrappers for `o.a.s.sql.functions.explode_outer` and `o.a.s.sql.functions.posexplode_outer`. ## How was this patch tested? Unit tests, doctests. Author: zero323 Closes #18049 from zero323/SPARK-20830. --- python/pyspark/sql/functions.py | 65 +++++++++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 20 +++++++++- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 240ae65a61785..3416c4b118a07 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1727,6 +1727,71 @@ def posexplode(col): return Column(jc) +@since(2.3) +def explode_outer(col): + """Returns a new row for each element in the given array or map. + Unlike explode, if the array/map is null or empty then null is produced. + + >>> df = spark.createDataFrame( + ... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)], + ... ("id", "an_array", "a_map") + ... ) + >>> df.select("id", "an_array", explode_outer("a_map")).show() + +---+----------+----+-----+ + | id| an_array| key|value| + +---+----------+----+-----+ + | 1|[foo, bar]| x| 1.0| + | 2| []|null| null| + | 3| null|null| null| + +---+----------+----+-----+ + + >>> df.select("id", "a_map", explode_outer("an_array")).show() + +---+-------------+----+ + | id| a_map| col| + +---+-------------+----+ + | 1|Map(x -> 1.0)| foo| + | 1|Map(x -> 1.0)| bar| + | 2| Map()|null| + | 3| null|null| + +---+-------------+----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.explode_outer(_to_java_column(col)) + return Column(jc) + + +@since(2.3) +def posexplode_outer(col): + """Returns a new row for each element with position in the given array or map. + Unlike posexplode, if the array/map is null or empty then the row (null, null) is produced. + + >>> df = spark.createDataFrame( + ... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)], + ... ("id", "an_array", "a_map") + ... ) + >>> df.select("id", "an_array", posexplode_outer("a_map")).show() + +---+----------+----+----+-----+ + | id| an_array| pos| key|value| + +---+----------+----+----+-----+ + | 1|[foo, bar]| 0| x| 1.0| + | 2| []|null|null| null| + | 3| null|null|null| null| + +---+----------+----+----+-----+ + >>> df.select("id", "a_map", posexplode_outer("an_array")).show() + +---+-------------+----+----+ + | id| a_map| pos| col| + +---+-------------+----+----+ + | 1|Map(x -> 1.0)| 0| foo| + | 1|Map(x -> 1.0)| 1| bar| + | 2| Map()|null|null| + | 3| null|null|null| + +---+-------------+----+----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.posexplode_outer(_to_java_column(col)) + return Column(jc) + + @ignore_unicode_prefix @since(1.6) def get_json_object(col, path): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 31f932a363225..3b308579a3778 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -258,8 +258,12 @@ def test_column_name_encoding(self): self.assertTrue(isinstance(columns[1], str)) def test_explode(self): - from pyspark.sql.functions import explode - d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] + from pyspark.sql.functions import explode, explode_outer, posexplode_outer + d = [ + Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}), + Row(a=1, intlist=[], mapfield={}), + Row(a=1, intlist=None, mapfield=None), + ] rdd = self.sc.parallelize(d) data = self.spark.createDataFrame(rdd) @@ -272,6 +276,18 @@ def test_explode(self): self.assertEqual(result[0][0], "a") self.assertEqual(result[0][1], "b") + result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()] + self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)]) + + result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()] + self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)]) + + result = [x[0] for x in data.select(explode_outer("intlist")).collect()] + self.assertEqual(result, [1, 2, 3, None, None]) + + result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()] + self.assertEqual(result, [('a', 'b'), (None, None), (None, None)]) + def test_and_in_expression(self): self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) From 53543374ce0cf0cec26de2382fbc85b7d5c7e9d6 Mon Sep 17 00:00:00 2001 From: wangmiao1981 Date: Wed, 21 Jun 2017 20:42:45 -0700 Subject: [PATCH 052/118] [SPARK-20906][SPARKR] Constrained Logistic Regression for SparkR ## What changes were proposed in this pull request? PR https://github.com/apache/spark/pull/17715 Added Constrained Logistic Regression for ML. We should add it to SparkR. ## How was this patch tested? Add new unit tests. Author: wangmiao1981 Closes #18128 from wangmiao1981/test. --- R/pkg/R/mllib_classification.R | 61 ++++++++++++++++++- .../fulltests/test_mllib_classification.R | 40 ++++++++++++ .../classification/LogisticRegression.scala | 8 +-- .../ml/r/LogisticRegressionWrapper.scala | 34 ++++++++++- 4 files changed, 135 insertions(+), 8 deletions(-) diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index bdcc0818d139d..82d2428f3c444 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -204,6 +204,20 @@ function(object, path, overwrite = FALSE) { #' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features #' or the number of partitions are large, this param could be adjusted to a larger size. #' This is an expert parameter. Default value should be good for most cases. +#' @param lowerBoundsOnCoefficients The lower bounds on coefficients if fitting under bound constrained optimization. +#' The bound matrix must be compatible with the shape (1, number of features) for binomial +#' regression, or (number of classes, number of features) for multinomial regression. +#' It is a R matrix. +#' @param upperBoundsOnCoefficients The upper bounds on coefficients if fitting under bound constrained optimization. +#' The bound matrix must be compatible with the shape (1, number of features) for binomial +#' regression, or (number of classes, number of features) for multinomial regression. +#' It is a R matrix. +#' @param lowerBoundsOnIntercepts The lower bounds on intercepts if fitting under bound constrained optimization. +#' The bounds vector size must be equal to 1 for binomial regression, or the number +#' of classes for multinomial regression. +#' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained optimization. +#' The bound vector size must be equal to 1 for binomial regression, or the number +#' of classes for multinomial regression. #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit @@ -241,8 +255,12 @@ function(object, path, overwrite = FALSE) { setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, tol = 1E-6, family = "auto", standardization = TRUE, - thresholds = 0.5, weightCol = NULL, aggregationDepth = 2) { + thresholds = 0.5, weightCol = NULL, aggregationDepth = 2, + lowerBoundsOnCoefficients = NULL, upperBoundsOnCoefficients = NULL, + lowerBoundsOnIntercepts = NULL, upperBoundsOnIntercepts = NULL) { formula <- paste(deparse(formula), collapse = "") + row <- 0 + col <- 0 if (!is.null(weightCol) && weightCol == "") { weightCol <- NULL @@ -250,12 +268,51 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") weightCol <- as.character(weightCol) } + if (!is.null(lowerBoundsOnIntercepts)) { + lowerBoundsOnIntercepts <- as.array(lowerBoundsOnIntercepts) + } + + if (!is.null(upperBoundsOnIntercepts)) { + upperBoundsOnIntercepts <- as.array(upperBoundsOnIntercepts) + } + + if (!is.null(lowerBoundsOnCoefficients)) { + if (class(lowerBoundsOnCoefficients) != "matrix") { + stop("lowerBoundsOnCoefficients must be a matrix.") + } + row <- nrow(lowerBoundsOnCoefficients) + col <- ncol(lowerBoundsOnCoefficients) + lowerBoundsOnCoefficients <- as.array(as.vector(lowerBoundsOnCoefficients)) + } + + if (!is.null(upperBoundsOnCoefficients)) { + if (class(upperBoundsOnCoefficients) != "matrix") { + stop("upperBoundsOnCoefficients must be a matrix.") + } + + if (!is.null(lowerBoundsOnCoefficients) && (row != nrow(upperBoundsOnCoefficients) + || col != ncol(upperBoundsOnCoefficients))) { + stop(paste0("dimension of upperBoundsOnCoefficients ", + "is not the same as lowerBoundsOnCoefficients", sep = "")) + } + + if (is.null(lowerBoundsOnCoefficients)) { + row <- nrow(upperBoundsOnCoefficients) + col <- ncol(upperBoundsOnCoefficients) + } + + upperBoundsOnCoefficients <- as.array(as.vector(upperBoundsOnCoefficients)) + } + jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", data@sdf, formula, as.numeric(regParam), as.numeric(elasticNetParam), as.integer(maxIter), as.numeric(tol), as.character(family), as.logical(standardization), as.array(thresholds), - weightCol, as.integer(aggregationDepth)) + weightCol, as.integer(aggregationDepth), + as.integer(row), as.integer(col), + lowerBoundsOnCoefficients, upperBoundsOnCoefficients, + lowerBoundsOnIntercepts, upperBoundsOnIntercepts) new("LogisticRegressionModel", jobj = jobj) }) diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R index 726e9d9a20b1c..3d75f4ce11ec8 100644 --- a/R/pkg/tests/fulltests/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -223,6 +223,46 @@ test_that("spark.logit", { model2 <- spark.logit(df2, label ~ feature, weightCol = "weight") prediction2 <- collect(select(predict(model2, df2), "prediction")) expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0")) + + # Test binomial logistic regression againt two classes with upperBoundsOnCoefficients + # and upperBoundsOnIntercepts + u <- matrix(c(1.0, 0.0, 1.0, 0.0), nrow = 1, ncol = 4) + model <- spark.logit(training, Species ~ ., upperBoundsOnCoefficients = u, + upperBoundsOnIntercepts = 1.0) + summary <- summary(model) + coefsR <- c(-11.13331, 1.00000, 0.00000, 1.00000, 0.00000) + coefs <- summary$coefficients[, "Estimate"] + expect_true(all(abs(coefsR - coefs) < 0.1)) + # Test upperBoundsOnCoefficients should be matrix + expect_error(spark.logit(training, Species ~ ., upperBoundsOnCoefficients = as.array(c(1, 2)), + upperBoundsOnIntercepts = 1.0)) + + # Test binomial logistic regression againt two classes with lowerBoundsOnCoefficients + # and lowerBoundsOnIntercepts + l <- matrix(c(0.0, -1.0, 0.0, -1.0), nrow = 1, ncol = 4) + model <- spark.logit(training, Species ~ ., lowerBoundsOnCoefficients = l, + lowerBoundsOnIntercepts = 0.0) + summary <- summary(model) + coefsR <- c(0, 0, -1, 0, 1.902192) + coefs <- summary$coefficients[, "Estimate"] + expect_true(all(abs(coefsR - coefs) < 0.1)) + # Test lowerBoundsOnCoefficients should be matrix + expect_error(spark.logit(training, Species ~ ., lowerBoundsOnCoefficients = as.array(c(1, 2)), + lowerBoundsOnIntercepts = 0.0)) + + # Test multinomial logistic regression with lowerBoundsOnCoefficients + # and lowerBoundsOnIntercepts + l <- matrix(c(0.0, -1.0, 0.0, -1.0, 0.0, -1.0, 0.0, -1.0), nrow = 2, ncol = 4) + model <- spark.logit(training, Species ~ ., family = "multinomial", + lowerBoundsOnCoefficients = l, + lowerBoundsOnIntercepts = as.array(c(0.0, 0.0))) + summary <- summary(model) + versicolorCoefsR <- c(42.639465, 7.258104, 14.330814, 16.298243, 11.716429) + virginicaCoefsR <- c(0.0002970796, 4.79274, 7.65047, 25.72793, 30.0021) + versicolorCoefs <- summary$coefficients[, "versicolor"] + virginicaCoefs <- summary$coefficients[, "virginica"] + expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) + expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) }) test_that("spark.mlp", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 567af0488e1b4..b234bc4c2df4f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -214,7 +214,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas /** * The lower bounds on intercepts if fitting under bound constrained optimization. - * The bounds vector size must be equal with 1 for binomial regression, or the number + * The bounds vector size must be equal to 1 for binomial regression, or the number * of classes for multinomial regression. Otherwise, it throws exception. * Default is none. * @@ -230,7 +230,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas /** * The upper bounds on intercepts if fitting under bound constrained optimization. - * The bound vector size must be equal with 1 for binomial regression, or the number + * The bound vector size must be equal to 1 for binomial regression, or the number * of classes for multinomial regression. Otherwise, it throws exception. * Default is none. * @@ -451,12 +451,12 @@ class LogisticRegression @Since("1.2.0") ( } if (isSet(lowerBoundsOnIntercepts)) { require($(lowerBoundsOnIntercepts).size == numCoefficientSets, "The size of " + - "lowerBoundsOnIntercepts must be equal with 1 for binomial regression, or the number of " + + "lowerBoundsOnIntercepts must be equal to 1 for binomial regression, or the number of " + s"classes for multinomial regression, but found: ${getLowerBoundsOnIntercepts.size}.") } if (isSet(upperBoundsOnIntercepts)) { require($(upperBoundsOnIntercepts).size == numCoefficientSets, "The size of " + - "upperBoundsOnIntercepts must be equal with 1 for binomial regression, or the number of " + + "upperBoundsOnIntercepts must be equal to 1 for binomial regression, or the number of " + s"classes for multinomial regression, but found: ${getUpperBoundsOnIntercepts.size}.") } if (isSet(lowerBoundsOnCoefficients) && isSet(upperBoundsOnCoefficients)) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala index 703bcdf4ca725..b96481acf46d7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -25,7 +25,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.feature.{IndexToString, RFormula} -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.{Matrices, Vector, Vectors} import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -97,7 +97,13 @@ private[r] object LogisticRegressionWrapper standardization: Boolean, thresholds: Array[Double], weightCol: String, - aggregationDepth: Int + aggregationDepth: Int, + numRowsOfBoundsOnCoefficients: Int, + numColsOfBoundsOnCoefficients: Int, + lowerBoundsOnCoefficients: Array[Double], + upperBoundsOnCoefficients: Array[Double], + lowerBoundsOnIntercepts: Array[Double], + upperBoundsOnIntercepts: Array[Double] ): LogisticRegressionWrapper = { val rFormula = new RFormula() @@ -133,6 +139,30 @@ private[r] object LogisticRegressionWrapper if (weightCol != null) lr.setWeightCol(weightCol) + if (numRowsOfBoundsOnCoefficients != 0 && + numColsOfBoundsOnCoefficients != 0 && lowerBoundsOnCoefficients != null) { + val coef = Matrices.dense(numRowsOfBoundsOnCoefficients, + numColsOfBoundsOnCoefficients, lowerBoundsOnCoefficients) + lr.setLowerBoundsOnCoefficients(coef) + } + + if (numRowsOfBoundsOnCoefficients != 0 && + numColsOfBoundsOnCoefficients != 0 && upperBoundsOnCoefficients != null) { + val coef = Matrices.dense(numRowsOfBoundsOnCoefficients, + numColsOfBoundsOnCoefficients, upperBoundsOnCoefficients) + lr.setUpperBoundsOnCoefficients(coef) + } + + if (lowerBoundsOnIntercepts != null) { + val intercept = Vectors.dense(lowerBoundsOnIntercepts) + lr.setLowerBoundsOnIntercepts(intercept) + } + + if (upperBoundsOnIntercepts != null) { + val intercept = Vectors.dense(upperBoundsOnIntercepts) + lr.setUpperBoundsOnIntercepts(intercept) + } + val idxToStr = new IndexToString() .setInputCol(PREDICTED_LABEL_INDEX_COL) .setOutputCol(PREDICTED_LABEL_COL) From d66b143eec7f604595089f72d8786edbdcd74282 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 21 Jun 2017 23:43:21 -0700 Subject: [PATCH 053/118] [SPARK-21167][SS] Decode the path generated by File sink to handle special characters ## What changes were proposed in this pull request? Decode the path generated by File sink to handle special characters. ## How was this patch tested? The added unit test. Author: Shixiong Zhu Closes #18381 from zsxwing/SPARK-21167. --- .../streaming/FileStreamSinkLog.scala | 5 +++- .../sql/streaming/FileStreamSinkSuite.scala | 29 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala index 8d718b2164d22..c9939ac1db746 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.net.URI + import org.apache.hadoop.fs.{FileStatus, Path} import org.json4s.NoTypeHints import org.json4s.jackson.Serialization @@ -47,7 +49,8 @@ case class SinkFileStatus( action: String) { def toFileStatus: FileStatus = { - new FileStatus(size, isDir, blockReplication, blockSize, modificationTime, new Path(path)) + new FileStatus( + size, isDir, blockReplication, blockSize, modificationTime, new Path(new URI(path))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 1a2d3a13f3a4a..bb6a27803bb20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -64,6 +64,35 @@ class FileStreamSinkSuite extends StreamTest { } } + test("SPARK-21167: encode and decode path correctly") { + val inputData = MemoryStream[String] + val ds = inputData.toDS() + + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + + val query = ds.map(s => (s, s.length)) + .toDF("value", "len") + .writeStream + .partitionBy("value") + .option("checkpointLocation", checkpointDir) + .format("parquet") + .start(outputDir) + + try { + // The output is partitoned by "value", so the value will appear in the file path. + // This is to test if we handle spaces in the path correctly. + inputData.addData("hello world") + failAfter(streamingTimeout) { + query.processAllAvailable() + } + val outputDf = spark.read.parquet(outputDir) + checkDatasetUnorderly(outputDf.as[(Int, String)], ("hello world".length, "hello world")) + } finally { + query.stop() + } + } + test("partitioned writing and batch reading") { val inputData = MemoryStream[Int] val ds = inputData.toDS() From 67c75021c59d93cda9b5d70c0ef6d547fff92083 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 22 Jun 2017 16:22:02 +0800 Subject: [PATCH 054/118] [SPARK-21163][SQL] DataFrame.toPandas should respect the data type ## What changes were proposed in this pull request? Currently we convert a spark DataFrame to Pandas Dataframe by `pd.DataFrame.from_records`. It infers the data type from the data and doesn't respect the spark DataFrame Schema. This PR fixes it. ## How was this patch tested? a new regression test Author: hyukjinkwon Author: Wenchen Fan Author: Wenchen Fan Closes #18378 from cloud-fan/to_pandas. --- python/pyspark/sql/dataframe.py | 31 ++++++++++++++++++++++++++++++- python/pyspark/sql/tests.py | 24 ++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8541403dfe2f1..0649271ed2246 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1721,7 +1721,18 @@ def toPandas(self): 1 5 Bob """ import pandas as pd - return pd.DataFrame.from_records(self.collect(), columns=self.columns) + + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + if pandas_type is not None: + dtype[field.name] = pandas_type + + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + return pdf ########################################################################################## # Pandas compatibility @@ -1750,6 +1761,24 @@ def _to_scala_map(sc, jm): return sc._jvm.PythonUtils.toScalaMap(jm) +def _to_corrected_pandas_type(dt): + """ + When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong. + This method gets the corrected data type for Pandas if that type may be inferred uncorrectly. + """ + import numpy as np + if type(dt) == ByteType: + return np.int8 + elif type(dt) == ShortType: + return np.int16 + elif type(dt) == IntegerType: + return np.int32 + elif type(dt) == FloatType: + return np.float32 + else: + return None + + class DataFrameNaFunctions(object): """Functionality for working with missing data in :class:`DataFrame`. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3b308579a3778..0a1cd6856b8e8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -46,6 +46,14 @@ else: import unittest +_have_pandas = False +try: + import pandas + _have_pandas = True +except: + # No Pandas, but that's okay, we'll skip those tests + pass + from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * @@ -2290,6 +2298,22 @@ def count_bucketed_cols(names, table="pyspark_bucket"): .mode("overwrite").saveAsTable("pyspark_bucket")) self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + @unittest.skipIf(not _have_pandas, "Pandas not installed") + def test_to_pandas(self): + import numpy as np + schema = StructType().add("a", IntegerType()).add("b", StringType())\ + .add("c", BooleanType()).add("d", FloatType()) + data = [ + (1, "foo", True, 3.0), (2, "foo", True, 5.0), + (3, "bar", False, -1.0), (4, "bar", False, 6.0), + ] + df = self.spark.createDataFrame(data, schema) + types = df.toPandas().dtypes + self.assertEquals(types[0], np.int32) + self.assertEquals(types[1], np.object) + self.assertEquals(types[2], np.bool) + self.assertEquals(types[3], np.float32) + class HiveSparkSubmitTests(SparkSubmitTests): From 97b307c87c0f262ea3e020bf3d72383deef76619 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Thu, 22 Jun 2017 10:12:33 +0100 Subject: [PATCH 055/118] [SQL][DOC] Fix documentation of lpad ## What changes were proposed in this pull request? Fix incomplete documentation for `lpad`. Author: actuaryzhang Closes #18367 from actuaryzhang/SQLDoc. --- .../src/main/scala/org/apache/spark/sql/functions.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9a35a5c4658e3..839cbf42024e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2292,7 +2292,8 @@ object functions { } /** - * Left-pad the string column with + * Left-pad the string column with pad to a length of len. If the string column is longer + * than len, the return value is shortened to len characters. * * @group string_funcs * @since 1.5.0 @@ -2350,7 +2351,8 @@ object functions { def unbase64(e: Column): Column = withExpr { UnBase64(e.expr) } /** - * Right-padded with pad to a length of len. + * Right-pad the string column with pad to a length of len. If the string column is longer + * than len, the return value is shortened to len characters. * * @group string_funcs * @since 1.5.0 From 2dadea95c8e2c727e97fca91b0060f666fc0c65b Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 22 Jun 2017 20:48:12 +0800 Subject: [PATCH 056/118] [SPARK-20832][CORE] Standalone master should explicitly inform drivers of worker deaths and invalidate external shuffle service outputs ## What changes were proposed in this pull request? In standalone mode, master should explicitly inform each active driver of any worker deaths, so the invalid external shuffle service outputs on the lost host would be removed from the shuffle mapStatus, thus we can avoid future `FetchFailure`s. ## How was this patch tested? Manually tested by the following steps: 1. Start a standalone Spark cluster with one driver node and two worker nodes; 2. Run a Job with ShuffleMapStage, ensure the outputs distribute on each worker; 3. Run another Job to make all executors exit, but the workers are all alive; 4. Kill one of the workers; 5. Run rdd.collect(), before this change, we should see `FetchFailure`s and failed Stages, while after the change, the job should complete without failure. Before the change: ![image](https://user-images.githubusercontent.com/4784782/27335366-c251c3d6-55fe-11e7-99dd-d1fdcb429210.png) After the change: ![image](https://user-images.githubusercontent.com/4784782/27335393-d1c71640-55fe-11e7-89ed-bd760f1f39af.png) Author: Xingbo Jiang Closes #18362 from jiangxb1987/removeWorker. --- .../apache/spark/deploy/DeployMessage.scala | 2 ++ .../deploy/client/StandaloneAppClient.scala | 4 +++ .../client/StandaloneAppClientListener.scala | 8 +++-- .../apache/spark/deploy/master/Master.scala | 15 ++++++---- .../apache/spark/scheduler/DAGScheduler.scala | 30 +++++++++++++++++++ .../spark/scheduler/DAGSchedulerEvent.scala | 3 ++ .../spark/scheduler/TaskScheduler.scala | 5 ++++ .../spark/scheduler/TaskSchedulerImpl.scala | 5 ++++ .../cluster/CoarseGrainedClusterMessage.scala | 3 ++ .../CoarseGrainedSchedulerBackend.scala | 25 +++++++++++++--- .../cluster/StandaloneSchedulerBackend.scala | 5 ++++ .../spark/deploy/client/AppClientSuite.scala | 2 ++ .../spark/scheduler/DAGSchedulerSuite.scala | 2 ++ .../ExternalClusterManagerSuite.scala | 1 + 14 files changed, 98 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index c1a91c27eef2d..49a319abb3238 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -158,6 +158,8 @@ private[deploy] object DeployMessages { case class ApplicationRemoved(message: String) + case class WorkerRemoved(id: String, host: String, message: String) + // DriverClient <-> Master case class RequestSubmitDriver(driverDescription: DriverDescription) extends DeployMessage diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala index 93f58ce63799f..757c930b84eb2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala @@ -182,6 +182,10 @@ private[spark] class StandaloneAppClient( listener.executorRemoved(fullId, message.getOrElse(""), exitStatus, workerLost) } + case WorkerRemoved(id, host, message) => + logInfo("Master removed worker %s: %s".format(id, message)) + listener.workerRemoved(id, host, message) + case MasterChanged(masterRef, masterWebUiUrl) => logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) master = Some(masterRef) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala index 64255ec92b72a..d8bc1a883def1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala @@ -18,9 +18,9 @@ package org.apache.spark.deploy.client /** - * Callbacks invoked by deploy client when various events happen. There are currently four events: - * connecting to the cluster, disconnecting, being given an executor, and having an executor - * removed (either due to failure or due to revocation). + * Callbacks invoked by deploy client when various events happen. There are currently five events: + * connecting to the cluster, disconnecting, being given an executor, having an executor removed + * (either due to failure or due to revocation), and having a worker removed. * * Users of this API should *not* block inside the callback methods. */ @@ -38,4 +38,6 @@ private[spark] trait StandaloneAppClientListener { def executorRemoved( fullId: String, message: String, exitStatus: Option[Int], workerLost: Boolean): Unit + + def workerRemoved(workerId: String, host: String, message: String): Unit } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index f10a41286c52f..c192a0cc82ef6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -498,7 +498,7 @@ private[deploy] class Master( override def onDisconnected(address: RpcAddress): Unit = { // The disconnected client could've been either a worker or an app; remove whichever it was logInfo(s"$address got disassociated, removing it.") - addressToWorker.get(address).foreach(removeWorker) + addressToWorker.get(address).foreach(removeWorker(_, s"${address} got disassociated")) addressToApp.get(address).foreach(finishApplication) if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } } @@ -544,7 +544,8 @@ private[deploy] class Master( state = RecoveryState.COMPLETING_RECOVERY // Kill off any workers and apps that didn't respond to us. - workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker) + workers.filter(_.state == WorkerState.UNKNOWN).foreach( + removeWorker(_, "Not responding for recovery")) apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication) // Update the state of recovered apps to RUNNING @@ -755,7 +756,7 @@ private[deploy] class Master( if (oldWorker.state == WorkerState.UNKNOWN) { // A worker registering from UNKNOWN implies that the worker was restarted during recovery. // The old worker must thus be dead, so we will remove it and accept the new worker. - removeWorker(oldWorker) + removeWorker(oldWorker, "Worker replaced by a new worker with same address") } else { logInfo("Attempted to re-register worker at same address: " + workerAddress) return false @@ -771,7 +772,7 @@ private[deploy] class Master( true } - private def removeWorker(worker: WorkerInfo) { + private def removeWorker(worker: WorkerInfo, msg: String) { logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) worker.setState(WorkerState.DEAD) idToWorker -= worker.id @@ -795,6 +796,10 @@ private[deploy] class Master( removeDriver(driver.id, DriverState.ERROR, None) } } + logInfo(s"Telling app of lost worker: " + worker.id) + apps.filterNot(completedApps.contains(_)).foreach { app => + app.driver.send(WorkerRemoved(worker.id, worker.host, msg)) + } persistenceEngine.removeWorker(worker) } @@ -979,7 +984,7 @@ private[deploy] class Master( if (worker.state != WorkerState.DEAD) { logWarning("Removing %s because we got no heartbeat in %d seconds".format( worker.id, WORKER_TIMEOUT_MS / 1000)) - removeWorker(worker) + removeWorker(worker, s"Not receiving heartbeat for ${WORKER_TIMEOUT_MS / 1000} seconds") } else { if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT_MS)) { workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index fafe9cafdc18f..3422a5f204b12 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -259,6 +259,13 @@ class DAGScheduler( eventProcessLoop.post(ExecutorLost(execId, reason)) } + /** + * Called by TaskScheduler implementation when a worker is removed. + */ + def workerRemoved(workerId: String, host: String, message: String): Unit = { + eventProcessLoop.post(WorkerRemoved(workerId, host, message)) + } + /** * Called by TaskScheduler implementation when a host is added. */ @@ -1432,6 +1439,26 @@ class DAGScheduler( } } + /** + * Responds to a worker being removed. This is called inside the event loop, so it assumes it can + * modify the scheduler's internal state. Use workerRemoved() to post a loss event from outside. + * + * We will assume that we've lost all shuffle blocks associated with the host if a worker is + * removed, so we will remove them all from MapStatus. + * + * @param workerId identifier of the worker that is removed. + * @param host host of the worker that is removed. + * @param message the reason why the worker is removed. + */ + private[scheduler] def handleWorkerRemoved( + workerId: String, + host: String, + message: String): Unit = { + logInfo("Shuffle files lost for worker %s on host %s".format(workerId, host)) + mapOutputTracker.removeOutputsOnHost(host) + clearCacheLocs() + } + private[scheduler] def handleExecutorAdded(execId: String, host: String) { // remove from failedEpoch(execId) ? if (failedEpoch.contains(execId)) { @@ -1727,6 +1754,9 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler } dagScheduler.handleExecutorLost(execId, workerLost) + case WorkerRemoved(workerId, host, message) => + dagScheduler.handleWorkerRemoved(workerId, host, message) + case BeginEvent(task, taskInfo) => dagScheduler.handleBeginEvent(task, taskInfo) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index cda0585f154a9..3f8d5639a2b90 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -86,6 +86,9 @@ private[scheduler] case class ExecutorAdded(execId: String, host: String) extend private[scheduler] case class ExecutorLost(execId: String, reason: ExecutorLossReason) extends DAGSchedulerEvent +private[scheduler] case class WorkerRemoved(workerId: String, host: String, message: String) + extends DAGSchedulerEvent + private[scheduler] case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 3de7d1f7de22b..90644fea23ab1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -89,6 +89,11 @@ private[spark] trait TaskScheduler { */ def executorLost(executorId: String, reason: ExecutorLossReason): Unit + /** + * Process a removed worker + */ + def workerRemoved(workerId: String, host: String, message: String): Unit + /** * Get an application's attempt ID associated with the job. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 629cfc7c7a8ce..bba0b294f1afb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -569,6 +569,11 @@ private[spark] class TaskSchedulerImpl private[scheduler]( } } + override def workerRemoved(workerId: String, host: String, message: String): Unit = { + logInfo(s"Handle removed worker $workerId: $message") + dagScheduler.workerRemoved(workerId, host, message) + } + private def logExecutorLoss( executorId: String, hostPort: String, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 6b49bd699a13a..89a9ad6811e18 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -85,6 +85,9 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: ExecutorLossReason) extends CoarseGrainedClusterMessage + case class RemoveWorker(workerId: String, host: String, message: String) + extends CoarseGrainedClusterMessage + case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage // Exchanged between the driver and the AM in Yarn client mode diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index dc82bb7704727..0b396b794ddce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -219,6 +219,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp removeExecutor(executorId, reason) context.reply(true) + case RemoveWorker(workerId, host, message) => + removeWorker(workerId, host, message) + context.reply(true) + case RetrieveSparkAppConfig => val reply = SparkAppConfig(sparkProperties, SparkEnv.get.securityManager.getIOEncryptionKey()) @@ -231,8 +235,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val taskDescs = CoarseGrainedSchedulerBackend.this.synchronized { // Filter out executors under killing val activeExecutors = executorDataMap.filterKeys(executorIsAlive) - val workOffers = activeExecutors.map { case (id, executorData) => - new WorkerOffer(id, executorData.executorHost, executorData.freeCores) + val workOffers = activeExecutors.map { + case (id, executorData) => + new WorkerOffer(id, executorData.executorHost, executorData.freeCores) }.toIndexedSeq scheduler.resourceOffers(workOffers) } @@ -331,6 +336,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } + // Remove a lost worker from the cluster + private def removeWorker(workerId: String, host: String, message: String): Unit = { + logDebug(s"Asked to remove worker $workerId with reason $message") + scheduler.workerRemoved(workerId, host, message) + } + /** * Stop making resource offers for the given executor. The executor is marked as lost with * the loss reason still pending. @@ -449,8 +460,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp */ protected def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { // Only log the failure since we don't care about the result. - driverEndpoint.ask[Boolean](RemoveExecutor(executorId, reason)).onFailure { case t => - logError(t.getMessage, t) + driverEndpoint.ask[Boolean](RemoveExecutor(executorId, reason)).onFailure { + case t => logError(t.getMessage, t) + }(ThreadUtils.sameThread) + } + + protected def removeWorker(workerId: String, host: String, message: String): Unit = { + driverEndpoint.ask[Boolean](RemoveWorker(workerId, host, message)).onFailure { + case t => logError(t.getMessage, t) }(ThreadUtils.sameThread) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 0529fe9eed4da..fd8e64454bf70 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -161,6 +161,11 @@ private[spark] class StandaloneSchedulerBackend( removeExecutor(fullId.split("/")(1), reason) } + override def workerRemoved(workerId: String, host: String, message: String): Unit = { + logInfo("Worker %s removed: %s".format(workerId, message)) + removeWorker(workerId, host, message) + } + override def sufficientResourcesRegistered(): Boolean = { totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio } diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index 936639b845789..a1707e6540b39 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -214,6 +214,8 @@ class AppClientSuite id: String, message: String, exitStatus: Option[Int], workerLost: Boolean): Unit = { execRemovedList.add(id) } + + def workerRemoved(workerId: String, host: String, message: String): Unit = {} } /** Create AppClient and supporting objects */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index ddd3281106745..453be26ed8d0c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -131,6 +131,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} + override def workerRemoved(workerId: String, host: String, message: String): Unit = {} override def applicationAttemptId(): Option[String] = None } @@ -632,6 +633,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], blockManagerId: BlockManagerId): Boolean = true override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} + override def workerRemoved(workerId: String, host: String, message: String): Unit = {} override def applicationAttemptId(): Option[String] = None } val noKillScheduler = new DAGScheduler( diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index ba56af8215cd7..a4e4ea7cd2894 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -84,6 +84,7 @@ private class DummyTaskScheduler extends TaskScheduler { override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} + override def workerRemoved(workerId: String, host: String, message: String): Unit = {} override def applicationAttemptId(): Option[String] = None def executorHeartbeatReceived( execId: String, From 19331b8e44ad910550f810b80e2a0caf0ef62cb3 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Thu, 22 Jun 2017 10:16:51 -0700 Subject: [PATCH 057/118] [SPARK-20889][SPARKR] Grouped documentation for DATETIME column methods ## What changes were proposed in this pull request? Grouped documentation for datetime column methods. Author: actuaryzhang Closes #18114 from actuaryzhang/sparkRDocDate. --- R/pkg/R/functions.R | 532 +++++++++++++++++++------------------------- R/pkg/R/generics.R | 69 ++++-- 2 files changed, 273 insertions(+), 328 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 01ca8b8c4527d..31028585aaa13 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -34,6 +34,58 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))} NULL +#' Date time functions for Column operations +#' +#' Date time functions defined for \code{Column}. +#' +#' @param x Column to compute on. +#' @param format For \code{to_date} and \code{to_timestamp}, it is the string to use to parse +#' x Column to DateType or TimestampType. For \code{trunc}, it is the string used +#' for specifying the truncation method. For example, "year", "yyyy", "yy" for +#' truncate by year, or "month", "mon", "mm" for truncate by month. +#' @param ... additional argument(s). +#' @name column_datetime_functions +#' @rdname column_datetime_functions +#' @family data time functions +#' @examples +#' \dontrun{ +#' dts <- c("2005-01-02 18:47:22", +#' "2005-12-24 16:30:58", +#' "2005-10-28 07:30:05", +#' "2005-12-28 07:01:05", +#' "2006-01-24 00:01:10") +#' y <- c(2.0, 2.2, 3.4, 2.5, 1.8) +#' df <- createDataFrame(data.frame(time = as.POSIXct(dts), y = y))} +NULL + +#' Date time arithmetic functions for Column operations +#' +#' Date time arithmetic functions defined for \code{Column}. +#' +#' @param y Column to compute on. +#' @param x For class \code{Column}, it is the column used to perform arithmetic operations +#' with column \code{y}. For class \code{numeric}, it is the number of months or +#' days to be added to or subtracted from \code{y}. For class \code{character}, it is +#' \itemize{ +#' \item \code{date_format}: date format specification. +#' \item \code{from_utc_timestamp}, \code{to_utc_timestamp}: time zone to use. +#' \item \code{next_day}: day of the week string. +#' } +#' +#' @name column_datetime_diff_functions +#' @rdname column_datetime_diff_functions +#' @family data time functions +#' @examples +#' \dontrun{ +#' dts <- c("2005-01-02 18:47:22", +#' "2005-12-24 16:30:58", +#' "2005-10-28 07:30:05", +#' "2005-12-28 07:01:05", +#' "2006-01-24 00:01:10") +#' y <- c(2.0, 2.2, 3.4, 2.5, 1.8) +#' df <- createDataFrame(data.frame(time = as.POSIXct(dts), y = y))} +NULL + #' lit #' #' A new \linkS4class{Column} is created to represent the literal value. @@ -546,18 +598,20 @@ setMethod("hash", column(jc) }) -#' dayofmonth -#' -#' Extracts the day of the month as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{dayofmonth}: Extracts the day of the month as an integer from a +#' given date/timestamp/string. #' -#' @rdname dayofmonth -#' @name dayofmonth -#' @family date time functions -#' @aliases dayofmonth,Column-method +#' @rdname column_datetime_functions +#' @aliases dayofmonth dayofmonth,Column-method #' @export -#' @examples \dontrun{dayofmonth(df$c)} +#' @examples +#' +#' \dontrun{ +#' head(select(df, df$time, year(df$time), quarter(df$time), month(df$time), +#' dayofmonth(df$time), dayofyear(df$time), weekofyear(df$time))) +#' head(agg(groupBy(df, year(df$time)), count(df$y), avg(df$y))) +#' head(agg(groupBy(df, month(df$time)), avg(df$y)))} #' @note dayofmonth since 1.5.0 setMethod("dayofmonth", signature(x = "Column"), @@ -566,18 +620,13 @@ setMethod("dayofmonth", column(jc) }) -#' dayofyear -#' -#' Extracts the day of the year as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{dayofyear}: Extracts the day of the year as an integer from a +#' given date/timestamp/string. #' -#' @rdname dayofyear -#' @name dayofyear -#' @family date time functions -#' @aliases dayofyear,Column-method +#' @rdname column_datetime_functions +#' @aliases dayofyear dayofyear,Column-method #' @export -#' @examples \dontrun{dayofyear(df$c)} #' @note dayofyear since 1.5.0 setMethod("dayofyear", signature(x = "Column"), @@ -763,18 +812,19 @@ setMethod("hex", column(jc) }) -#' hour -#' -#' Extracts the hours as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{hour}: Extracts the hours as an integer from a given date/timestamp/string. #' -#' @rdname hour -#' @name hour -#' @aliases hour,Column-method -#' @family date time functions +#' @rdname column_datetime_functions +#' @aliases hour hour,Column-method #' @export -#' @examples \dontrun{hour(df$c)} +#' @examples +#' +#' \dontrun{ +#' head(select(df, hour(df$time), minute(df$time), second(df$time))) +#' head(agg(groupBy(df, dayofmonth(df$time)), avg(df$y))) +#' head(agg(groupBy(df, hour(df$time)), avg(df$y))) +#' head(agg(groupBy(df, minute(df$time)), avg(df$y)))} #' @note hour since 1.5.0 setMethod("hour", signature(x = "Column"), @@ -893,20 +943,18 @@ setMethod("last", column(jc) }) -#' last_day -#' -#' Given a date column, returns the last day of the month which the given date belongs to. -#' For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the -#' month in July 2015. -#' -#' @param x Column to compute on. +#' @details +#' \code{last_day}: Given a date column, returns the last day of the month which the +#' given date belongs to. For example, input "2015-07-27" returns "2015-07-31" since +#' July 31 is the last day of the month in July 2015. #' -#' @rdname last_day -#' @name last_day -#' @aliases last_day,Column-method -#' @family date time functions +#' @rdname column_datetime_functions +#' @aliases last_day last_day,Column-method #' @export -#' @examples \dontrun{last_day(df$c)} +#' @examples +#' +#' \dontrun{ +#' head(select(df, df$time, last_day(df$time), month(df$time)))} #' @note last_day since 1.5.0 setMethod("last_day", signature(x = "Column"), @@ -1129,18 +1177,12 @@ setMethod("min", column(jc) }) -#' minute -#' -#' Extracts the minutes as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{minute}: Extracts the minutes as an integer from a given date/timestamp/string. #' -#' @rdname minute -#' @name minute -#' @aliases minute,Column-method -#' @family date time functions +#' @rdname column_datetime_functions +#' @aliases minute minute,Column-method #' @export -#' @examples \dontrun{minute(df$c)} #' @note minute since 1.5.0 setMethod("minute", signature(x = "Column"), @@ -1177,18 +1219,12 @@ setMethod("monotonically_increasing_id", column(jc) }) -#' month -#' -#' Extracts the month as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{month}: Extracts the month as an integer from a given date/timestamp/string. #' -#' @rdname month -#' @name month -#' @aliases month,Column-method -#' @family date time functions +#' @rdname column_datetime_functions +#' @aliases month month,Column-method #' @export -#' @examples \dontrun{month(df$c)} #' @note month since 1.5.0 setMethod("month", signature(x = "Column"), @@ -1217,18 +1253,12 @@ setMethod("negate", column(jc) }) -#' quarter -#' -#' Extracts the quarter as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{quarter}: Extracts the quarter as an integer from a given date/timestamp/string. #' -#' @rdname quarter -#' @name quarter -#' @family date time functions -#' @aliases quarter,Column-method +#' @rdname column_datetime_functions +#' @aliases quarter quarter,Column-method #' @export -#' @examples \dontrun{quarter(df$c)} #' @note quarter since 1.5.0 setMethod("quarter", signature(x = "Column"), @@ -1364,18 +1394,12 @@ setMethod("sd", stddev_samp(x) }) -#' second -#' -#' Extracts the seconds as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{second}: Extracts the seconds as an integer from a given date/timestamp/string. #' -#' @rdname second -#' @name second -#' @family date time functions -#' @aliases second,Column-method +#' @rdname column_datetime_functions +#' @aliases second second,Column-method #' @export -#' @examples \dontrun{second(df$c)} #' @note second since 1.5.0 setMethod("second", signature(x = "Column"), @@ -1725,29 +1749,28 @@ setMethod("toRadians", column(jc) }) -#' to_date -#' -#' Converts the column into a DateType. You may optionally specify a format -#' according to the rules in: +#' @details +#' \code{to_date}: Converts the column into a DateType. You may optionally specify +#' a format according to the rules in: #' \url{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}. #' If the string cannot be parsed according to the specified format (or default), #' the value of the column will be null. #' By default, it follows casting rules to a DateType if the format is omitted #' (equivalent to \code{cast(df$x, "date")}). #' -#' @param x Column to parse. -#' @param format string to use to parse x Column to DateType. (optional) -#' -#' @rdname to_date -#' @name to_date -#' @family date time functions -#' @aliases to_date,Column,missing-method +#' @rdname column_datetime_functions +#' @aliases to_date to_date,Column,missing-method #' @export #' @examples +#' #' \dontrun{ -#' to_date(df$c) -#' to_date(df$c, 'yyyy-MM-dd') -#' } +#' tmp <- createDataFrame(data.frame(time_string = dts)) +#' tmp2 <- mutate(tmp, date1 = to_date(tmp$time_string), +#' date2 = to_date(tmp$time_string, "yyyy-MM-dd"), +#' date3 = date_format(tmp$time_string, "MM/dd/yyy"), +#' time1 = to_timestamp(tmp$time_string), +#' time2 = to_timestamp(tmp$time_string, "yyyy-MM-dd")) +#' head(tmp2)} #' @note to_date(Column) since 1.5.0 setMethod("to_date", signature(x = "Column", format = "missing"), @@ -1756,9 +1779,7 @@ setMethod("to_date", column(jc) }) -#' @rdname to_date -#' @name to_date -#' @family date time functions +#' @rdname column_datetime_functions #' @aliases to_date,Column,character-method #' @export #' @note to_date(Column, character) since 2.2.0 @@ -1801,29 +1822,18 @@ setMethod("to_json", signature(x = "Column"), column(jc) }) -#' to_timestamp -#' -#' Converts the column into a TimestampType. You may optionally specify a format -#' according to the rules in: +#' @details +#' \code{to_timestamp}: Converts the column into a TimestampType. You may optionally specify +#' a format according to the rules in: #' \url{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}. #' If the string cannot be parsed according to the specified format (or default), #' the value of the column will be null. #' By default, it follows casting rules to a TimestampType if the format is omitted #' (equivalent to \code{cast(df$x, "timestamp")}). #' -#' @param x Column to parse. -#' @param format string to use to parse x Column to TimestampType. (optional) -#' -#' @rdname to_timestamp -#' @name to_timestamp -#' @family date time functions -#' @aliases to_timestamp,Column,missing-method +#' @rdname column_datetime_functions +#' @aliases to_timestamp to_timestamp,Column,missing-method #' @export -#' @examples -#' \dontrun{ -#' to_timestamp(df$c) -#' to_timestamp(df$c, 'yyyy-MM-dd') -#' } #' @note to_timestamp(Column) since 2.2.0 setMethod("to_timestamp", signature(x = "Column", format = "missing"), @@ -1832,9 +1842,7 @@ setMethod("to_timestamp", column(jc) }) -#' @rdname to_timestamp -#' @name to_timestamp -#' @family date time functions +#' @rdname column_datetime_functions #' @aliases to_timestamp,Column,character-method #' @export #' @note to_timestamp(Column, character) since 2.2.0 @@ -1984,18 +1992,12 @@ setMethod("var_samp", column(jc) }) -#' weekofyear -#' -#' Extracts the week number as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{weekofyear}: Extracts the week number as an integer from a given date/timestamp/string. #' -#' @rdname weekofyear -#' @name weekofyear -#' @aliases weekofyear,Column-method -#' @family date time functions +#' @rdname column_datetime_functions +#' @aliases weekofyear weekofyear,Column-method #' @export -#' @examples \dontrun{weekofyear(df$c)} #' @note weekofyear since 1.5.0 setMethod("weekofyear", signature(x = "Column"), @@ -2004,18 +2006,12 @@ setMethod("weekofyear", column(jc) }) -#' year -#' -#' Extracts the year as an integer from a given date/timestamp/string. -#' -#' @param x Column to compute on. +#' @details +#' \code{year}: Extracts the year as an integer from a given date/timestamp/string. #' -#' @rdname year -#' @name year -#' @family date time functions -#' @aliases year,Column-method +#' @rdname column_datetime_functions +#' @aliases year year,Column-method #' @export -#' @examples \dontrun{year(df$c)} #' @note year since 1.5.0 setMethod("year", signature(x = "Column"), @@ -2048,19 +2044,20 @@ setMethod("atan2", signature(y = "Column"), column(jc) }) -#' datediff -#' -#' Returns the number of days from \code{start} to \code{end}. -#' -#' @param x start Column to use. -#' @param y end Column to use. +#' @details +#' \code{datediff}: Returns the number of days from \code{y} to \code{x}. #' -#' @rdname datediff -#' @name datediff -#' @aliases datediff,Column-method -#' @family date time functions +#' @rdname column_datetime_diff_functions +#' @aliases datediff datediff,Column-method #' @export -#' @examples \dontrun{datediff(df$c, x)} +#' @examples +#' +#' \dontrun{ +#' tmp <- createDataFrame(data.frame(time_string1 = as.POSIXct(dts), +#' time_string2 = as.POSIXct(dts[order(runif(length(dts)))]))) +#' tmp2 <- mutate(tmp, datediff = datediff(tmp$time_string1, tmp$time_string2), +#' monthdiff = months_between(tmp$time_string1, tmp$time_string2)) +#' head(tmp2)} #' @note datediff since 1.5.0 setMethod("datediff", signature(y = "Column"), function(y, x) { @@ -2117,19 +2114,12 @@ setMethod("levenshtein", signature(y = "Column"), column(jc) }) -#' months_between -#' -#' Returns number of months between dates \code{date1} and \code{date2}. -#' -#' @param x start Column to use. -#' @param y end Column to use. +#' @details +#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. #' -#' @rdname months_between -#' @name months_between -#' @family date time functions -#' @aliases months_between,Column-method +#' @rdname column_datetime_diff_functions +#' @aliases months_between months_between,Column-method #' @export -#' @examples \dontrun{months_between(df$c, x)} #' @note months_between since 1.5.0 setMethod("months_between", signature(y = "Column"), function(y, x) { @@ -2348,26 +2338,18 @@ setMethod("n", signature(x = "Column"), count(x) }) -#' date_format -#' -#' Converts a date/timestamp/string to a value of string in the format specified by the date -#' format given by the second argument. -#' -#' A pattern could be for instance \preformatted{dd.MM.yyyy} and could return a string like '18.03.1993'. All +#' @details +#' \code{date_format}: Converts a date/timestamp/string to a value of string in the format +#' specified by the date format given by the second argument. A pattern could be for instance +#' \code{dd.MM.yyyy} and could return a string like '18.03.1993'. All #' pattern letters of \code{java.text.SimpleDateFormat} can be used. -#' #' Note: Use when ever possible specialized functions like \code{year}. These benefit from a #' specialized implementation. #' -#' @param y Column to compute on. -#' @param x date format specification. +#' @rdname column_datetime_diff_functions #' -#' @family date time functions -#' @rdname date_format -#' @name date_format -#' @aliases date_format,Column,character-method +#' @aliases date_format date_format,Column,character-method #' @export -#' @examples \dontrun{date_format(df$t, 'MM/dd/yyy')} #' @note date_format since 1.5.0 setMethod("date_format", signature(y = "Column", x = "character"), function(y, x) { @@ -2414,20 +2396,20 @@ setMethod("from_json", signature(x = "Column", schema = "structType"), column(jc) }) -#' from_utc_timestamp -#' -#' Given a timestamp, which corresponds to a certain time of day in UTC, returns another timestamp -#' that corresponds to the same time of day in the given timezone. +#' @details +#' \code{from_utc_timestamp}: Given a timestamp, which corresponds to a certain time of day in UTC, +#' returns another timestamp that corresponds to the same time of day in the given timezone. #' -#' @param y Column to compute on. -#' @param x time zone to use. +#' @rdname column_datetime_diff_functions #' -#' @family date time functions -#' @rdname from_utc_timestamp -#' @name from_utc_timestamp -#' @aliases from_utc_timestamp,Column,character-method +#' @aliases from_utc_timestamp from_utc_timestamp,Column,character-method #' @export -#' @examples \dontrun{from_utc_timestamp(df$t, 'PST')} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, from_utc = from_utc_timestamp(df$time, 'PST'), +#' to_utc = to_utc_timestamp(df$time, 'PST')) +#' head(tmp)} #' @note from_utc_timestamp since 1.5.0 setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { @@ -2458,30 +2440,16 @@ setMethod("instr", signature(y = "Column", x = "character"), column(jc) }) -#' next_day -#' -#' Given a date column, returns the first date which is later than the value of the date column -#' that is on the specified day of the week. -#' -#' For example, \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first -#' Sunday after 2015-07-27. -#' -#' Day of the week parameter is case insensitive, and accepts first three or two characters: -#' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". -#' -#' @param y Column to compute on. -#' @param x Day of the week string. +#' @details +#' \code{next_day}: Given a date column, returns the first date which is later than the value of +#' the date column that is on the specified day of the week. For example, +#' \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first Sunday +#' after 2015-07-27. Day of the week parameter is case insensitive, and accepts first three or +#' two characters: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". #' -#' @family date time functions -#' @rdname next_day -#' @name next_day -#' @aliases next_day,Column,character-method +#' @rdname column_datetime_diff_functions +#' @aliases next_day next_day,Column,character-method #' @export -#' @examples -#'\dontrun{ -#'next_day(df$d, 'Sun') -#'next_day(df$d, 'Sunday') -#'} #' @note next_day since 1.5.0 setMethod("next_day", signature(y = "Column", x = "character"), function(y, x) { @@ -2489,20 +2457,13 @@ setMethod("next_day", signature(y = "Column", x = "character"), column(jc) }) -#' to_utc_timestamp -#' -#' Given a timestamp, which corresponds to a certain time of day in the given timezone, returns -#' another timestamp that corresponds to the same time of day in UTC. -#' -#' @param y Column to compute on -#' @param x timezone to use +#' @details +#' \code{to_utc_timestamp}: Given a timestamp, which corresponds to a certain time of day +#' in the given timezone, returns another timestamp that corresponds to the same time of day in UTC. #' -#' @family date time functions -#' @rdname to_utc_timestamp -#' @name to_utc_timestamp -#' @aliases to_utc_timestamp,Column,character-method +#' @rdname column_datetime_diff_functions +#' @aliases to_utc_timestamp to_utc_timestamp,Column,character-method #' @export -#' @examples \dontrun{to_utc_timestamp(df$t, 'PST')} #' @note to_utc_timestamp since 1.5.0 setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { @@ -2510,19 +2471,20 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), column(jc) }) -#' add_months -#' -#' Returns the date that is numMonths after startDate. -#' -#' @param y Column to compute on -#' @param x Number of months to add +#' @details +#' \code{add_months}: Returns the date that is numMonths (\code{x}) after startDate (\code{y}). #' -#' @name add_months -#' @family date time functions -#' @rdname add_months -#' @aliases add_months,Column,numeric-method +#' @rdname column_datetime_diff_functions +#' @aliases add_months add_months,Column,numeric-method #' @export -#' @examples \dontrun{add_months(df$d, 1)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, t1 = add_months(df$time, 1), +#' t2 = date_add(df$time, 2), +#' t3 = date_sub(df$time, 3), +#' t4 = next_day(df$time, 'Sun')) +#' head(tmp)} #' @note add_months since 1.5.0 setMethod("add_months", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2530,19 +2492,12 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), column(jc) }) -#' date_add -#' -#' Returns the date that is \code{x} days after -#' -#' @param y Column to compute on -#' @param x Number of days to add +#' @details +#' \code{date_add}: Returns the date that is \code{x} days after. #' -#' @family date time functions -#' @rdname date_add -#' @name date_add -#' @aliases date_add,Column,numeric-method +#' @rdname column_datetime_diff_functions +#' @aliases date_add date_add,Column,numeric-method #' @export -#' @examples \dontrun{date_add(df$d, 1)} #' @note date_add since 1.5.0 setMethod("date_add", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2550,19 +2505,13 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), column(jc) }) -#' date_sub -#' -#' Returns the date that is \code{x} days before +#' @details +#' \code{date_sub}: Returns the date that is \code{x} days before. #' -#' @param y Column to compute on -#' @param x Number of days to substract +#' @rdname column_datetime_diff_functions #' -#' @family date time functions -#' @rdname date_sub -#' @name date_sub -#' @aliases date_sub,Column,numeric-method +#' @aliases date_sub date_sub,Column,numeric-method #' @export -#' @examples \dontrun{date_sub(df$d, 1)} #' @note date_sub since 1.5.0 setMethod("date_sub", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2774,27 +2723,24 @@ setMethod("format_string", signature(format = "character", x = "Column"), column(jc) }) -#' from_unixtime +#' @details +#' \code{from_unixtime}: Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a +#' string representing the timestamp of that moment in the current system time zone in the JVM in the +#' given format. See \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ +#' Customizing Formats} for available options. #' -#' Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string -#' representing the timestamp of that moment in the current system time zone in the given -#' format. +#' @rdname column_datetime_functions #' -#' @param x a Column of unix timestamp. -#' @param format the target format. See -#' \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ -#' Customizing Formats} for available options. -#' @param ... further arguments to be passed to or from other methods. -#' @family date time functions -#' @rdname from_unixtime -#' @name from_unixtime -#' @aliases from_unixtime,Column-method +#' @aliases from_unixtime from_unixtime,Column-method #' @export #' @examples -#'\dontrun{ -#'from_unixtime(df$t) -#'from_unixtime(df$t, 'yyyy/MM/dd HH') -#'} +#' +#' \dontrun{ +#' tmp <- mutate(df, to_unix = unix_timestamp(df$time), +#' to_unix2 = unix_timestamp(df$time, 'yyyy-MM-dd HH'), +#' from_unix = from_unixtime(unix_timestamp(df$time)), +#' from_unix2 = from_unixtime(unix_timestamp(df$time), 'yyyy-MM-dd HH:mm')) +#' head(tmp)} #' @note from_unixtime since 1.5.0 setMethod("from_unixtime", signature(x = "Column"), function(x, format = "yyyy-MM-dd HH:mm:ss") { @@ -3111,21 +3057,12 @@ setMethod("translate", column(jc) }) -#' unix_timestamp -#' -#' Gets current Unix timestamp in seconds. +#' @details +#' \code{unix_timestamp}: Gets current Unix timestamp in seconds. #' -#' @family date time functions -#' @rdname unix_timestamp -#' @name unix_timestamp -#' @aliases unix_timestamp,missing,missing-method +#' @rdname column_datetime_functions +#' @aliases unix_timestamp unix_timestamp,missing,missing-method #' @export -#' @examples -#'\dontrun{ -#'unix_timestamp() -#'unix_timestamp(df$t) -#'unix_timestamp(df$t, 'yyyy-MM-dd HH') -#'} #' @note unix_timestamp since 1.5.0 setMethod("unix_timestamp", signature(x = "missing", format = "missing"), function(x, format) { @@ -3133,8 +3070,7 @@ setMethod("unix_timestamp", signature(x = "missing", format = "missing"), column(jc) }) -#' @rdname unix_timestamp -#' @name unix_timestamp +#' @rdname column_datetime_functions #' @aliases unix_timestamp,Column,missing-method #' @export #' @note unix_timestamp(Column) since 1.5.0 @@ -3144,12 +3080,7 @@ setMethod("unix_timestamp", signature(x = "Column", format = "missing"), column(jc) }) -#' @param x a Column of date, in string, date or timestamp type. -#' @param format the target format. See -#' \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ -#' Customizing Formats} for available options. -#' @rdname unix_timestamp -#' @name unix_timestamp +#' @rdname column_datetime_functions #' @aliases unix_timestamp,Column,character-method #' @export #' @note unix_timestamp(Column, character) since 1.5.0 @@ -3931,26 +3862,17 @@ setMethod("input_file_name", signature("missing"), column(jc) }) -#' trunc -#' -#' Returns date truncated to the unit specified by the format. -#' -#' @param x Column to compute on. -#' @param format string used for specify the truncation method. For example, "year", "yyyy", -#' "yy" for truncate by year, or "month", "mon", "mm" for truncate by month. +#' @details +#' \code{trunc}: Returns date truncated to the unit specified by the format. #' -#' @rdname trunc -#' @name trunc -#' @family date time functions -#' @aliases trunc,Column-method +#' @rdname column_datetime_functions +#' @aliases trunc trunc,Column-method #' @export #' @examples +#' #' \dontrun{ -#' trunc(df$c, "year") -#' trunc(df$c, "yy") -#' trunc(df$c, "month") -#' trunc(df$c, "mon") -#' } +#' head(select(df, df$time, trunc(df$time, "year"), trunc(df$time, "yy"), +#' trunc(df$time, "month"), trunc(df$time, "mon")))} #' @note trunc since 2.3.0 setMethod("trunc", signature(x = "Column"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index b3cc4868a0b33..f105174cea70d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -903,8 +903,9 @@ setGeneric("windowOrderBy", function(col, ...) { standardGeneric("windowOrderBy" ###################### Expression Function Methods ########################## -#' @rdname add_months +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) #' @rdname column_aggregate_functions @@ -1002,28 +1003,34 @@ setGeneric("hash", function(x, ...) { standardGeneric("hash") }) #' @export setGeneric("cume_dist", function(x = "missing") { standardGeneric("cume_dist") }) -#' @rdname datediff +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) -#' @rdname date_add +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("date_add", function(y, x) { standardGeneric("date_add") }) -#' @rdname date_format +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) -#' @rdname date_sub +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) -#' @rdname dayofmonth +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) -#' @rdname dayofyear +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) #' @rdname decode @@ -1051,8 +1058,9 @@ setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) #' @export setGeneric("expr", function(x) { standardGeneric("expr") }) -#' @rdname from_utc_timestamp +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) #' @rdname format_number @@ -1067,8 +1075,9 @@ setGeneric("format_string", function(format, x, ...) { standardGeneric("format_s #' @export setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) -#' @rdname from_unixtime +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) #' @rdname greatest @@ -1089,8 +1098,9 @@ setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) #' @export setGeneric("hex", function(x) { standardGeneric("hex") }) -#' @rdname hour +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("hour", function(x) { standardGeneric("hour") }) #' @rdname hypot @@ -1128,8 +1138,9 @@ setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @export setGeneric("last", function(x, ...) { standardGeneric("last") }) -#' @rdname last_day +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("last_day", function(x) { standardGeneric("last_day") }) #' @rdname lead @@ -1168,8 +1179,9 @@ setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) #' @export setGeneric("md5", function(x) { standardGeneric("md5") }) -#' @rdname minute +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("minute", function(x) { standardGeneric("minute") }) #' @param x empty. Should be used with no argument. @@ -1178,12 +1190,14 @@ setGeneric("minute", function(x) { standardGeneric("minute") }) setGeneric("monotonically_increasing_id", function(x = "missing") { standardGeneric("monotonically_increasing_id") }) -#' @rdname month +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("month", function(x) { standardGeneric("month") }) -#' @rdname months_between +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) #' @rdname count @@ -1202,8 +1216,9 @@ setGeneric("negate", function(x) { standardGeneric("negate") }) #' @export setGeneric("not", function(x) { standardGeneric("not") }) -#' @rdname next_day +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) #' @rdname ntile @@ -1232,8 +1247,9 @@ setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) #' @export setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") }) -#' @rdname quarter +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("quarter", function(x) { standardGeneric("quarter") }) #' @rdname rand @@ -1287,8 +1303,9 @@ setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) #' @name NULL setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) -#' @rdname second +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("second", function(x) { standardGeneric("second") }) #' @rdname sha1 @@ -1377,20 +1394,23 @@ setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) #' @export setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) -#' @rdname to_date +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("to_date", function(x, format) { standardGeneric("to_date") }) #' @rdname to_json #' @export setGeneric("to_json", function(x, ...) { standardGeneric("to_json") }) -#' @rdname to_timestamp +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("to_timestamp", function(x, format) { standardGeneric("to_timestamp") }) -#' @rdname to_utc_timestamp +#' @rdname column_datetime_diff_functions #' @export +#' @name NULL setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) #' @rdname translate @@ -1409,8 +1429,9 @@ setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) #' @export setGeneric("unhex", function(x) { standardGeneric("unhex") }) -#' @rdname unix_timestamp +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) #' @rdname upper @@ -1437,16 +1458,18 @@ setGeneric("var_pop", function(x) { standardGeneric("var_pop") }) #' @name NULL setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) -#' @rdname weekofyear +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) #' @rdname window #' @export setGeneric("window", function(x, ...) { standardGeneric("window") }) -#' @rdname year +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("year", function(x) { standardGeneric("year") }) From e55a105ae04f1d1c35ee8f02005a3ab71d789124 Mon Sep 17 00:00:00 2001 From: Lubo Zhang Date: Thu, 22 Jun 2017 11:18:58 -0700 Subject: [PATCH 058/118] [SPARK-20599][SS] ConsoleSink should work with (batch) ## What changes were proposed in this pull request? Currently, if we read a batch and want to display it on the console sink, it will lead a runtime exception. Changes: - In this PR, we add a match rule to check whether it is a ConsoleSinkProvider, we will display the Dataset if using console format. ## How was this patch tested? spark.read.schema().json(path).write.format("console").save Author: Lubo Zhang Author: lubozhan Closes #18347 from lubozhan/dev. --- .../sql/execution/streaming/console.scala | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 38c63191106d0..9e889ff679450 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.types.StructType class ConsoleSink(options: Map[String, String]) extends Sink with Logging { // Number of rows to display, by default 20 rows @@ -51,7 +53,14 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging { } } -class ConsoleSinkProvider extends StreamSinkProvider with DataSourceRegister { +case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) + extends BaseRelation { + override def schema: StructType = data.schema +} + +class ConsoleSinkProvider extends StreamSinkProvider + with DataSourceRegister + with CreatableRelationProvider { def createSink( sqlContext: SQLContext, parameters: Map[String, String], @@ -60,5 +69,20 @@ class ConsoleSinkProvider extends StreamSinkProvider with DataSourceRegister { new ConsoleSink(parameters) } + def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + // Number of rows to display, by default 20 rows + val numRowsToShow = parameters.get("numRows").map(_.toInt).getOrElse(20) + + // Truncate the displayed data if it is too long, by default it is true + val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true) + data.showInternal(numRowsToShow, isTruncated) + + ConsoleRelation(sqlContext, data) + } + def shortName(): String = "console" } From 58434acdd8cec0c762b4f09ace25e41d603af0a4 Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 22 Jun 2017 14:10:51 -0700 Subject: [PATCH 059/118] [SPARK-19937] Collect metrics for remote bytes read to disk during shuffle. In current code(https://github.com/apache/spark/pull/16989), big blocks are shuffled to disk. This pr proposes to collect metrics for remote bytes fetched to disk. Author: jinxing Closes #18249 from jinxing64/SPARK-19937. --- .../apache/spark/InternalAccumulator.scala | 1 + .../spark/executor/ShuffleReadMetrics.scala | 13 +++++ .../apache/spark/executor/TaskMetrics.scala | 1 + .../status/api/v1/AllStagesResource.scala | 2 + .../org/apache/spark/status/api/v1/api.scala | 2 + .../storage/ShuffleBlockFetcherIterator.scala | 6 +++ .../org/apache/spark/ui/jobs/UIData.scala | 4 +- .../org/apache/spark/util/JsonProtocol.scala | 3 ++ .../one_stage_attempt_json_expectation.json | 8 +++ .../one_stage_json_expectation.json | 8 +++ .../stage_task_list_expectation.json | 20 ++++++++ ...multi_attempt_app_json_1__expectation.json | 8 +++ ...multi_attempt_app_json_2__expectation.json | 8 +++ ...k_list_w__offset___length_expectation.json | 50 +++++++++++++++++++ ...stage_task_list_w__sortBy_expectation.json | 20 ++++++++ ...tBy_short_names___runtime_expectation.json | 20 ++++++++ ...rtBy_short_names__runtime_expectation.json | 20 ++++++++ ...mmary_w__custom_quantiles_expectation.json | 1 + ...sk_summary_w_shuffle_read_expectation.json | 1 + ...k_summary_w_shuffle_write_expectation.json | 1 + ...age_with_accumulable_json_expectation.json | 8 +++ .../spark/executor/TaskMetricsSuite.scala | 3 ++ .../apache/spark/util/JsonProtocolSuite.scala | 33 ++++++++---- project/MimaExcludes.scala | 6 ++- 24 files changed, 234 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala index 82d3098e2e055..18b10d23da94c 100644 --- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala +++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala @@ -50,6 +50,7 @@ private[spark] object InternalAccumulator { val REMOTE_BLOCKS_FETCHED = SHUFFLE_READ_METRICS_PREFIX + "remoteBlocksFetched" val LOCAL_BLOCKS_FETCHED = SHUFFLE_READ_METRICS_PREFIX + "localBlocksFetched" val REMOTE_BYTES_READ = SHUFFLE_READ_METRICS_PREFIX + "remoteBytesRead" + val REMOTE_BYTES_READ_TO_DISK = SHUFFLE_READ_METRICS_PREFIX + "remoteBytesReadToDisk" val LOCAL_BYTES_READ = SHUFFLE_READ_METRICS_PREFIX + "localBytesRead" val FETCH_WAIT_TIME = SHUFFLE_READ_METRICS_PREFIX + "fetchWaitTime" val RECORDS_READ = SHUFFLE_READ_METRICS_PREFIX + "recordsRead" diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index 8dd1a1ea059be..4be395c8358b2 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -31,6 +31,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { private[executor] val _remoteBlocksFetched = new LongAccumulator private[executor] val _localBlocksFetched = new LongAccumulator private[executor] val _remoteBytesRead = new LongAccumulator + private[executor] val _remoteBytesReadToDisk = new LongAccumulator private[executor] val _localBytesRead = new LongAccumulator private[executor] val _fetchWaitTime = new LongAccumulator private[executor] val _recordsRead = new LongAccumulator @@ -50,6 +51,11 @@ class ShuffleReadMetrics private[spark] () extends Serializable { */ def remoteBytesRead: Long = _remoteBytesRead.sum + /** + * Total number of remotes bytes read to disk from the shuffle by this task. + */ + def remoteBytesReadToDisk: Long = _remoteBytesReadToDisk.sum + /** * Shuffle data that was read from the local disk (as opposed to from a remote executor). */ @@ -80,6 +86,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { private[spark] def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched.add(v) private[spark] def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched.add(v) private[spark] def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead.add(v) + private[spark] def incRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk.add(v) private[spark] def incLocalBytesRead(v: Long): Unit = _localBytesRead.add(v) private[spark] def incFetchWaitTime(v: Long): Unit = _fetchWaitTime.add(v) private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v) @@ -87,6 +94,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { private[spark] def setRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched.setValue(v) private[spark] def setLocalBlocksFetched(v: Int): Unit = _localBlocksFetched.setValue(v) private[spark] def setRemoteBytesRead(v: Long): Unit = _remoteBytesRead.setValue(v) + private[spark] def setRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk.setValue(v) private[spark] def setLocalBytesRead(v: Long): Unit = _localBytesRead.setValue(v) private[spark] def setFetchWaitTime(v: Long): Unit = _fetchWaitTime.setValue(v) private[spark] def setRecordsRead(v: Long): Unit = _recordsRead.setValue(v) @@ -99,6 +107,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { _remoteBlocksFetched.setValue(0) _localBlocksFetched.setValue(0) _remoteBytesRead.setValue(0) + _remoteBytesReadToDisk.setValue(0) _localBytesRead.setValue(0) _fetchWaitTime.setValue(0) _recordsRead.setValue(0) @@ -106,6 +115,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { _remoteBlocksFetched.add(metric.remoteBlocksFetched) _localBlocksFetched.add(metric.localBlocksFetched) _remoteBytesRead.add(metric.remoteBytesRead) + _remoteBytesReadToDisk.add(metric.remoteBytesReadToDisk) _localBytesRead.add(metric.localBytesRead) _fetchWaitTime.add(metric.fetchWaitTime) _recordsRead.add(metric.recordsRead) @@ -122,6 +132,7 @@ private[spark] class TempShuffleReadMetrics { private[this] var _remoteBlocksFetched = 0L private[this] var _localBlocksFetched = 0L private[this] var _remoteBytesRead = 0L + private[this] var _remoteBytesReadToDisk = 0L private[this] var _localBytesRead = 0L private[this] var _fetchWaitTime = 0L private[this] var _recordsRead = 0L @@ -129,6 +140,7 @@ private[spark] class TempShuffleReadMetrics { def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v + def incRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk += v def incLocalBytesRead(v: Long): Unit = _localBytesRead += v def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v def incRecordsRead(v: Long): Unit = _recordsRead += v @@ -136,6 +148,7 @@ private[spark] class TempShuffleReadMetrics { def remoteBlocksFetched: Long = _remoteBlocksFetched def localBlocksFetched: Long = _localBlocksFetched def remoteBytesRead: Long = _remoteBytesRead + def remoteBytesReadToDisk: Long = _remoteBytesReadToDisk def localBytesRead: Long = _localBytesRead def fetchWaitTime: Long = _fetchWaitTime def recordsRead: Long = _recordsRead diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index a3ce3d1ccc5e3..341a6da8107ef 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -215,6 +215,7 @@ class TaskMetrics private[spark] () extends Serializable { shuffleRead.REMOTE_BLOCKS_FETCHED -> shuffleReadMetrics._remoteBlocksFetched, shuffleRead.LOCAL_BLOCKS_FETCHED -> shuffleReadMetrics._localBlocksFetched, shuffleRead.REMOTE_BYTES_READ -> shuffleReadMetrics._remoteBytesRead, + shuffleRead.REMOTE_BYTES_READ_TO_DISK -> shuffleReadMetrics._remoteBytesReadToDisk, shuffleRead.LOCAL_BYTES_READ -> shuffleReadMetrics._localBytesRead, shuffleRead.FETCH_WAIT_TIME -> shuffleReadMetrics._fetchWaitTime, shuffleRead.RECORDS_READ -> shuffleReadMetrics._recordsRead, diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 1818935392eb3..56028710ecc66 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -200,6 +200,7 @@ private[v1] object AllStagesResource { readBytes = submetricQuantiles(_.totalBytesRead), readRecords = submetricQuantiles(_.recordsRead), remoteBytesRead = submetricQuantiles(_.remoteBytesRead), + remoteBytesReadToDisk = submetricQuantiles(_.remoteBytesReadToDisk), remoteBlocksFetched = submetricQuantiles(_.remoteBlocksFetched), localBlocksFetched = submetricQuantiles(_.localBlocksFetched), totalBlocksFetched = submetricQuantiles(_.totalBlocksFetched), @@ -281,6 +282,7 @@ private[v1] object AllStagesResource { localBlocksFetched = internal.localBlocksFetched, fetchWaitTime = internal.fetchWaitTime, remoteBytesRead = internal.remoteBytesRead, + remoteBytesReadToDisk = internal.remoteBytesReadToDisk, localBytesRead = internal.localBytesRead, recordsRead = internal.recordsRead ) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index f6203271f3cd2..05948f2661056 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -208,6 +208,7 @@ class ShuffleReadMetrics private[spark]( val localBlocksFetched: Long, val fetchWaitTime: Long, val remoteBytesRead: Long, + val remoteBytesReadToDisk: Long, val localBytesRead: Long, val recordsRead: Long) @@ -249,6 +250,7 @@ class ShuffleReadMetricDistributions private[spark]( val localBlocksFetched: IndexedSeq[Double], val fetchWaitTime: IndexedSeq[Double], val remoteBytesRead: IndexedSeq[Double], + val remoteBytesReadToDisk: IndexedSeq[Double], val totalBlocksFetched: IndexedSeq[Double]) class ShuffleWriteMetricDistributions private[spark]( diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index bded3a1e4eb54..a10f1feadd0af 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -165,6 +165,9 @@ final class ShuffleBlockFetcherIterator( case SuccessFetchResult(_, address, _, buf, _) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } shuffleMetrics.incRemoteBlocksFetched(1) } buf.release() @@ -363,6 +366,9 @@ final class ShuffleBlockFetcherIterator( case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } shuffleMetrics.incRemoteBlocksFetched(1) } bytesInFlight -= size diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 6764daa0df529..9448baac096dc 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -251,6 +251,7 @@ private[spark] object UIData { remoteBlocksFetched: Long, localBlocksFetched: Long, remoteBytesRead: Long, + remoteBytesReadToDisk: Long, localBytesRead: Long, fetchWaitTime: Long, recordsRead: Long, @@ -274,6 +275,7 @@ private[spark] object UIData { remoteBlocksFetched = metrics.remoteBlocksFetched, localBlocksFetched = metrics.localBlocksFetched, remoteBytesRead = metrics.remoteBytesRead, + remoteBytesReadToDisk = metrics.remoteBytesReadToDisk, localBytesRead = metrics.localBytesRead, fetchWaitTime = metrics.fetchWaitTime, recordsRead = metrics.recordsRead, @@ -282,7 +284,7 @@ private[spark] object UIData { ) } } - private val EMPTY = ShuffleReadMetricsUIData(0, 0, 0, 0, 0, 0, 0, 0) + private val EMPTY = ShuffleReadMetricsUIData(0, 0, 0, 0, 0, 0, 0, 0, 0) } case class ShuffleWriteMetricsUIData( diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 8296c4294242c..806d14e7cc119 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -339,6 +339,7 @@ private[spark] object JsonProtocol { ("Local Blocks Fetched" -> taskMetrics.shuffleReadMetrics.localBlocksFetched) ~ ("Fetch Wait Time" -> taskMetrics.shuffleReadMetrics.fetchWaitTime) ~ ("Remote Bytes Read" -> taskMetrics.shuffleReadMetrics.remoteBytesRead) ~ + ("Remote Bytes Read To Disk" -> taskMetrics.shuffleReadMetrics.remoteBytesReadToDisk) ~ ("Local Bytes Read" -> taskMetrics.shuffleReadMetrics.localBytesRead) ~ ("Total Records Read" -> taskMetrics.shuffleReadMetrics.recordsRead) val shuffleWriteMetrics: JValue = @@ -804,6 +805,8 @@ private[spark] object JsonProtocol { readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int]) readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int]) readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long]) + Utils.jsonOption(readJson \ "Remote Bytes Read To Disk") + .foreach { v => readMetrics.incRemoteBytesReadToDisk(v.extract[Long])} readMetrics.incLocalBytesRead( Utils.jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L)) readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long]) diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index c2f450ba87c6d..6fb40f6f1713b 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -60,6 +60,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -105,6 +106,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -150,6 +152,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -195,6 +198,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -240,6 +244,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -285,6 +290,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -330,6 +336,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -375,6 +382,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index 506859ae545b1..f5a89a2107646 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -60,6 +60,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -105,6 +106,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -150,6 +152,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -195,6 +198,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -240,6 +244,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -285,6 +290,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -330,6 +336,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -375,6 +382,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json index f4cec68fbfdf2..9b401b414f8d4 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json @@ -33,6 +33,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -77,6 +78,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -121,6 +123,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -165,6 +168,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -209,6 +213,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -253,6 +258,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -297,6 +303,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -341,6 +348,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -385,6 +393,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -429,6 +438,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -473,6 +483,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -517,6 +528,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -561,6 +573,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -605,6 +618,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -649,6 +663,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -693,6 +708,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -737,6 +753,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -781,6 +798,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -825,6 +843,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -869,6 +888,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json index 496a21c328da9..2ebee66a6d7c2 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json @@ -38,6 +38,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -87,6 +88,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -136,6 +138,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -185,6 +188,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -234,6 +238,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -283,6 +288,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -332,6 +338,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -381,6 +388,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json index 4328dc753c5d4..965a31a4104c3 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json @@ -38,6 +38,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -87,6 +88,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -136,6 +138,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -185,6 +188,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -234,6 +238,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -283,6 +288,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -332,6 +338,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -381,6 +388,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json index 8c571430f3a1f..31132e156937c 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json @@ -33,6 +33,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -77,6 +78,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -121,6 +123,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -165,6 +168,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -209,6 +213,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -253,6 +258,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -297,6 +303,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -341,6 +348,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -385,6 +393,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -429,6 +438,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -473,6 +483,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -517,6 +528,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -561,6 +573,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -605,6 +618,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -649,6 +663,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -693,6 +708,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -737,6 +753,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -781,6 +798,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -825,6 +843,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -869,6 +888,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -913,6 +933,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -957,6 +978,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1001,6 +1023,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1045,6 +1068,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1089,6 +1113,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1133,6 +1158,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1177,6 +1203,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1221,6 +1248,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1265,6 +1293,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1309,6 +1338,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1353,6 +1383,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1397,6 +1428,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1441,6 +1473,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1485,6 +1518,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1529,6 +1563,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1573,6 +1608,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1617,6 +1653,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1661,6 +1698,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1705,6 +1743,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1749,6 +1788,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1793,6 +1833,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1837,6 +1878,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1881,6 +1923,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1925,6 +1968,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -1969,6 +2013,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -2013,6 +2058,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -2057,6 +2103,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -2101,6 +2148,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -2145,6 +2193,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -2189,6 +2238,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json index 0bd614bdc756e..6af1cfbeb8f7e 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json @@ -33,6 +33,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -77,6 +78,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -121,6 +123,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -165,6 +168,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -209,6 +213,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -253,6 +258,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -297,6 +303,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -341,6 +348,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -385,6 +393,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -429,6 +438,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -473,6 +483,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -517,6 +528,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -561,6 +573,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -605,6 +618,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -649,6 +663,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -693,6 +708,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -737,6 +753,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -781,6 +798,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -825,6 +843,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -869,6 +888,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json index 0bd614bdc756e..6af1cfbeb8f7e 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json @@ -33,6 +33,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -77,6 +78,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -121,6 +123,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -165,6 +168,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -209,6 +213,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -253,6 +258,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -297,6 +303,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -341,6 +348,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -385,6 +393,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -429,6 +438,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -473,6 +483,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -517,6 +528,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -561,6 +573,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -605,6 +618,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -649,6 +663,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -693,6 +708,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -737,6 +753,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -781,6 +798,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -825,6 +843,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -869,6 +888,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json index b58f1a51ba481..c26daf4b8d7bd 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json @@ -33,6 +33,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -77,6 +78,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -121,6 +123,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -165,6 +168,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -209,6 +213,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -253,6 +258,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -297,6 +303,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -341,6 +348,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -385,6 +393,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -429,6 +438,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -473,6 +483,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -517,6 +528,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -561,6 +573,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -605,6 +618,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -649,6 +663,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -693,6 +708,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -737,6 +753,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -781,6 +798,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -825,6 +843,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -869,6 +888,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json index 0ed609d5b7f92..f8e27703c0def 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json @@ -24,6 +24,7 @@ "localBlocksFetched" : [ 0.0, 0.0, 0.0 ], "fetchWaitTime" : [ 0.0, 0.0, 0.0 ], "remoteBytesRead" : [ 0.0, 0.0, 0.0 ], + "remoteBytesReadToDisk" : [ 0.0, 0.0, 0.0 ], "totalBlocksFetched" : [ 0.0, 0.0, 0.0 ] }, "shuffleWriteMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json index 6d230ac653776..a28bda16a956e 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json @@ -24,6 +24,7 @@ "localBlocksFetched" : [ 100.0, 100.0, 100.0, 100.0, 100.0 ], "fetchWaitTime" : [ 0.0, 0.0, 0.0, 1.0, 1.0 ], "remoteBytesRead" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "remoteBytesReadToDisk" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "totalBlocksFetched" : [ 100.0, 100.0, 100.0, 100.0, 100.0 ] }, "shuffleWriteMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json index aea0f5413d8b9..ede3eaed1d1d2 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json @@ -24,6 +24,7 @@ "localBlocksFetched" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "fetchWaitTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "remoteBytesRead" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "remoteBytesReadToDisk" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "totalBlocksFetched" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ] }, "shuffleWriteMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index a449926ee7dc6..44b5f66efe339 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -69,6 +69,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -119,6 +120,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -169,6 +171,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -219,6 +222,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -269,6 +273,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -319,6 +324,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -369,6 +375,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, @@ -419,6 +426,7 @@ "localBlocksFetched" : 0, "fetchWaitTime" : 0, "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, "localBytesRead" : 0, "recordsRead" : 0 }, diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index eae26fa742a23..7bcc2fb5231db 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -94,6 +94,8 @@ class TaskMetricsSuite extends SparkFunSuite { sr.setRemoteBytesRead(30L) sr.incRemoteBytesRead(3L) sr.incRemoteBytesRead(3L) + sr.setRemoteBytesReadToDisk(10L) + sr.incRemoteBytesReadToDisk(8L) sr.setLocalBytesRead(400L) sr.setLocalBytesRead(40L) sr.incLocalBytesRead(4L) @@ -110,6 +112,7 @@ class TaskMetricsSuite extends SparkFunSuite { assert(sr.remoteBlocksFetched == 12) assert(sr.localBlocksFetched == 24) assert(sr.remoteBytesRead == 36L) + assert(sr.remoteBytesReadToDisk == 18L) assert(sr.localBytesRead == 48L) assert(sr.fetchWaitTime == 60L) assert(sr.recordsRead == 72L) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index a77c8e3cab4e8..57452d4912abe 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -848,6 +848,7 @@ private[spark] object JsonProtocolSuite extends Assertions { } else { val sr = t.createTempShuffleReadMetrics() sr.incRemoteBytesRead(b + d) + sr.incRemoteBytesReadToDisk(b) sr.incLocalBlocksFetched(e) sr.incFetchWaitTime(a + d) sr.incRemoteBlocksFetched(f) @@ -1128,6 +1129,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Local Blocks Fetched": 700, | "Fetch Wait Time": 900, | "Remote Bytes Read": 1000, + | "Remote Bytes Read To Disk": 400, | "Local Bytes Read": 1100, | "Total Records Read": 10 | }, @@ -1228,6 +1230,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Local Blocks Fetched" : 0, | "Fetch Wait Time" : 0, | "Remote Bytes Read" : 0, + | "Remote Bytes Read To Disk" : 0, | "Local Bytes Read" : 0, | "Total Records Read" : 0 | }, @@ -1328,10 +1331,11 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Local Blocks Fetched" : 0, | "Fetch Wait Time" : 0, | "Remote Bytes Read" : 0, + | "Remote Bytes Read To Disk" : 0, | "Local Bytes Read" : 0, | "Total Records Read" : 0 | }, - | "Shuffle Write Metrics" : { + | "Shuffle Write Metrics": { | "Shuffle Bytes Written" : 0, | "Shuffle Write Time" : 0, | "Shuffle Records Written" : 0 @@ -1915,76 +1919,83 @@ private[spark] object JsonProtocolSuite extends Assertions { | }, | { | "ID": 14, - | "Name": "${shuffleRead.LOCAL_BYTES_READ}", + | "Name": "${shuffleRead.REMOTE_BYTES_READ_TO_DISK}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 15, - | "Name": "${shuffleRead.FETCH_WAIT_TIME}", + | "Name": "${shuffleRead.LOCAL_BYTES_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 16, - | "Name": "${shuffleRead.RECORDS_READ}", + | "Name": "${shuffleRead.FETCH_WAIT_TIME}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 17, - | "Name": "${shuffleWrite.BYTES_WRITTEN}", + | "Name": "${shuffleRead.RECORDS_READ}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 18, - | "Name": "${shuffleWrite.RECORDS_WRITTEN}", + | "Name": "${shuffleWrite.BYTES_WRITTEN}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 19, - | "Name": "${shuffleWrite.WRITE_TIME}", + | "Name": "${shuffleWrite.RECORDS_WRITTEN}", | "Update": 0, | "Internal": true, | "Count Failed Values": true | }, | { | "ID": 20, + | "Name": "${shuffleWrite.WRITE_TIME}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 21, | "Name": "${input.BYTES_READ}", | "Update": 2100, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 21, + | "ID": 22, | "Name": "${input.RECORDS_READ}", | "Update": 21, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 22, + | "ID": 23, | "Name": "${output.BYTES_WRITTEN}", | "Update": 1200, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 23, + | "ID": 24, | "Name": "${output.RECORDS_WRITTEN}", | "Update": 12, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 24, + | "ID": 25, | "Name": "$TEST_ACCUM", | "Update": 0, | "Internal": true, diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3cc089dcede38..1793da03a2c3e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -37,7 +37,11 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( // [SPARK-20495][SQL] Add StorageLevel to cacheTable API - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), + + // [SPARK-19937] Add remote bytes read to disk. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ShuffleReadMetricDistributions.this") ) // Exclude rules for 2.2.x From e44697606f429b01808c1a22cb44cb5b89585c5c Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 23 Jun 2017 09:01:13 +0800 Subject: [PATCH 060/118] [SPARK-13534][PYSPARK] Using Apache Arrow to increase performance of DataFrame.toPandas ## What changes were proposed in this pull request? Integrate Apache Arrow with Spark to increase performance of `DataFrame.toPandas`. This has been done by using Arrow to convert data partitions on the executor JVM to Arrow payload byte arrays where they are then served to the Python process. The Python DataFrame can then collect the Arrow payloads where they are combined and converted to a Pandas DataFrame. All non-complex data types are currently supported, otherwise an `UnsupportedOperation` exception is thrown. Additions to Spark include a Scala package private method `Dataset.toArrowPayloadBytes` that will convert data partitions in the executor JVM to `ArrowPayload`s as byte arrays so they can be easily served. A package private class/object `ArrowConverters` that provide data type mappings and conversion routines. In Python, a public method `DataFrame.collectAsArrow` is added to collect Arrow payloads and an optional flag in `toPandas(useArrow=False)` to enable using Arrow (uses the old conversion by default). ## How was this patch tested? Added a new test suite `ArrowConvertersSuite` that will run tests on conversion of Datasets to Arrow payloads for supported types. The suite will generate a Dataset and matching Arrow JSON data, then the dataset is converted to an Arrow payload and finally validated against the JSON data. This will ensure that the schema and data has been converted correctly. Added PySpark tests to verify the `toPandas` method is producing equal DataFrames with and without pyarrow. A roundtrip test to ensure the pandas DataFrame produced by pyspark is equal to a one made directly with pandas. Author: Bryan Cutler Author: Li Jin Author: Li Jin Author: Wes McKinney Closes #15821 from BryanCutler/wip-toPandas_with_arrow-SPARK-13534. --- bin/pyspark | 2 +- dev/deps/spark-deps-hadoop-2.6 | 5 + dev/deps/spark-deps-hadoop-2.7 | 5 + dev/run-pip-tests | 6 + pom.xml | 20 + python/pyspark/serializers.py | 17 + python/pyspark/sql/dataframe.py | 48 +- python/pyspark/sql/tests.py | 79 +- .../apache/spark/sql/internal/SQLConf.scala | 22 + sql/core/pom.xml | 4 + .../scala/org/apache/spark/sql/Dataset.scala | 20 + .../sql/execution/arrow/ArrowConverters.scala | 429 ++++++ .../arrow/ArrowConvertersSuite.scala | 1222 +++++++++++++++++ 13 files changed, 1866 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala diff --git a/bin/pyspark b/bin/pyspark index 98387c2ec5b8a..8eeea7716cc98 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m "$1" + exec "$PYSPARK_DRIVER_PYTHON" -m "$@" exit fi diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 9287bd47cf113..9868c1ab7c2ab 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -13,6 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar +arrow-format-0.4.0.jar +arrow-memory-0.4.0.jar +arrow-vector-0.4.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -55,6 +58,7 @@ datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -77,6 +81,7 @@ hadoop-yarn-server-web-proxy-2.6.5.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar +hppc-0.7.1.jar htrace-core-3.0.4.jar httpclient-4.5.2.jar httpcore-4.4.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 9127413ab6c23..57c78cfe12087 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -13,6 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar +arrow-format-0.4.0.jar +arrow-memory-0.4.0.jar +arrow-vector-0.4.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -55,6 +58,7 @@ datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -77,6 +81,7 @@ hadoop-yarn-server-web-proxy-2.7.3.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar +hppc-0.7.1.jar htrace-core-3.1.0-incubating.jar httpclient-4.5.2.jar httpcore-4.4.4.jar diff --git a/dev/run-pip-tests b/dev/run-pip-tests index d51dde12a03c5..225e9209536f0 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -83,6 +83,8 @@ for python in "${PYTHON_EXECS[@]}"; do if [ -n "$USE_CONDA" ]; then conda create -y -p "$VIRTUALENV_PATH" python=$python numpy pandas pip setuptools source activate "$VIRTUALENV_PATH" + conda install -y -c conda-forge pyarrow=0.4.0 + TEST_PYARROW=1 else mkdir -p "$VIRTUALENV_PATH" virtualenv --python=$python "$VIRTUALENV_PATH" @@ -120,6 +122,10 @@ for python in "${PYTHON_EXECS[@]}"; do python "$FWDIR"/dev/pip-sanity-check.py echo "Run the tests for context.py" python "$FWDIR"/python/pyspark/context.py + if [ -n "$TEST_PYARROW" ]; then + echo "Run tests for pyarrow" + SPARK_TESTING=1 "$FWDIR"/bin/pyspark pyspark.sql.tests ArrowTests + fi cd "$FWDIR" diff --git a/pom.xml b/pom.xml index 5f524079495c0..f124ba45007b7 100644 --- a/pom.xml +++ b/pom.xml @@ -181,6 +181,7 @@ 2.6 1.8 1.0.0 + 0.4.0 ${java.home} @@ -1878,6 +1879,25 @@ paranamer ${paranamer.version} + + org.apache.arrow + arrow-vector + ${arrow.version} + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-databind + + + io.netty + netty-handler + + + diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ea5e00e9eeef5..d5c2a7518b18f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -182,6 +182,23 @@ def loads(self, obj): raise NotImplementedError +class ArrowSerializer(FramedSerializer): + """ + Serializes an Arrow stream. + """ + + def dumps(self, obj): + raise NotImplementedError + + def loads(self, obj): + import pyarrow as pa + reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) + return reader.read_all() + + def __repr__(self): + return "ArrowSerializer" + + class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0649271ed2246..760f113dfd197 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,7 +29,8 @@ from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer +from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ + UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -1708,7 +1709,8 @@ def toDF(self, *cols): @since(1.3) def toPandas(self): - """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + """ + Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. This is only available if Pandas is installed and available. @@ -1721,18 +1723,42 @@ def toPandas(self): 1 5 Bob """ import pandas as pd + if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true": + try: + import pyarrow + tables = self._collectAsArrow() + if tables: + table = pyarrow.concat_tables(tables) + return table.to_pandas() + else: + return pd.DataFrame.from_records([], columns=self.columns) + except ImportError as e: + msg = "note: pyarrow must be installed and available on calling Python process " \ + "if using spark.sql.execution.arrow.enable=true" + raise ImportError("%s\n%s" % (e.message, msg)) + else: + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + if pandas_type is not None: + dtype[field.name] = pandas_type - dtype = {} - for field in self.schema: - pandas_type = _to_corrected_pandas_type(field.dataType) - if pandas_type is not None: - dtype[field.name] = pandas_type + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + return pdf - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) - return pdf + def _collectAsArrow(self): + """ + Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed + and available. + + .. note:: Experimental. + """ + with SCCallSiteSync(self._sc) as css: + port = self._jdf.collectAsArrowToPython() + return list(_load_from_socket(port, ArrowSerializer())) ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0a1cd6856b8e8..326e8548a617c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -58,12 +58,21 @@ from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type -from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests +from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException +_have_arrow = False +try: + import pyarrow + _have_arrow = True +except: + # No Arrow, but that's okay, we'll skip those tests + pass + + class UTCOffsetTimezone(datetime.tzinfo): """ Specifies timezone in UTC offset @@ -2620,6 +2629,74 @@ def range_frame_match(): importlib.reload(window) + +@unittest.skipIf(not _have_arrow, "Arrow not installed") +class ArrowTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + cls.spark.conf.set("spark.sql.execution.arrow.enable", "true") + cls.schema = StructType([ + StructField("1_str_t", StringType(), True), + StructField("2_int_t", IntegerType(), True), + StructField("3_long_t", LongType(), True), + StructField("4_float_t", FloatType(), True), + StructField("5_double_t", DoubleType(), True)]) + cls.data = [("a", 1, 10, 0.2, 2.0), + ("b", 2, 20, 0.4, 4.0), + ("c", 3, 30, 0.8, 6.0)] + + def assertFramesEqual(self, df_with_arrow, df_without): + msg = ("DataFrame from Arrow is not equal" + + ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + + ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) + self.assertTrue(df_without.equals(df_with_arrow), msg=msg) + + def test_unsupported_datatype(self): + schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)]) + df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: df.toPandas()) + + def test_null_conversion(self): + df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + + self.data) + pdf = df_null.toPandas() + null_counts = pdf.isnull().sum().tolist() + self.assertTrue(all([c == 1 for c in null_counts])) + + def test_toPandas_arrow_toggle(self): + df = self.spark.createDataFrame(self.data, schema=self.schema) + self.spark.conf.set("spark.sql.execution.arrow.enable", "false") + pdf = df.toPandas() + self.spark.conf.set("spark.sql.execution.arrow.enable", "true") + pdf_arrow = df.toPandas() + self.assertFramesEqual(pdf_arrow, pdf) + + def test_pandas_round_trip(self): + import pandas as pd + import numpy as np + data_dict = {} + for j, name in enumerate(self.schema.names): + data_dict[name] = [self.data[i][j] for i in range(len(self.data))] + # need to convert these to numpy types first + data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) + data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) + pdf = pd.DataFrame(data=data_dict) + df = self.spark.createDataFrame(self.data, schema=self.schema) + pdf_arrow = df.toPandas() + self.assertFramesEqual(pdf_arrow, pdf) + + def test_filtered_frame(self): + df = self.spark.range(3).toDF("i") + pdf = df.filter("i < 0").toPandas() + self.assertEqual(len(pdf.columns), 1) + self.assertEqual(pdf.columns[0], "i") + self.assertTrue(pdf.empty) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6ab3a615e6cc0..e609256db2802 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -846,6 +846,24 @@ object SQLConf { .intConf .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + val ARROW_EXECUTION_ENABLE = + buildConf("spark.sql.execution.arrow.enable") + .internal() + .doc("Make use of Apache Arrow for columnar data transfers. Currently available " + + "for use with pyspark.sql.DataFrame.toPandas with the following data types: " + + "StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, " + + "LongType, ShortType") + .booleanConf + .createWithDefault(false) + + val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = + buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") + .internal() + .doc("When using Apache Arrow, limit the maximum number of records that can be written " + + "to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.") + .intConf + .createWithDefault(10000) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1104,6 +1122,10 @@ class SQLConf extends Serializable with Logging { def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO) + def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) + + def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 1bc34a6b069d9..661c31ded7148 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -103,6 +103,10 @@ jackson-databind ${fasterxml.jackson.version} + + org.apache.arrow + arrow-vector + org.apache.xbean xbean-asm5-shaded diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d28ff7888d127..a2af9c2efe2ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython @@ -2922,6 +2923,16 @@ class Dataset[T] private[sql]( } } + /** + * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + */ + private[sql] def collectAsArrowToPython(): Int = { + withNewExecutionId { + val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) + PythonRDD.serveIterator(iter, "serve-Arrow") + } + } + private[sql] def toPythonIterator(): Int = { withNewExecutionId { PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) @@ -3003,4 +3014,13 @@ class Dataset[T] private[sql]( Dataset(sparkSession, logicalPlan) } } + + /** Convert to an RDD of ArrowPayload byte arrays */ + private[sql] def toArrowPayload: RDD[ArrowPayload] = { + val schemaCaptured = this.schema + val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch + queryExecution.toRdd.mapPartitionsInternal { iter => + ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala new file mode 100644 index 0000000000000..6af5c73422377 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -0,0 +1,429 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.arrow + +import java.io.ByteArrayOutputStream +import java.nio.channels.Channels + +import scala.collection.JavaConverters._ + +import io.netty.buffer.ArrowBuf +import org.apache.arrow.memory.{BufferAllocator, RootAllocator} +import org.apache.arrow.vector._ +import org.apache.arrow.vector.BaseValueVector.BaseMutator +import org.apache.arrow.vector.file._ +import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} +import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + + +/** + * Store Arrow data in a form that can be serialized by Spark and served to a Python process. + */ +private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable { + + /** + * Convert the ArrowPayload to an ArrowRecordBatch. + */ + def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { + ArrowConverters.byteArrayToBatch(payload, allocator) + } + + /** + * Get the ArrowPayload as a type that can be served to Python. + */ + def asPythonSerializable: Array[Byte] = payload +} + +private[sql] object ArrowPayload { + + /** + * Create an ArrowPayload from an ArrowRecordBatch and Spark schema. + */ + def apply( + batch: ArrowRecordBatch, + schema: StructType, + allocator: BufferAllocator): ArrowPayload = { + new ArrowPayload(ArrowConverters.batchToByteArray(batch, schema, allocator)) + } +} + +private[sql] object ArrowConverters { + + /** + * Map a Spark DataType to ArrowType. + */ + private[arrow] def sparkTypeToArrowType(dataType: DataType): ArrowType = { + dataType match { + case BooleanType => ArrowType.Bool.INSTANCE + case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true) + case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true) + case LongType => new ArrowType.Int(8 * LongType.defaultSize, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case ByteType => new ArrowType.Int(8, true) + case StringType => ArrowType.Utf8.INSTANCE + case BinaryType => ArrowType.Binary.INSTANCE + case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") + } + } + + /** + * Convert a Spark Dataset schema to Arrow schema. + */ + private[arrow] def schemaToArrowSchema(schema: StructType): Schema = { + val arrowFields = schema.fields.map { f => + new Field(f.name, f.nullable, sparkTypeToArrowType(f.dataType), List.empty[Field].asJava) + } + new Schema(arrowFields.toList.asJava) + } + + /** + * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload + * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. + */ + private[sql] def toPayloadIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Int): Iterator[ArrowPayload] = { + new Iterator[ArrowPayload] { + private val _allocator = new RootAllocator(Long.MaxValue) + private var _nextPayload = if (rowIter.nonEmpty) convert() else null + + override def hasNext: Boolean = _nextPayload != null + + override def next(): ArrowPayload = { + val obj = _nextPayload + if (hasNext) { + if (rowIter.hasNext) { + _nextPayload = convert() + } else { + _allocator.close() + _nextPayload = null + } + } + obj + } + + private def convert(): ArrowPayload = { + val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch) + ArrowPayload(batch, schema, _allocator) + } + } + } + + /** + * Iterate over InternalRows and write to an ArrowRecordBatch, stopping when rowIter is consumed + * or the number of records in the batch equals maxRecordsInBatch. If maxRecordsPerBatch is 0, + * then rowIter will be fully consumed. + */ + private def internalRowIterToArrowBatch( + rowIter: Iterator[InternalRow], + schema: StructType, + allocator: BufferAllocator, + maxRecordsPerBatch: Int = 0): ArrowRecordBatch = { + + val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) => + ColumnWriter(field.dataType, ordinal, allocator).init() + } + + val writerLength = columnWriters.length + var recordsInBatch = 0 + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) { + val row = rowIter.next() + var i = 0 + while (i < writerLength) { + columnWriters(i).write(row) + i += 1 + } + recordsInBatch += 1 + } + + val (fieldNodes, bufferArrays) = columnWriters.map(_.finish()).unzip + val buffers = bufferArrays.flatten + + val rowLength = if (fieldNodes.nonEmpty) fieldNodes.head.getLength else 0 + val recordBatch = new ArrowRecordBatch(rowLength, + fieldNodes.toList.asJava, buffers.toList.asJava) + + buffers.foreach(_.release()) + recordBatch + } + + /** + * Convert an ArrowRecordBatch to a byte array and close batch to release resources. Once closed, + * the batch can no longer be used. + */ + private[arrow] def batchToByteArray( + batch: ArrowRecordBatch, + schema: StructType, + allocator: BufferAllocator): Array[Byte] = { + val arrowSchema = ArrowConverters.schemaToArrowSchema(schema) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val out = new ByteArrayOutputStream() + val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) + + // Write a batch to byte stream, ensure the batch, allocator and writer are closed + Utils.tryWithSafeFinally { + val loader = new VectorLoader(root) + loader.load(batch) + writer.writeBatch() // writeBatch can throw IOException + } { + batch.close() + root.close() + writer.close() + } + out.toByteArray + } + + /** + * Convert a byte array to an ArrowRecordBatch. + */ + private[arrow] def byteArrayToBatch( + batchBytes: Array[Byte], + allocator: BufferAllocator): ArrowRecordBatch = { + val in = new ByteArrayReadableSeekableByteChannel(batchBytes) + val reader = new ArrowFileReader(in, allocator) + + // Read a batch from a byte stream, ensure the reader is closed + Utils.tryWithSafeFinally { + val root = reader.getVectorSchemaRoot // throws IOException + val unloader = new VectorUnloader(root) + reader.loadNextBatch() // throws IOException + unloader.getRecordBatch + } { + reader.close() + } + } +} + +/** + * Interface for writing InternalRows to Arrow Buffers. + */ +private[arrow] trait ColumnWriter { + def init(): this.type + def write(row: InternalRow): Unit + + /** + * Clear the column writer and return the ArrowFieldNode and ArrowBuf. + * This should be called only once after all the data is written. + */ + def finish(): (ArrowFieldNode, Array[ArrowBuf]) +} + +/** + * Base class for flat arrow column writer, i.e., column without children. + */ +private[arrow] abstract class PrimitiveColumnWriter(val ordinal: Int) + extends ColumnWriter { + + def getFieldType(dtype: ArrowType): FieldType = FieldType.nullable(dtype) + + def valueVector: BaseDataValueVector + def valueMutator: BaseMutator + + def setNull(): Unit + def setValue(row: InternalRow): Unit + + protected var count = 0 + protected var nullCount = 0 + + override def init(): this.type = { + valueVector.allocateNew() + this + } + + override def write(row: InternalRow): Unit = { + if (row.isNullAt(ordinal)) { + setNull() + nullCount += 1 + } else { + setValue(row) + } + count += 1 + } + + override def finish(): (ArrowFieldNode, Array[ArrowBuf]) = { + valueMutator.setValueCount(count) + val fieldNode = new ArrowFieldNode(count, nullCount) + val valueBuffers = valueVector.getBuffers(true) + (fieldNode, valueBuffers) + } +} + +private[arrow] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableBitVector + = new NullableBitVector("BooleanValue", getFieldType(dtype), allocator) + override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, if (row.getBoolean(ordinal)) 1 else 0 ) +} + +private[arrow] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableSmallIntVector + = new NullableSmallIntVector("ShortValue", getFieldType(dtype: ArrowType), allocator) + override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getShort(ordinal)) +} + +private[arrow] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableIntVector + = new NullableIntVector("IntValue", getFieldType(dtype), allocator) + override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getInt(ordinal)) +} + +private[arrow] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableBigIntVector + = new NullableBigIntVector("LongValue", getFieldType(dtype), allocator) + override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getLong(ordinal)) +} + +private[arrow] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableFloat4Vector + = new NullableFloat4Vector("FloatValue", getFieldType(dtype), allocator) + override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getFloat(ordinal)) +} + +private[arrow] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableFloat8Vector + = new NullableFloat8Vector("DoubleValue", getFieldType(dtype), allocator) + override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getDouble(ordinal)) +} + +private[arrow] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableUInt1Vector + = new NullableUInt1Vector("ByteValue", getFieldType(dtype), allocator) + override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit + = valueMutator.setSafe(count, row.getByte(ordinal)) +} + +private[arrow] class UTF8StringColumnWriter( + dtype: ArrowType, + ordinal: Int, + allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableVarCharVector + = new NullableVarCharVector("UTF8StringValue", getFieldType(dtype), allocator) + override val valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit = { + val str = row.getUTF8String(ordinal) + valueMutator.setSafe(count, str.getByteBuffer, 0, str.numBytes) + } +} + +private[arrow] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableVarBinaryVector + = new NullableVarBinaryVector("BinaryValue", getFieldType(dtype), allocator) + override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit = { + val bytes = row.getBinary(ordinal) + valueMutator.setSafe(count, bytes, 0, bytes.length) + } +} + +private[arrow] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableDateDayVector + = new NullableDateDayVector("DateValue", getFieldType(dtype), allocator) + override val valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit = { + valueMutator.setSafe(count, row.getInt(ordinal)) + } +} + +private[arrow] class TimeStampColumnWriter( + dtype: ArrowType, + ordinal: Int, + allocator: BufferAllocator) + extends PrimitiveColumnWriter(ordinal) { + override val valueVector: NullableTimeStampMicroVector + = new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator) + override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow): Unit = { + valueMutator.setSafe(count, row.getLong(ordinal)) + } +} + +private[arrow] object ColumnWriter { + + /** + * Create an Arrow ColumnWriter given the type and ordinal of row. + */ + def apply(dataType: DataType, ordinal: Int, allocator: BufferAllocator): ColumnWriter = { + val dtype = ArrowConverters.sparkTypeToArrowType(dataType) + dataType match { + case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator) + case ShortType => new ShortColumnWriter(dtype, ordinal, allocator) + case IntegerType => new IntegerColumnWriter(dtype, ordinal, allocator) + case LongType => new LongColumnWriter(dtype, ordinal, allocator) + case FloatType => new FloatColumnWriter(dtype, ordinal, allocator) + case DoubleType => new DoubleColumnWriter(dtype, ordinal, allocator) + case ByteType => new ByteColumnWriter(dtype, ordinal, allocator) + case StringType => new UTF8StringColumnWriter(dtype, ordinal, allocator) + case BinaryType => new BinaryColumnWriter(dtype, ordinal, allocator) + case DateType => new DateColumnWriter(dtype, ordinal, allocator) + case TimestampType => new TimeStampColumnWriter(dtype, ordinal, allocator) + case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala new file mode 100644 index 0000000000000..159328cc0d958 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -0,0 +1,1222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.arrow + +import java.io.File +import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} +import java.text.SimpleDateFormat +import java.util.Locale + +import com.google.common.io.Files +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} +import org.apache.arrow.vector.file.json.JsonFileReader +import org.apache.arrow.vector.util.Validator +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkException +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{BinaryType, StructField, StructType} +import org.apache.spark.util.Utils + + +class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { + import testImplicits._ + + private var tempDataPath: String = _ + + override def beforeAll(): Unit = { + super.beforeAll() + tempDataPath = Utils.createTempDir(namePrefix = "arrow").getAbsolutePath + } + + test("collect to arrow record batch") { + val indexData = (1 to 6).toDF("i") + val arrowPayloads = indexData.toArrowPayload.collect() + assert(arrowPayloads.nonEmpty) + assert(arrowPayloads.length == indexData.rdd.getNumPartitions) + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val rowCount = arrowRecordBatches.map(_.getLength).sum + assert(rowCount === indexData.count()) + arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0)) + arrowRecordBatches.foreach(_.close()) + allocator.close() + } + + test("short conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_s", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | }, { + | "name" : "b_s", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_s", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 32767, -32768 ] + | }, { + | "name" : "b_s", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -32768 ] + | } ] + | } ] + |} + """.stripMargin + + val a_s = List[Short](1, -1, 2, -2, 32767, -32768) + val b_s = List[Option[Short]](Some(1), None, None, Some(-2), None, Some(-32768)) + val df = a_s.zip(b_s).toDF("a_s", "b_s") + + collectAndValidate(df, json, "integer-16bit.json") + } + + test("int conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + + val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) + val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) + val df = a_i.zip(b_i).toDF("a_i", "b_i") + + collectAndValidate(df, json, "integer-32bit.json") + } + + test("long conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_l", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "b_l", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_l", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 9223372036854775807, -9223372036854775808 ] + | }, { + | "name" : "b_l", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -9223372036854775808 ] + | } ] + | } ] + |} + """.stripMargin + + val a_l = List[Long](1, -1, 2, -2, 9223372036854775807L, -9223372036854775808L) + val b_l = List[Option[Long]](Some(1), None, None, Some(-2), None, Some(-9223372036854775808L)) + val df = a_l.zip(b_l).toDF("a_l", "b_l") + + collectAndValidate(df, json, "integer-64bit.json") + } + + test("float conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_f", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0 ] + | }, { + | "name" : "b_f", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] + | } ] + | } ] + |} + """.stripMargin + + val a_f = List(1.0f, 2.0f, 0.01f, 200.0f, 0.0001f, 20000.0f) + val b_f = List[Option[Float]](Some(1.1f), None, None, Some(2.2f), None, Some(3.3f)) + val df = a_f.zip(b_f).toDF("a_f", "b_f") + + collectAndValidate(df, json, "floating_point-single_precision.json") + } + + test("double conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "b_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_d", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 1.0E-4, 20000.0 ] + | }, { + | "name" : "b_d", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] + | } ] + | } ] + |} + """.stripMargin + + val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0) + val b_d = List[Option[Double]](Some(1.1), None, None, Some(2.2), None, Some(3.3)) + val df = a_d.zip(b_d).toDF("a_d", "b_d") + + collectAndValidate(df, json, "floating_point-double_precision.json") + } + + test("index conversion") { + val data = List[Int](1, 2, 3, 4, 5, 6) + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | } ] + | } ] + |} + """.stripMargin + val df = data.toDF("i") + + collectAndValidate(df, json, "indexData-ints.json") + } + + test("mixed numeric type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 16 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 16 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "c", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "e", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 64 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | }, { + | "name" : "b", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] + | }, { + | "name" : "c", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | }, { + | "name" : "d", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] + | }, { + | "name" : "e", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5, 6 ] + | } ] + | } ] + |} + """.stripMargin + + val data = List(1, 2, 3, 4, 5, 6) + val data_tuples = for (d <- data) yield { + (d.toShort, d.toFloat, d.toInt, d.toDouble, d.toLong) + } + val df = data_tuples.toDF("a", "b", "c", "d", "e") + + collectAndValidate(df, json, "mixed_numeric_types.json") + } + + test("string type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "upper_case", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | }, { + | "name" : "lower_case", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | }, { + | "name" : "null_str", + | "type" : { + | "name" : "utf8" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "upper_case", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 1, 2, 3 ], + | "DATA" : [ "A", "B", "C" ] + | }, { + | "name" : "lower_case", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 1, 2, 3 ], + | "DATA" : [ "a", "b", "c" ] + | }, { + | "name" : "null_str", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 0 ], + | "OFFSET" : [ 0, 2, 5, 5 ], + | "DATA" : [ "ab", "CDE", "" ] + | } ] + | } ] + |} + """.stripMargin + + val upperCase = Seq("A", "B", "C") + val lowerCase = Seq("a", "b", "c") + val nullStr = Seq("ab", "CDE", null) + val df = (upperCase, lowerCase, nullStr).zipped.toList + .toDF("upper_case", "lower_case", "null_str") + + collectAndValidate(df, json, "stringData.json") + } + + test("boolean type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_bool", + | "type" : { + | "name" : "bool" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 1 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "a_bool", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ true, true, false, true ] + | } ] + | } ] + |} + """.stripMargin + val df = Seq(true, true, false, true).toDF("a_bool") + collectAndValidate(df, json, "boolData.json") + } + + test("byte type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_byte", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 8 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "a_byte", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 64, 127 ] + | } ] + | } ] + |} + | + """.stripMargin + val df = List[Byte](1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue).toDF("a_byte") + collectAndValidate(df, json, "byteData.json") + } + + test("binary type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_binary", + | "type" : { + | "name" : "binary" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 8 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a_binary", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "OFFSET" : [ 0, 3, 4, 6 ], + | "DATA" : [ "616263", "64", "6566" ] + | } ] + | } ] + |} + """.stripMargin + + val data = Seq("abc", "d", "ef") + val rdd = sparkContext.parallelize(data.map(s => Row(s.getBytes("utf-8")))) + val df = spark.createDataFrame(rdd, StructType(Seq(StructField("a_binary", BinaryType)))) + + collectAndValidate(df, json, "binaryData.json") + } + + test("floating-point NaN") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "NaN_f", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "SINGLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "NaN_d", + | "type" : { + | "name" : "floatingpoint", + | "precision" : "DOUBLE" + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 2, + | "columns" : [ { + | "name" : "NaN_f", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ 1.2000000476837158, "NaN" ] + | }, { + | "name" : "NaN_d", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ "NaN", 1.2 ] + | } ] + | } ] + |} + """.stripMargin + + val fnan = Seq(1.2F, Float.NaN) + val dnan = Seq(Double.NaN, 1.2) + val df = fnan.zip(dnan).toDF("NaN_f", "NaN_d") + + collectAndValidate(df, json, "nanData-floating_point.json") + } + + test("partitioned DataFrame") { + val json1 = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 1, 2 ] + | }, { + | "name" : "b", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 2, 1 ] + | } ] + | } ] + |} + """.stripMargin + val json2 = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 2, 3, 3 ] + | }, { + | "name" : "b", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 2, 1, 2 ] + | } ] + | } ] + |} + """.stripMargin + + val arrowPayloads = testData2.toArrowPayload.collect() + // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload + assert(arrowPayloads.length === 2) + val schema = testData2.schema + + val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") + val tempFile2 = new File(tempDataPath, "testData2-ints-part2.json") + Files.write(json1, tempFile1, StandardCharsets.UTF_8) + Files.write(json2, tempFile2, StandardCharsets.UTF_8) + + validateConversion(schema, arrowPayloads(0), tempFile1) + validateConversion(schema, arrowPayloads(1), tempFile2) + } + + test("empty frame collect") { + val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() + assert(arrowPayload.isEmpty) + + val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") + val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect() + assert(filteredArrowPayload.isEmpty) + } + + test("empty partition collect") { + val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") + val arrowPayloads = emptyPart.toArrowPayload.collect() + assert(arrowPayloads.length === 1) + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + assert(arrowRecordBatches.head.getLength == 1) + arrowRecordBatches.foreach(_.close()) + allocator.close() + } + + test("max records in batch conf") { + val totalRecords = 10 + val maxRecordsPerBatch = 3 + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) + val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") + val arrowPayloads = df.toArrowPayload.collect() + val allocator = new RootAllocator(Long.MaxValue) + val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + var recordCount = 0 + arrowRecordBatches.foreach { batch => + assert(batch.getLength > 0) + assert(batch.getLength <= maxRecordsPerBatch) + recordCount += batch.getLength + batch.close() + } + assert(recordCount == totalRecords) + allocator.close() + spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") + } + + testQuietly("unsupported types") { + def runUnsupported(block: => Unit): Unit = { + val msg = intercept[SparkException] { + block + } + assert(msg.getMessage.contains("Unsupported data type")) + assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) + } + + runUnsupported { decimalData.toArrowPayload.collect() } + runUnsupported { arrayData.toDF().toArrowPayload.collect() } + runUnsupported { mapData.toDF().toArrowPayload.collect() } + runUnsupported { complexData.toArrowPayload.collect() } + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) + val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) + runUnsupported { Seq(d1, d2).toDF("date").toArrowPayload.collect() } + + val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) + runUnsupported { Seq(ts1, ts2).toDF("timestamp").toArrowPayload.collect() } + } + + test("test Arrow Validator") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + val json_diff_col_order = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "b_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "a_i", + | "type" : { + | "name" : "int", + | "isSigned" : true, + | "bitWidth" : 32 + | }, + | "nullable" : false, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_i", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] + | }, { + | "name" : "b_i", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] + | } ] + | } ] + |} + """.stripMargin + + val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) + val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) + val df = a_i.zip(b_i).toDF("a_i", "b_i") + + // Different schema + intercept[IllegalArgumentException] { + collectAndValidate(df, json_diff_col_order, "validator_diff_schema.json") + } + + // Different values + intercept[IllegalArgumentException] { + collectAndValidate(df.sort($"a_i".desc), json, "validator_diff_values.json") + } + } + + /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ + private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = { + // NOTE: coalesce to single partition because can only load 1 batch in validator + val arrowPayload = df.coalesce(1).toArrowPayload.collect().head + val tempFile = new File(tempDataPath, file) + Files.write(json, tempFile, StandardCharsets.UTF_8) + validateConversion(df.schema, arrowPayload, tempFile) + } + + private def validateConversion( + sparkSchema: StructType, + arrowPayload: ArrowPayload, + jsonFile: File): Unit = { + val allocator = new RootAllocator(Long.MaxValue) + val jsonReader = new JsonFileReader(jsonFile, allocator) + + val arrowSchema = ArrowConverters.schemaToArrowSchema(sparkSchema) + val jsonSchema = jsonReader.start() + Validator.compareSchemas(arrowSchema, jsonSchema) + + val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) + val vectorLoader = new VectorLoader(arrowRoot) + val arrowRecordBatch = arrowPayload.loadBatch(allocator) + vectorLoader.load(arrowRecordBatch) + val jsonRoot = jsonReader.read() + Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) + + jsonRoot.close() + jsonReader.close() + arrowRecordBatch.close() + arrowRoot.close() + allocator.close() + } +} From 5b5a69bea9de806e2c39b04b248ee82a7b664d7b Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 23 Jun 2017 09:19:02 +0800 Subject: [PATCH 061/118] [SPARK-20923] turn tracking of TaskMetrics._updatedBlockStatuses off ## What changes were proposed in this pull request? Turn tracking of TaskMetrics._updatedBlockStatuses off by default. As far as I can see its not used by anything and it uses a lot of memory when caching and processing a lot of blocks. In my case it was taking 5GB of a 10GB heap and I even went up to 50GB heap and the job still ran out of memory. With this change in place the same job easily runs in less then 10GB of heap. We leave the api there as well as a config to turn it back on just in case anyone is using it. TaskMetrics is exposed via SparkListenerTaskEnd so if users are relying on it they can turn it back on. ## How was this patch tested? Ran unit tests that were modified and manually tested on a couple of jobs (with and without caching). Clicked through the UI and didn't see anything missing. Ran my very large hive query job with 200,000 small tasks, 1000 executors, cached 6+TB of data this runs fine now whereas without this change it would go into full gcs and eventually die. Author: Thomas Graves Author: Tom Graves Closes #18162 from tgravescs/SPARK-20923. --- .../apache/spark/executor/TaskMetrics.scala | 6 ++++ .../spark/internal/config/package.scala | 8 +++++ .../apache/spark/storage/BlockManager.scala | 6 ++-- .../spark/storage/BlockManagerSuite.scala | 32 ++++++++++++++++++- 4 files changed, 49 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 341a6da8107ef..85b2745a2aec4 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -112,6 +112,12 @@ class TaskMetrics private[spark] () extends Serializable { /** * Storage statuses of any blocks that have been updated as a result of this task. + * + * Tracking the _updatedBlockStatuses can use a lot of memory. + * It is not used anywhere inside of Spark so we would ideally remove it, but its exposed to + * the user in SparkListenerTaskEnd so the api is kept for compatibility. + * Tracking can be turned off to save memory via config + * TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES. */ def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = { // This is called on driver. All accumulator updates have a fixed value. So it's safe to use diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 615497d36fd14..462c1890fd8df 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -322,4 +322,12 @@ package object config { "above this threshold. This is to avoid a giant request takes too much memory.") .bytesConf(ByteUnit.BYTE) .createWithDefaultString("200m") + + private[spark] val TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES = + ConfigBuilder("spark.taskMetrics.trackUpdatedBlockStatuses") + .doc("Enable tracking of updatedBlockStatuses in the TaskMetrics. Off by default since " + + "tracking the block statuses can use a lot of memory and its not used anywhere within " + + "spark.") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 74be70348305c..adbe3cfd89ea6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1473,8 +1473,10 @@ private[spark] class BlockManager( } private def addUpdatedBlockStatusToTaskMetrics(blockId: BlockId, status: BlockStatus): Unit = { - Option(TaskContext.get()).foreach { c => - c.taskMetrics().incUpdatedBlockStatuses(blockId -> status) + if (conf.get(config.TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES)) { + Option(TaskContext.get()).foreach { c => + c.taskMetrics().incUpdatedBlockStatuses(blockId -> status) + } } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 88f18294aa015..086adccea954c 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -922,8 +922,38 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } } + test("turn off updated block statuses") { + val conf = new SparkConf() + conf.set(TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES, false) + store = makeBlockManager(12000, testConf = Some(conf)) + + store.registerTask(0) + val list = List.fill(2)(new Array[Byte](2000)) + + def getUpdatedBlocks(task: => Unit): Seq[(BlockId, BlockStatus)] = { + val context = TaskContext.empty() + try { + TaskContext.setTaskContext(context) + task + } finally { + TaskContext.unset() + } + context.taskMetrics.updatedBlockStatuses + } + + // 1 updated block (i.e. list1) + val updatedBlocks1 = getUpdatedBlocks { + store.putIterator( + "list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + } + assert(updatedBlocks1.size === 0) + } + + test("updated block statuses") { - store = makeBlockManager(12000) + val conf = new SparkConf() + conf.set(TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES, true) + store = makeBlockManager(12000, testConf = Some(conf)) store.registerTask(0) val list = List.fill(2)(new Array[Byte](2000)) val bigList = List.fill(8)(new Array[Byte](2000)) From b8a743b6a531432e57eb50ecff06798ebc19483e Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Fri, 23 Jun 2017 09:27:35 +0800 Subject: [PATCH 062/118] [SPARK-21174][SQL] Validate sampling fraction in logical operator level ## What changes were proposed in this pull request? Currently the validation of sampling fraction in dataset is incomplete. As an improvement, validate sampling fraction in logical operator level: 1) if with replacement: fraction should be nonnegative 2) else: fraction should be on interval [0, 1] Also add test cases for the validation. ## How was this patch tested? integration tests gatorsmile cloud-fan Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Wang Gengliang Closes #18387 from gengliangwang/sample_ratio_validate. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 3 +- .../plans/logical/basicLogicalOperators.scala | 13 ++++ .../scala/org/apache/spark/sql/Dataset.scala | 3 - .../sql-tests/inputs/tablesample-negative.sql | 14 +++++ .../sql-tests/results/operators.sql.out | 8 +-- .../results/tablesample-negative.sql.out | 62 +++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 28 +++++++++ 8 files changed, 124 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/tablesample-negative.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/tablesample-negative.sql.out diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index ef5648c6dbe47..9456031736528 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -440,7 +440,7 @@ joinCriteria sample : TABLESAMPLE '(' - ( (percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) + ( (negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) | (expression sampleType=ROWS) | sampleType=BYTELENGTH_LITERAL | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON (identifier | qualifiedName '(' ')'))?)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 500d999c30da7..315c6721b3f65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -636,7 +636,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.PERCENTLIT => val fraction = ctx.percentage.getText.toDouble - sample(fraction / 100.0d) + val sign = if (ctx.negativeSign == null) 1 else -1 + sample(sign * fraction / 100.0d) case SqlBaseParser.BYTELENGTH_LITERAL => throw new ParseException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 6878b6b179c3a..6e88b7a57dc33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +import org.apache.spark.util.random.RandomSampler /** * When planning take() or collect() operations, this special node that is inserted at the top of @@ -817,6 +818,18 @@ case class Sample( child: LogicalPlan)( val isTableSample: java.lang.Boolean = false) extends UnaryNode { + val eps = RandomSampler.roundingEpsilon + val fraction = upperBound - lowerBound + if (withReplacement) { + require( + fraction >= 0.0 - eps, + s"Sampling fraction ($fraction) must be nonnegative with replacement") + } else { + require( + fraction >= 0.0 - eps && fraction <= 1.0 + eps, + s"Sampling fraction ($fraction) must be on interval [0, 1] without replacement") + } + override def output: Seq[Attribute] = child.output override def computeStats(conf: SQLConf): Statistics = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a2af9c2efe2ab..767dad3e63a6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1806,9 +1806,6 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { - require(fraction >= 0, - s"Fraction must be nonnegative, but got ${fraction}") - withTypedPlan { Sample(0.0, fraction, withReplacement, seed, logicalPlan)() } diff --git a/sql/core/src/test/resources/sql-tests/inputs/tablesample-negative.sql b/sql/core/src/test/resources/sql-tests/inputs/tablesample-negative.sql new file mode 100644 index 0000000000000..72508f59bee27 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/tablesample-negative.sql @@ -0,0 +1,14 @@ +-- Negative testcases for tablesample +CREATE DATABASE mydb1; +USE mydb1; +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1; + +-- Negative tests: negative percentage +SELECT mydb1.t1 FROM t1 TABLESAMPLE (-1 PERCENT); + +-- Negative tests: percentage over 100 +-- The TABLESAMPLE clause samples without replacement, so the value of PERCENT must not exceed 100 +SELECT mydb1.t1 FROM t1 TABLESAMPLE (101 PERCENT); + +-- reset +DROP DATABASE mydb1 CASCADE; diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 5cb6ed3e27bf2..fec423fca5bbe 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 56 +-- Number of queries: 57 -- !query 0 @@ -462,9 +462,9 @@ struct 3.13 2.19 --- !query 55 +-- !query 56 select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11) --- !query 55 schema +-- !query 56 schema struct<(+ CAST(-1.11 AS DOUBLE)):double,(+ -1.11):decimal(3,2),(- CAST(-1.11 AS DOUBLE)):double,(- -1.11):decimal(3,2)> --- !query 55 output +-- !query 56 output -1.11 -1.11 1.11 1.11 diff --git a/sql/core/src/test/resources/sql-tests/results/tablesample-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/tablesample-negative.sql.out new file mode 100644 index 0000000000000..35f3931736b83 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/tablesample-negative.sql.out @@ -0,0 +1,62 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +CREATE DATABASE mydb1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +USE mydb1 +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1 +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +SELECT mydb1.t1 FROM t1 TABLESAMPLE (-1 PERCENT) +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.catalyst.parser.ParseException + +Sampling fraction (-0.01) must be on interval [0, 1](line 1, pos 24) + +== SQL == +SELECT mydb1.t1 FROM t1 TABLESAMPLE (-1 PERCENT) +------------------------^^^ + + +-- !query 4 +SELECT mydb1.t1 FROM t1 TABLESAMPLE (101 PERCENT) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.catalyst.parser.ParseException + +Sampling fraction (1.01) must be on interval [0, 1](line 1, pos 24) + +== SQL == +SELECT mydb1.t1 FROM t1 TABLESAMPLE (101 PERCENT) +------------------------^^^ + + +-- !query 5 +DROP DATABASE mydb1 CASCADE +-- !query 5 schema +struct<> +-- !query 5 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8eb381b91f46d..165176f6c040e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -457,6 +457,34 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 3, 17, 27, 58, 62) } + test("sample fraction should not be negative with replacement") { + val data = sparkContext.parallelize(1 to 2, 1).toDS() + val errMsg = intercept[IllegalArgumentException] { + data.sample(withReplacement = true, -0.1, 0) + }.getMessage + assert(errMsg.contains("Sampling fraction (-0.1) must be nonnegative with replacement")) + + // Sampling fraction can be greater than 1 with replacement. + checkDataset( + data.sample(withReplacement = true, 1.05, seed = 13), + 1, 2) + } + + test("sample fraction should be on interval [0, 1] without replacement") { + val data = sparkContext.parallelize(1 to 2, 1).toDS() + val errMsg1 = intercept[IllegalArgumentException] { + data.sample(withReplacement = false, -0.1, 0) + }.getMessage() + assert(errMsg1.contains( + "Sampling fraction (-0.1) must be on interval [0, 1] without replacement")) + + val errMsg2 = intercept[IllegalArgumentException] { + data.sample(withReplacement = false, 1.1, 0) + }.getMessage() + assert(errMsg2.contains( + "Sampling fraction (1.1) must be on interval [0, 1] without replacement")) + } + test("SPARK-16686: Dataset.sample with seed results shouldn't depend on downstream usage") { val simpleUdf = udf((n: Int) => { require(n != 1, "simpleUdf shouldn't see id=1!") From fe24634d14bc0973ca38222db2f58eafbf0c890d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 23 Jun 2017 00:43:21 -0700 Subject: [PATCH 063/118] [SPARK-21145][SS] Added StateStoreProviderId with queryRunId to reload StateStoreProviders when query is restarted ## What changes were proposed in this pull request? StateStoreProvider instances are loaded on-demand in a executor when a query is started. When a query is restarted, the loaded provider instance will get reused. Now, there is a non-trivial chance, that the task of the previous query run is still running, while the tasks of the restarted run has started. So for a stateful partition, there may be two concurrent tasks related to the same stateful partition, and there for using the same provider instance. This can lead to inconsistent results and possibly random failures, as state store implementations are not designed to be thread-safe. To fix this, I have introduced a `StateStoreProviderId`, that unique identifies a provider loaded in an executor. It has the query run id in it, thus making sure that restarted queries will force the executor to load a new provider instance, thus avoiding two concurrent tasks (from two different runs) from reusing the same provider instance. Additional minor bug fixes - All state stores related to query run is marked as deactivated in the `StateStoreCoordinator` so that the executors can unload them and clear resources. - Moved the code that determined the checkpoint directory of a state store from implementation-specific code (`HDFSBackedStateStoreProvider`) to non-specific code (StateStoreId), so that implementation do not accidentally get it wrong. - Also added store name to the path, to support multiple stores per sql operator partition. *Note:* This change does not address the scenario where two tasks of the same run (e.g. speculative tasks) are concurrently running in the same executor. The chance of this very small, because ideally speculative tasks should never run in the same executor. ## How was this patch tested? Existing unit tests + new unit test. Author: Tathagata Das Closes #18355 from tdas/SPARK-21145. --- .../sql/execution/aggregate/AggUtils.scala | 2 +- .../sql/execution/command/commands.scala | 5 +- .../FlatMapGroupsWithStateExec.scala | 7 +- .../streaming/IncrementalExecution.scala | 27 +++--- .../execution/streaming/StreamExecution.scala | 1 + .../state/HDFSBackedStateStoreProvider.scala | 16 ++-- .../streaming/state/StateStore.scala | 91 +++++++++++++----- .../state/StateStoreCoordinator.scala | 41 ++++---- .../streaming/state/StateStoreRDD.scala | 21 ++++- .../execution/streaming/state/package.scala | 25 ++--- .../streaming/statefulOperators.scala | 38 ++++---- .../sql/streaming/StreamingQueryManager.scala | 1 + .../state/StateStoreCoordinatorSuite.scala | 61 ++++++++++-- .../streaming/state/StateStoreRDDSuite.scala | 51 +++++----- .../streaming/state/StateStoreSuite.scala | 93 +++++++++++++++---- .../spark/sql/streaming/StreamSuite.scala | 2 +- .../spark/sql/streaming/StreamTest.scala | 13 ++- 17 files changed, 329 insertions(+), 166 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index aa789af6f812f..12f8cffb6774a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -311,7 +311,7 @@ object AggUtils { val saved = StateStoreSaveExec( groupingAttributes, - stateId = None, + stateInfo = None, outputMode = None, eventTimeWatermark = None, partialMerged2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 2d82fcf4da6e9..81bc93e7ebcf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.util.UUID + import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -117,7 +119,8 @@ case class ExplainCommand( // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the // output mode does not matter since there is no `Sink`. new IncrementalExecution( - sparkSession, logicalPlan, OutputMode.Append(), "", 0, OffsetSeqMetadata(0, 0)) + sparkSession, logicalPlan, OutputMode.Append(), "", + UUID.randomUUID, 0, OffsetSeqMetadata(0, 0)) } else { sparkSession.sessionState.executePlan(logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 2aad8701a4eca..9dcac33b4107c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -50,7 +50,7 @@ case class FlatMapGroupsWithStateExec( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, - stateId: Option[OperatorStateId], + stateInfo: Option[StatefulOperatorStateInfo], stateEncoder: ExpressionEncoder[Any], outputMode: OutputMode, timeoutConf: GroupStateTimeout, @@ -107,10 +107,7 @@ case class FlatMapGroupsWithStateExec( } child.execute().mapPartitionsWithStateStore[InternalRow]( - getStateId.checkpointLocation, - getStateId.operatorId, - storeName = "default", - getStateId.batchId, + getStateInfo, groupingAttributes.toStructType, stateAttributes.toStructType, indexOrdinal = None, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 622e049630db2..ab89dc6b705d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.internal.Logging @@ -36,6 +37,7 @@ class IncrementalExecution( logicalPlan: LogicalPlan, val outputMode: OutputMode, val checkpointLocation: String, + val runId: UUID, val currentBatchId: Long, offsetSeqMetadata: OffsetSeqMetadata) extends QueryExecution(sparkSession, logicalPlan) with Logging { @@ -69,7 +71,13 @@ class IncrementalExecution( * Records the current id for a given stateful operator in the query plan as the `state` * preparation walks the query plan. */ - private val operatorId = new AtomicInteger(0) + private val statefulOperatorId = new AtomicInteger(0) + + /** Get the state info of the next stateful operator */ + private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = { + StatefulOperatorStateInfo( + checkpointLocation, runId, statefulOperatorId.getAndIncrement(), currentBatchId) + } /** Locates save/restore pairs surrounding aggregation. */ val state = new Rule[SparkPlan] { @@ -78,35 +86,28 @@ class IncrementalExecution( case StateStoreSaveExec(keys, None, None, None, UnaryExecNode(agg, StateStoreRestoreExec(keys2, None, child))) => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - + val aggStateInfo = nextStatefulOperationStateInfo StateStoreSaveExec( keys, - Some(stateId), + Some(aggStateInfo), Some(outputMode), Some(offsetSeqMetadata.batchWatermarkMs), agg.withNewChildren( StateStoreRestoreExec( keys, - Some(stateId), + Some(aggStateInfo), child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - StreamingDeduplicateExec( keys, child, - Some(stateId), + Some(nextStatefulOperationStateInfo), Some(offsetSeqMetadata.batchWatermarkMs)) case m: FlatMapGroupsWithStateExec => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) m.copy( - stateId = Some(stateId), + stateInfo = Some(nextStatefulOperationStateInfo), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 74f0f509bbf85..06bdec8b06407 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -652,6 +652,7 @@ class StreamExecution( triggerLogicalPlan, outputMode, checkpointFile("state"), + runId, currentBatchId, offsetSeqMetadata) lastExecution.executedPlan // Force the lazy generation of execution plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 67d86daf10812..bae7a15165e43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -92,7 +92,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit @volatile private var state: STATE = UPDATING @volatile private var finalDeltaFile: Path = null - override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id + override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId override def get(key: UnsafeRow): UnsafeRow = { mapToUpdate.get(key) @@ -177,7 +177,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** * Whether all updates have been committed */ - override private[streaming] def hasCommitted: Boolean = { + override def hasCommitted: Boolean = { state == COMMITTED } @@ -205,7 +205,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit indexOrdinal: Option[Int], // for sorting the data storeConf: StateStoreConf, hadoopConf: Configuration): Unit = { - this.stateStoreId = stateStoreId + this.stateStoreId_ = stateStoreId this.keySchema = keySchema this.valueSchema = valueSchema this.storeConf = storeConf @@ -213,7 +213,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit fs.mkdirs(baseDir) } - override def id: StateStoreId = stateStoreId + override def stateStoreId: StateStoreId = stateStoreId_ /** Do maintenance backing data files, including creating snapshots and cleaning up old files */ override def doMaintenance(): Unit = { @@ -231,20 +231,20 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } override def toString(): String = { - s"HDFSStateStoreProvider[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" + s"HDFSStateStoreProvider[" + + s"id = (op=${stateStoreId.operatorId},part=${stateStoreId.partitionId}),dir = $baseDir]" } /* Internal fields and methods */ - @volatile private var stateStoreId: StateStoreId = _ + @volatile private var stateStoreId_ : StateStoreId = _ @volatile private var keySchema: StructType = _ @volatile private var valueSchema: StructType = _ @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ private lazy val loadedMaps = new mutable.HashMap[Long, MapType] - private lazy val baseDir = - new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}") + private lazy val baseDir = stateStoreId.storeCheckpointLocation() private lazy val fs = baseDir.getFileSystem(hadoopConf) private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 29c456f86e1ed..a94ff8a7ebd1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID import java.util.concurrent.{ScheduledFuture, TimeUnit} import javax.annotation.concurrent.GuardedBy @@ -24,14 +25,14 @@ import scala.collection.mutable import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ThreadUtils, Utils} - /** * Base trait for a versioned key-value store. Each instance of a `StateStore` represents a specific * version of state data, and such instances are created through a [[StateStoreProvider]]. @@ -99,7 +100,7 @@ trait StateStore { /** * Whether all updates have been committed */ - private[streaming] def hasCommitted: Boolean + def hasCommitted: Boolean } @@ -147,7 +148,7 @@ trait StateStoreProvider { * Return the id of the StateStores this provider will generate. * Should be the same as the one passed in init(). */ - def id: StateStoreId + def stateStoreId: StateStoreId /** Called when the provider instance is unloaded from the executor */ def close(): Unit @@ -179,13 +180,46 @@ object StateStoreProvider { } } +/** + * Unique identifier for a provider, used to identify when providers can be reused. + * Note that `queryRunId` is used uniquely identify a provider, so that the same provider + * instance is not reused across query restarts. + */ +case class StateStoreProviderId(storeId: StateStoreId, queryRunId: UUID) -/** Unique identifier for a bunch of keyed state data. */ +/** + * Unique identifier for a bunch of keyed state data. + * @param checkpointRootLocation Root directory where all the state data of a query is stored + * @param operatorId Unique id of a stateful operator + * @param partitionId Index of the partition of an operators state data + * @param storeName Optional, name of the store. Each partition can optionally use multiple state + * stores, but they have to be identified by distinct names. + */ case class StateStoreId( - checkpointLocation: String, + checkpointRootLocation: String, operatorId: Long, partitionId: Int, - name: String = "") + storeName: String = StateStoreId.DEFAULT_STORE_NAME) { + + /** + * Checkpoint directory to be used by a single state store, identified uniquely by the tuple + * (operatorId, partitionId, storeName). All implementations of [[StateStoreProvider]] should + * use this path for saving state data, as this ensures that distinct stores will write to + * different locations. + */ + def storeCheckpointLocation(): Path = { + if (storeName == StateStoreId.DEFAULT_STORE_NAME) { + // For reading state store data that was generated before store names were used (Spark <= 2.2) + new Path(checkpointRootLocation, s"$operatorId/$partitionId") + } else { + new Path(checkpointRootLocation, s"$operatorId/$partitionId/$storeName") + } + } +} + +object StateStoreId { + val DEFAULT_STORE_NAME = "default" +} /** Mutable, and reusable class for representing a pair of UnsafeRows. */ class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) { @@ -211,7 +245,7 @@ object StateStore extends Logging { val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 @GuardedBy("loadedProviders") - private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() + private val loadedProviders = new mutable.HashMap[StateStoreProviderId, StateStoreProvider]() /** * Runs the `task` periodically and automatically cancels it if there is an exception. `onError` @@ -253,7 +287,7 @@ object StateStore extends Logging { /** Get or create a store associated with the id. */ def get( - storeId: StateStoreId, + storeProviderId: StateStoreProviderId, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int], @@ -264,24 +298,24 @@ object StateStore extends Logging { val storeProvider = loadedProviders.synchronized { startMaintenanceIfNeeded() val provider = loadedProviders.getOrElseUpdate( - storeId, + storeProviderId, StateStoreProvider.instantiate( - storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) + storeProviderId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) ) - reportActiveStoreInstance(storeId) + reportActiveStoreInstance(storeProviderId) provider } storeProvider.getStore(version) } /** Unload a state store provider */ - def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized { - loadedProviders.remove(storeId).foreach(_.close()) + def unload(storeProviderId: StateStoreProviderId): Unit = loadedProviders.synchronized { + loadedProviders.remove(storeProviderId).foreach(_.close()) } /** Whether a state store provider is loaded or not */ - def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized { - loadedProviders.contains(storeId) + def isLoaded(storeProviderId: StateStoreProviderId): Boolean = loadedProviders.synchronized { + loadedProviders.contains(storeProviderId) } def isMaintenanceRunning: Boolean = loadedProviders.synchronized { @@ -340,21 +374,21 @@ object StateStore extends Logging { } } - private def reportActiveStoreInstance(storeId: StateStoreId): Unit = { + private def reportActiveStoreInstance(storeProviderId: StateStoreProviderId): Unit = { if (SparkEnv.get != null) { val host = SparkEnv.get.blockManager.blockManagerId.host val executorId = SparkEnv.get.blockManager.blockManagerId.executorId - coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId)) - logDebug(s"Reported that the loaded instance $storeId is active") + coordinatorRef.foreach(_.reportActiveInstance(storeProviderId, host, executorId)) + logInfo(s"Reported that the loaded instance $storeProviderId is active") } } - private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = { + private def verifyIfStoreInstanceActive(storeProviderId: StateStoreProviderId): Boolean = { if (SparkEnv.get != null) { val executorId = SparkEnv.get.blockManager.blockManagerId.executorId val verified = - coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false) - logDebug(s"Verified whether the loaded instance $storeId is active: $verified") + coordinatorRef.map(_.verifyIfInstanceActive(storeProviderId, executorId)).getOrElse(false) + logDebug(s"Verified whether the loaded instance $storeProviderId is active: $verified") verified } else { false @@ -364,12 +398,21 @@ object StateStore extends Logging { private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { - if (_coordRef == null) { + logInfo("Env is not null") + val isDriver = + env.executorId == SparkContext.DRIVER_IDENTIFIER || + env.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER + // If running locally, then the coordinator reference in _coordRef may be have become inactive + // as SparkContext + SparkEnv may have been restarted. Hence, when running in driver, + // always recreate the reference. + if (isDriver || _coordRef == null) { + logInfo("Getting StateStoreCoordinatorRef") _coordRef = StateStoreCoordinatorRef.forExecutor(env) } - logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") + logInfo(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") Some(_coordRef) } else { + logInfo("Env is null") _coordRef = null None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index d0f81887e62d1..3884f5e6ce766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID + import scala.collection.mutable import org.apache.spark.SparkEnv @@ -29,16 +31,19 @@ import org.apache.spark.util.RpcUtils private sealed trait StateStoreCoordinatorMessage extends Serializable /** Classes representing messages */ -private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String) +private case class ReportActiveInstance( + storeId: StateStoreProviderId, + host: String, + executorId: String) extends StateStoreCoordinatorMessage -private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: String) +private case class VerifyIfInstanceActive(storeId: StateStoreProviderId, executorId: String) extends StateStoreCoordinatorMessage -private case class GetLocation(storeId: StateStoreId) +private case class GetLocation(storeId: StateStoreProviderId) extends StateStoreCoordinatorMessage -private case class DeactivateInstances(checkpointLocation: String) +private case class DeactivateInstances(runId: UUID) extends StateStoreCoordinatorMessage private object StopCoordinator @@ -80,25 +85,27 @@ object StateStoreCoordinatorRef extends Logging { class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { private[state] def reportActiveInstance( - storeId: StateStoreId, + stateStoreProviderId: StateStoreProviderId, host: String, executorId: String): Unit = { - rpcEndpointRef.send(ReportActiveInstance(storeId, host, executorId)) + rpcEndpointRef.send(ReportActiveInstance(stateStoreProviderId, host, executorId)) } /** Verify whether the given executor has the active instance of a state store */ - private[state] def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = { - rpcEndpointRef.askSync[Boolean](VerifyIfInstanceActive(storeId, executorId)) + private[state] def verifyIfInstanceActive( + stateStoreProviderId: StateStoreProviderId, + executorId: String): Boolean = { + rpcEndpointRef.askSync[Boolean](VerifyIfInstanceActive(stateStoreProviderId, executorId)) } /** Get the location of the state store */ - private[state] def getLocation(storeId: StateStoreId): Option[String] = { - rpcEndpointRef.askSync[Option[String]](GetLocation(storeId)) + private[state] def getLocation(stateStoreProviderId: StateStoreProviderId): Option[String] = { + rpcEndpointRef.askSync[Option[String]](GetLocation(stateStoreProviderId)) } - /** Deactivate instances related to a set of operator */ - private[state] def deactivateInstances(storeRootLocation: String): Unit = { - rpcEndpointRef.askSync[Boolean](DeactivateInstances(storeRootLocation)) + /** Deactivate instances related to a query */ + private[sql] def deactivateInstances(runId: UUID): Unit = { + rpcEndpointRef.askSync[Boolean](DeactivateInstances(runId)) } private[state] def stop(): Unit = { @@ -113,7 +120,7 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { */ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { - private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] + private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] override def receive: PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId) => @@ -135,11 +142,11 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) logDebug(s"Got location of the state store $id: $executorId") context.reply(executorId) - case DeactivateInstances(checkpointLocation) => + case DeactivateInstances(runId) => val storeIdsToRemove = - instances.keys.filter(_.checkpointLocation == checkpointLocation).toSeq + instances.keys.filter(_.queryRunId == runId).toSeq instances --= storeIdsToRemove - logDebug(s"Deactivating instances related to checkpoint location $checkpointLocation: " + + logDebug(s"Deactivating instances related to checkpoint location $runId: " + storeIdsToRemove.mkString(", ")) context.reply(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index b744c25dc97a8..01d8e75980993 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID + import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} @@ -34,8 +36,8 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( dataRDD: RDD[T], storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], checkpointLocation: String, + queryRunId: UUID, operatorId: Long, - storeName: String, storeVersion: Long, keySchema: StructType, valueSchema: StructType, @@ -52,16 +54,25 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( override protected def getPartitions: Array[Partition] = dataRDD.partitions + /** + * Set the preferred location of each partition using the executor that has the related + * [[StateStoreProvider]] already loaded. + */ override def getPreferredLocations(partition: Partition): Seq[String] = { - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName) - storeCoordinator.flatMap(_.getLocation(storeId)).toSeq + val stateStoreProviderId = StateStoreProviderId( + StateStoreId(checkpointLocation, operatorId, partition.index), + queryRunId) + storeCoordinator.flatMap(_.getLocation(stateStoreProviderId)).toSeq } override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { var store: StateStore = null - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName) + val storeProviderId = StateStoreProviderId( + StateStoreId(checkpointLocation, operatorId, partition.index), + queryRunId) + store = StateStore.get( - storeId, keySchema, valueSchema, indexOrdinal, storeVersion, + storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion, storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeUpdateFunction(store, inputIter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 228fe86d59940..a0086e251f9c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID + import scala.reflect.ClassTag import org.apache.spark.TaskContext @@ -32,20 +34,14 @@ package object state { /** Map each partition of an RDD along with data in a [[StateStore]]. */ def mapPartitionsWithStateStore[U: ClassTag]( sqlContext: SQLContext, - checkpointLocation: String, - operatorId: Long, - storeName: String, - storeVersion: Long, + stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int])( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { mapPartitionsWithStateStore( - checkpointLocation, - operatorId, - storeName, - storeVersion, + stateInfo, keySchema, valueSchema, indexOrdinal, @@ -56,10 +52,7 @@ package object state { /** Map each partition of an RDD along with data in a [[StateStore]]. */ private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( - checkpointLocation: String, - operatorId: Long, - storeName: String, - storeVersion: Long, + stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int], @@ -79,10 +72,10 @@ package object state { new StateStoreRDD( dataRDD, wrappedF, - checkpointLocation, - operatorId, - storeName, - storeVersion, + stateInfo.checkpointLocation, + stateInfo.queryRunId, + stateInfo.operatorId, + stateInfo.storeVersion, keySchema, valueSchema, indexOrdinal, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 3e57f3fbada32..c5722466a33af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID import java.util.concurrent.TimeUnit._ import org.apache.spark.rdd.RDD @@ -36,20 +37,22 @@ import org.apache.spark.util.{CompletionIterator, NextIterator} /** Used to identify the state store for a given operator. */ -case class OperatorStateId( +case class StatefulOperatorStateInfo( checkpointLocation: String, + queryRunId: UUID, operatorId: Long, - batchId: Long) + storeVersion: Long) /** - * An operator that reads or writes state from the [[StateStore]]. The [[OperatorStateId]] should - * be filled in by `prepareForExecution` in [[IncrementalExecution]]. + * An operator that reads or writes state from the [[StateStore]]. + * The [[StatefulOperatorStateInfo]] should be filled in by `prepareForExecution` in + * [[IncrementalExecution]]. */ trait StatefulOperator extends SparkPlan { - def stateId: Option[OperatorStateId] + def stateInfo: Option[StatefulOperatorStateInfo] - protected def getStateId: OperatorStateId = attachTree(this) { - stateId.getOrElse { + protected def getStateInfo: StatefulOperatorStateInfo = attachTree(this) { + stateInfo.getOrElse { throw new IllegalStateException("State location not present for execution") } } @@ -140,7 +143,7 @@ trait WatermarkSupport extends UnaryExecNode { */ case class StateStoreRestoreExec( keyExpressions: Seq[Attribute], - stateId: Option[OperatorStateId], + stateInfo: Option[StatefulOperatorStateInfo], child: SparkPlan) extends UnaryExecNode with StateStoreReader { @@ -148,10 +151,7 @@ case class StateStoreRestoreExec( val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - operatorId = getStateId.operatorId, - storeName = "default", - storeVersion = getStateId.batchId, + getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, @@ -177,7 +177,7 @@ case class StateStoreRestoreExec( */ case class StateStoreSaveExec( keyExpressions: Seq[Attribute], - stateId: Option[OperatorStateId] = None, + stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, child: SparkPlan) @@ -189,10 +189,7 @@ case class StateStoreSaveExec( "Incorrect planning in IncrementalExecution, outputMode has not been set") child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - getStateId.operatorId, - storeName = "default", - getStateId.batchId, + getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, @@ -319,7 +316,7 @@ case class StateStoreSaveExec( case class StreamingDeduplicateExec( keyExpressions: Seq[Attribute], child: SparkPlan, - stateId: Option[OperatorStateId] = None, + stateInfo: Option[StatefulOperatorStateInfo] = None, eventTimeWatermark: Option[Long] = None) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { @@ -331,10 +328,7 @@ case class StreamingDeduplicateExec( metrics // force lazy init at driver child.execute().mapPartitionsWithStateStore( - getStateId.checkpointLocation, - getStateId.operatorId, - storeName = "default", - getStateId.batchId, + getStateInfo, keyExpressions.toStructType, child.output.toStructType, indexOrdinal = None, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 002c45413b4c2..48b0ea20e5da1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -332,5 +332,6 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } awaitTerminationLock.notifyAll() } + stateStoreCoordinator.deactivateInstances(terminatedQuery.runId) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index a7e32626264cc..9a7595eee7bd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -17,11 +17,17 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.UUID + import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.functions.count +import org.apache.spark.util.Utils class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { @@ -29,7 +35,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("report, verify, getLocation") { withCoordinatorRef(sc) { coordinatorRef => - val id = StateStoreId("x", 0, 0) + val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) assert(coordinatorRef.getLocation(id) === None) @@ -57,9 +63,11 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("make inactive") { withCoordinatorRef(sc) { coordinatorRef => - val id1 = StateStoreId("x", 0, 0) - val id2 = StateStoreId("y", 1, 0) - val id3 = StateStoreId("x", 0, 1) + val runId1 = UUID.randomUUID + val runId2 = UUID.randomUUID + val id1 = StateStoreProviderId(StateStoreId("x", 0, 0), runId1) + val id2 = StateStoreProviderId(StateStoreId("y", 1, 0), runId2) + val id3 = StateStoreProviderId(StateStoreId("x", 0, 1), runId1) val host = "hostX" val exec = "exec1" @@ -73,7 +81,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true) } - coordinatorRef.deactivateInstances("x") + coordinatorRef.deactivateInstances(runId1) assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === false) assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) @@ -85,7 +93,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { Some(ExecutorCacheTaskLocation(host, exec).toString)) assert(coordinatorRef.getLocation(id3) === None) - coordinatorRef.deactivateInstances("y") + coordinatorRef.deactivateInstances(runId2) assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === false) assert(coordinatorRef.getLocation(id2) === None) } @@ -95,7 +103,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { withCoordinatorRef(sc) { coordRef1 => val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env) - val id = StateStoreId("x", 0, 0) + val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) coordRef1.reportActiveInstance(id, "hostX", "exec1") @@ -107,6 +115,45 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } } + + test("query stop deactivates related store providers") { + var coordRef: StateStoreCoordinatorRef = null + try { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + SparkSession.setActiveSession(spark) + import spark.implicits._ + coordRef = spark.streams.stateStoreCoordinator + implicit val sqlContext = spark.sqlContext + spark.conf.set("spark.sql.shuffle.partitions", "1") + + // Start a query and run a batch to load state stores + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().groupBy("value").agg(count("*")) // stateful query + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + inputData.addData(1, 2, 3) + query.processAllAvailable() + + // Verify state store has been loaded + val stateCheckpointDir = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation + val providerId = StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, 0), query.runId) + assert(coordRef.getLocation(providerId).nonEmpty) + + // Stop and verify whether the stores are deactivated in the coordinator + query.stop() + assert(coordRef.getLocation(providerId).isEmpty) + } finally { + SparkSession.getActiveSession.foreach(_.streams.active.foreach(_.stop())) + if (coordRef != null) coordRef.stop() + StateStore.stop() + } + } } object StateStoreCoordinatorSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 4a1a089af54c2..defb9ed63a881 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -19,20 +19,19 @@ package org.apache.spark.sql.execution.streaming.state import java.io.File import java.nio.file.Files +import java.util.UUID import scala.util.Random import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.sql.LocalSparkSession._ -import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.LocalSparkSession._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{CompletionIterator, Utils} @@ -57,16 +56,14 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("versioning and immutability") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - val opId = 0 - val rdd1 = - makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( + spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)( + spark.sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)( increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) @@ -76,7 +73,6 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("recovering from files") { - val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString def makeStoreRDD( @@ -85,7 +81,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn storeVersion: Int): RDD[(String, Int)] = { implicit val sqlContext = spark.sqlContext makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion, keySchema, valueSchema, None)(increment) + sqlContext, operatorStateInfo(path, version = storeVersion), + keySchema, valueSchema, None)(increment) } // Generate RDDs and state store data @@ -132,17 +129,17 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( iteratorOfGets) assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( iteratorOfPuts) assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)( + sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)( iteratorOfGets) assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } @@ -150,22 +147,25 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("preferred locations using StateStoreCoordinator") { quietly { + val queryRunId = UUID.randomUUID val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0, "name"), "host1", "exec1") - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1, "name"), "host2", "exec2") + val storeProviderId1 = StateStoreProviderId(StateStoreId(path, opId, 0), queryRunId) + val storeProviderId2 = StateStoreProviderId(StateStoreId(path, opId, 1), queryRunId) + coordinatorRef.reportActiveInstance(storeProviderId1, "host1", "exec1") + coordinatorRef.reportActiveInstance(storeProviderId2, "host2", "exec2") - assert( - coordinatorRef.getLocation(StateStoreId(path, opId, 0, "name")) === + require( + coordinatorRef.getLocation(storeProviderId1) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( - increment) + sqlContext, operatorStateInfo(path, queryRunId = queryRunId), + keySchema, valueSchema, None)(increment) require(rdd.partitions.length === 2) assert( @@ -192,12 +192,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(increment) + sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(increment) + sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -210,6 +210,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2) } + private def operatorStateInfo( + path: String, + queryRunId: UUID = UUID.randomUUID, + version: Int = 0): StatefulOperatorStateInfo = { + StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version) + } + private val increment = (store: StateStore, iter: Iterator[String]) => { iter.foreach { s => val key = stringToRow(s) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index af2b9f1c11fb6..c2087ec219e57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI +import java.util.UUID import scala.collection.JavaConverters._ import scala.collection.mutable @@ -33,8 +34,11 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.LocalSparkContext._ +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.functions.count import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -143,7 +147,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] provider.getStore(0).commit() // Verify we don't leak temp files - val tempFiles = FileUtils.listFiles(new File(provider.id.checkpointLocation), + val tempFiles = FileUtils.listFiles(new File(provider.stateStoreId.checkpointRootLocation), null, true).asScala.filter(_.getName.startsWith("temp-")) assert(tempFiles.isEmpty) } @@ -183,7 +187,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] test("StateStore.get") { quietly { val dir = newDir() - val storeId = StateStoreId(dir, 0, 0) + val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() @@ -243,18 +247,18 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] .set("spark.rpc.numRetries", "1") val opId = 0 val dir = newDir() - val storeId = StateStoreId(dir, opId, 0) + val storeProviderId = StateStoreProviderId(StateStoreId(dir, opId, 0), UUID.randomUUID) val sqlConf = new SQLConf() sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) val storeConf = StateStoreConf(sqlConf) val hadoopConf = new Configuration() - val provider = newStoreProvider(storeId) + val provider = newStoreProvider(storeProviderId.storeId) var latestStoreVersion = 0 def generateStoreVersions() { for (i <- 1 to 20) { - val store = StateStore.get(storeId, keySchema, valueSchema, None, + val store = StateStore.get(storeProviderId, keySchema, valueSchema, None, latestStoreVersion, storeConf, hadoopConf) put(store, "a", i) store.commit() @@ -274,7 +278,8 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] eventually(timeout(timeoutDuration)) { // Store should have been reported to the coordinator - assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported") + assert(coordinatorRef.getLocation(storeProviderId).nonEmpty, + "active instance was not reported") // Background maintenance should clean up and generate snapshots assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") @@ -295,35 +300,35 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") } - // If driver decides to deactivate all instances of the store, then this instance - // should be unloaded - coordinatorRef.deactivateInstances(dir) + // If driver decides to deactivate all stores related to a query run, + // then this instance should be unloaded + coordinatorRef.deactivateInstances(storeProviderId.queryRunId) eventually(timeout(timeoutDuration)) { - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None, + StateStore.get(storeProviderId, keySchema, valueSchema, indexOrdinal = None, latestStoreVersion, storeConf, hadoopConf) - assert(StateStore.isLoaded(storeId)) + assert(StateStore.isLoaded(storeProviderId)) // If some other executor loads the store, then this instance should be unloaded - coordinatorRef.reportActiveInstance(storeId, "other-host", "other-exec") + coordinatorRef.reportActiveInstance(storeProviderId, "other-host", "other-exec") eventually(timeout(timeoutDuration)) { - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None, + StateStore.get(storeProviderId, keySchema, valueSchema, indexOrdinal = None, latestStoreVersion, storeConf, hadoopConf) - assert(StateStore.isLoaded(storeId)) + assert(StateStore.isLoaded(storeProviderId)) } } // Verify if instance is unloaded if SparkContext is stopped eventually(timeout(timeoutDuration)) { require(SparkEnv.get === null) - assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isLoaded(storeProviderId)) assert(!StateStore.isMaintenanceRunning) } } @@ -344,7 +349,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] test("SPARK-18416: do not create temp delta file until the store is updated") { val dir = newDir() - val storeId = StateStoreId(dir, 0, 0) + val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() val deltaFileDir = new File(s"$dir/0/0/") @@ -408,12 +413,60 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(numDeltaFiles === 3) } + test("SPARK-21145: Restarted queries create new provider instances") { + try { + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val spark = SparkSession.builder().master("local[2]").getOrCreate() + SparkSession.setActiveSession(spark) + implicit val sqlContext = spark.sqlContext + spark.conf.set("spark.sql.shuffle.partitions", "1") + import spark.implicits._ + val inputData = MemoryStream[Int] + + def runQueryAndGetLoadedProviders(): Seq[StateStoreProvider] = { + val aggregated = inputData.toDF().groupBy("value").agg(count("*")) + // stateful query + val query = aggregated.writeStream + .format("memory") + .outputMode("complete") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + inputData.addData(1, 2, 3) + query.processAllAvailable() + require(query.lastProgress != null) // at least one batch processed after start + val loadedProvidersMethod = + PrivateMethod[mutable.HashMap[StateStoreProviderId, StateStoreProvider]]('loadedProviders) + val loadedProvidersMap = StateStore invokePrivate loadedProvidersMethod() + val loadedProviders = loadedProvidersMap.synchronized { loadedProvidersMap.values.toSeq } + query.stop() + loadedProviders + } + + val loadedProvidersAfterRun1 = runQueryAndGetLoadedProviders() + require(loadedProvidersAfterRun1.length === 1) + + val loadedProvidersAfterRun2 = runQueryAndGetLoadedProviders() + assert(loadedProvidersAfterRun2.length === 2) // two providers loaded for 2 runs + + // Both providers should have the same StateStoreId, but the should be different objects + assert(loadedProvidersAfterRun2(0).stateStoreId === loadedProvidersAfterRun2(1).stateStoreId) + assert(loadedProvidersAfterRun2(0) ne loadedProvidersAfterRun2(1)) + + } finally { + SparkSession.getActiveSession.foreach { spark => + spark.streams.active.foreach(_.stop()) + spark.stop() + } + } + } + override def newStoreProvider(): HDFSBackedStateStoreProvider = { newStoreProvider(opId = Random.nextInt(), partition = 0) } override def newStoreProvider(storeId: StateStoreId): HDFSBackedStateStoreProvider = { - newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointLocation) + newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointRootLocation) } override def getLatestData(storeProvider: HDFSBackedStateStoreProvider): Set[(String, Int)] = { @@ -423,7 +476,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] override def getData( provider: HDFSBackedStateStoreProvider, version: Int = -1): Set[(String, Int)] = { - val reloadedProvider = newStoreProvider(provider.id) + val reloadedProvider = newStoreProvider(provider.stateStoreId) if (version < 0) { reloadedProvider.latestIterator().map(rowsToStringInt).toSet } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 4ede4fd9a035e..86c3a35a59c13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -777,7 +777,7 @@ class TestStateStoreProvider extends StateStoreProvider { throw new Exception("Successfully instantiated") } - override def id: StateStoreId = null + override def stateStoreId: StateStoreId = null override def close(): Unit = { } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 2a4039cc5831a..b2c42eef88f6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -26,9 +26,8 @@ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal -import org.scalatest.Assertions +import org.scalatest.{Assertions, BeforeAndAfterAll} import org.scalatest.concurrent.{Eventually, Timeouts} -import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.Span @@ -39,9 +38,10 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ro import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * A framework for implementing tests for streaming queries and sources. @@ -67,7 +67,12 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} * avoid hanging forever in the case of failures. However, individual suites can change this * by overriding `streamingTimeout`. */ -trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { +trait StreamTest extends QueryTest with SharedSQLContext with Timeouts with BeforeAndAfterAll { + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() // stop the state store maintenance thread and unload store providers + } /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds From 153dd49b74e1b6df2b8e35760806c9754ca7bfae Mon Sep 17 00:00:00 2001 From: jinxing Date: Fri, 23 Jun 2017 20:41:17 +0800 Subject: [PATCH 064/118] [SPARK-21047] Add test suites for complicated cases in ColumnarBatchSuite ## What changes were proposed in this pull request? Current ColumnarBatchSuite has very simple test cases for `Array` and `Struct`. This pr wants to add some test suites for complicated cases in ColumnVector. Author: jinxing Closes #18327 from jinxing64/SPARK-21047. --- .../execution/vectorized/ColumnarBatch.java | 35 ++++- .../vectorized/ColumnarBatchSuite.scala | 122 ++++++++++++++++++ 2 files changed, 156 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 8b7b0e655b31d..e23a64350cbc5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -241,7 +241,40 @@ public MapData getMap(int ordinal) { @Override public Object get(int ordinal, DataType dataType) { - throw new UnsupportedOperationException(); + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index e48e3f6402901..80d41577dcf2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -739,6 +739,128 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } + test("Nest Array in Array.") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val column = ColumnVector.allocate(10, new ArrayType(new ArrayType(IntegerType, true), true), + memMode) + val childColumn = column.arrayData() + val data = column.arrayData().arrayData() + (0 until 6).foreach { + case 3 => data.putNull(3) + case i => data.putInt(i, i) + } + // Arrays in child column: [0], [1, 2], [], [null, 4, 5] + childColumn.putArray(0, 0, 1) + childColumn.putArray(1, 1, 2) + childColumn.putArray(2, 2, 0) + childColumn.putArray(3, 3, 3) + // Arrays in column: [[0]], [[1, 2], []], [[], [null, 4, 5]], null + column.putArray(0, 0, 1) + column.putArray(1, 1, 2) + column.putArray(2, 2, 2) + column.putNull(3) + + assert(column.getArray(0).getArray(0).toIntArray() === Array(0)) + assert(column.getArray(1).getArray(0).toIntArray() === Array(1, 2)) + assert(column.getArray(1).getArray(1).toIntArray() === Array()) + assert(column.getArray(2).getArray(0).toIntArray() === Array()) + assert(column.getArray(2).getArray(1).isNullAt(0)) + assert(column.getArray(2).getArray(1).getInt(1) === 4) + assert(column.getArray(2).getArray(1).getInt(2) === 5) + assert(column.isNullAt(3)) + } + } + + test("Nest Struct in Array.") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val schema = new StructType().add("int", IntegerType).add("long", LongType) + val column = ColumnVector.allocate(10, new ArrayType(schema, true), memMode) + val data = column.arrayData() + val c0 = data.getChildColumn(0) + val c1 = data.getChildColumn(1) + // Structs in child column: (0, 0), (1, 10), (2, 20), (3, 30), (4, 40), (5, 50) + (0 until 6).foreach { i => + c0.putInt(i, i) + c1.putLong(i, i * 10) + } + // Arrays in column: [(0, 0), (1, 10)], [(1, 10), (2, 20), (3, 30)], + // [(4, 40), (5, 50)] + column.putArray(0, 0, 2) + column.putArray(1, 1, 3) + column.putArray(2, 4, 2) + + assert(column.getArray(0).getStruct(0, 2).toSeq(schema) === Seq(0, 0)) + assert(column.getArray(0).getStruct(1, 2).toSeq(schema) === Seq(1, 10)) + assert(column.getArray(1).getStruct(0, 2).toSeq(schema) === Seq(1, 10)) + assert(column.getArray(1).getStruct(1, 2).toSeq(schema) === Seq(2, 20)) + assert(column.getArray(1).getStruct(2, 2).toSeq(schema) === Seq(3, 30)) + assert(column.getArray(2).getStruct(0, 2).toSeq(schema) === Seq(4, 40)) + assert(column.getArray(2).getStruct(1, 2).toSeq(schema) === Seq(5, 50)) + } + } + + test("Nest Array in Struct.") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val schema = new StructType() + .add("int", IntegerType) + .add("array", new ArrayType(IntegerType, true)) + val column = ColumnVector.allocate(10, schema, memMode) + val c0 = column.getChildColumn(0) + val c1 = column.getChildColumn(1) + c0.putInt(0, 0) + c0.putInt(1, 1) + c0.putInt(2, 2) + val c1Child = c1.arrayData() + (0 until 6).foreach { i => + c1Child.putInt(i, i) + } + // Arrays in c1: [0, 1], [2], [3, 4, 5] + c1.putArray(0, 0, 2) + c1.putArray(1, 2, 1) + c1.putArray(2, 3, 3) + + assert(column.getStruct(0).getInt(0) === 0) + assert(column.getStruct(0).getArray(1).toIntArray() === Array(0, 1)) + assert(column.getStruct(1).getInt(0) === 1) + assert(column.getStruct(1).getArray(1).toIntArray() === Array(2)) + assert(column.getStruct(2).getInt(0) === 2) + assert(column.getStruct(2).getArray(1).toIntArray() === Array(3, 4, 5)) + } + } + + test("Nest Struct in Struct.") { + (MemoryMode.ON_HEAP :: Nil).foreach { memMode => + val subSchema = new StructType() + .add("int", IntegerType) + .add("int", IntegerType) + val schema = new StructType() + .add("int", IntegerType) + .add("struct", subSchema) + val column = ColumnVector.allocate(10, schema, memMode) + val c0 = column.getChildColumn(0) + val c1 = column.getChildColumn(1) + c0.putInt(0, 0) + c0.putInt(1, 1) + c0.putInt(2, 2) + val c1c0 = c1.getChildColumn(0) + val c1c1 = c1.getChildColumn(1) + // Structs in c1: (7, 70), (8, 80), (9, 90) + c1c0.putInt(0, 7) + c1c0.putInt(1, 8) + c1c0.putInt(2, 9) + c1c1.putInt(0, 70) + c1c1.putInt(1, 80) + c1c1.putInt(2, 90) + + assert(column.getStruct(0).getInt(0) === 0) + assert(column.getStruct(0).getStruct(1, 2).toSeq(subSchema) === Seq(7, 70)) + assert(column.getStruct(1).getInt(0) === 1) + assert(column.getStruct(1).getStruct(1, 2).toSeq(subSchema) === Seq(8, 80)) + assert(column.getStruct(2).getInt(0) === 2) + assert(column.getStruct(2).getStruct(1, 2).toSeq(subSchema) === Seq(9, 90)) + } + } + test("ColumnarBatch basic") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val schema = new StructType() From acd208ee50b29bde4e097bf88761867b1d57a665 Mon Sep 17 00:00:00 2001 From: 10129659 Date: Fri, 23 Jun 2017 20:53:26 +0800 Subject: [PATCH 065/118] [SPARK-21115][CORE] If the cores left is less than the coresPerExecutor,the cores left will not be allocated, so it should not to check in every schedule ## What changes were proposed in this pull request? If we start an app with the param --total-executor-cores=4 and spark.executor.cores=3, the cores left is always 1, so it will try to allocate executors in the function org.apache.spark.deploy.master.startExecutorsOnWorkers in every schedule. Another question is, is it will be better to allocate another executor with 1 core for the cores left. ## How was this patch tested? unit test Author: 10129659 Closes #18322 from eatoncys/leftcores. --- .../scala/org/apache/spark/SparkConf.scala | 11 +++++++ .../apache/spark/deploy/master/Master.scala | 29 ++++++++++--------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index ba7a65f79c414..de2f475c6895f 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -543,6 +543,17 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } } + if (contains("spark.cores.max") && contains("spark.executor.cores")) { + val totalCores = getInt("spark.cores.max", 1) + val executorCores = getInt("spark.executor.cores", 1) + val leftCores = totalCores % executorCores + if (leftCores != 0) { + logWarning(s"Total executor cores: ${totalCores} is not " + + s"divisible by cores per executor: ${executorCores}, " + + s"the left cores: ${leftCores} will not be allocated") + } + } + val encryptionEnabled = get(NETWORK_ENCRYPTION_ENABLED) || get(SASL_ENCRYPTION_ENABLED) require(!encryptionEnabled || get(NETWORK_AUTH_ENABLED), s"${NETWORK_AUTH_ENABLED.key} must be enabled when enabling encryption.") diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index c192a0cc82ef6..0dee25fb2ebe2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -659,19 +659,22 @@ private[deploy] class Master( private def startExecutorsOnWorkers(): Unit = { // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app // in the queue, then the second app, etc. - for (app <- waitingApps if app.coresLeft > 0) { - val coresPerExecutor: Option[Int] = app.desc.coresPerExecutor - // Filter out workers that don't have enough resources to launch an executor - val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && - worker.coresFree >= coresPerExecutor.getOrElse(1)) - .sortBy(_.coresFree).reverse - val assignedCores = scheduleExecutorsOnWorkers(app, usableWorkers, spreadOutApps) - - // Now that we've decided how many cores to allocate on each worker, let's allocate them - for (pos <- 0 until usableWorkers.length if assignedCores(pos) > 0) { - allocateWorkerResourceToExecutors( - app, assignedCores(pos), coresPerExecutor, usableWorkers(pos)) + for (app <- waitingApps) { + val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(1) + // If the cores left is less than the coresPerExecutor,the cores left will not be allocated + if (app.coresLeft >= coresPerExecutor) { + // Filter out workers that don't have enough resources to launch an executor + val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) + .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && + worker.coresFree >= coresPerExecutor) + .sortBy(_.coresFree).reverse + val assignedCores = scheduleExecutorsOnWorkers(app, usableWorkers, spreadOutApps) + + // Now that we've decided how many cores to allocate on each worker, let's allocate them + for (pos <- 0 until usableWorkers.length if assignedCores(pos) > 0) { + allocateWorkerResourceToExecutors( + app, assignedCores(pos), app.desc.coresPerExecutor, usableWorkers(pos)) + } } } } From 5dca10b8fdec81a3cc476301fa4f82ea917c34ec Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 23 Jun 2017 21:51:55 +0800 Subject: [PATCH 066/118] [SPARK-21193][PYTHON] Specify Pandas version in setup.py ## What changes were proposed in this pull request? It looks we missed specifying the Pandas version. This PR proposes to fix it. For the current state, it should be Pandas 0.13.0 given my test. This PR propose to fix it as 0.13.0. Running the codes below: ```python from pyspark.sql.types import * schema = StructType().add("a", IntegerType()).add("b", StringType())\ .add("c", BooleanType()).add("d", FloatType()) data = [ (1, "foo", True, 3.0,), (2, "foo", True, 5.0), (3, "bar", False, -1.0), (4, "bar", False, 6.0), ] spark.createDataFrame(data, schema).toPandas().dtypes ``` prints ... **With Pandas 0.13.0** - released, 2014-01 ``` a int32 b object c bool d float32 dtype: object ``` **With Pandas 0.12.0** - - released, 2013-06 ``` Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/sql/dataframe.py", line 1734, in toPandas pdf[f] = pdf[f].astype(t, copy=False) TypeError: astype() got an unexpected keyword argument 'copy' ``` without `copy` ``` a int32 b object c bool d float32 dtype: object ``` **With Pandas 0.11.0** - released, 2013-03 ``` Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/sql/dataframe.py", line 1734, in toPandas pdf[f] = pdf[f].astype(t, copy=False) TypeError: astype() got an unexpected keyword argument 'copy' ``` without `copy` ``` a int32 b object c bool d float32 dtype: object ``` **With Pandas 0.10.0** - released, 2012-12 ``` Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/sql/dataframe.py", line 1734, in toPandas pdf[f] = pdf[f].astype(t, copy=False) TypeError: astype() got an unexpected keyword argument 'copy' ``` without `copy` ``` a int64 # <- this should be 'int32' b object c bool d float64 # <- this should be 'float32' ``` ## How was this patch tested? Manually tested with Pandas from 0.10.0 to 0.13.0. Author: hyukjinkwon Closes #18403 from HyukjinKwon/SPARK-21193. --- python/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/setup.py b/python/setup.py index f50035435e26b..2644d3e79dea1 100644 --- a/python/setup.py +++ b/python/setup.py @@ -199,7 +199,7 @@ def _supports_symlinks(): extras_require={ 'ml': ['numpy>=1.7'], 'mllib': ['numpy>=1.7'], - 'sql': ['pandas'] + 'sql': ['pandas>=0.13.0'] }, classifiers=[ 'Development Status :: 5 - Production/Stable', From f3dea60793d86212ba1068e88ad89cb3dcf07801 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 23 Jun 2017 09:28:02 -0700 Subject: [PATCH 067/118] [SPARK-21144][SQL] Print a warning if the data schema and partition schema have the duplicate columns ## What changes were proposed in this pull request? The current master outputs unexpected results when the data schema and partition schema have the duplicate columns: ``` withTempPath { dir => val basePath = dir.getCanonicalPath spark.range(0, 3).toDF("foo").write.parquet(new Path(basePath, "foo=1").toString) spark.range(0, 3).toDF("foo").write.parquet(new Path(basePath, "foo=a").toString) spark.read.parquet(basePath).show() } +---+ |foo| +---+ | 1| | 1| | a| | a| | 1| | a| +---+ ``` This patch added code to print a warning when the duplication found. ## How was this patch tested? Manually checked. Author: Takeshi Yamamuro Closes #18375 from maropu/SPARK-21144-3. --- .../apache/spark/sql/util/SchemaUtils.scala | 53 +++++++++++++++++++ .../execution/datasources/DataSource.scala | 6 +++ 2 files changed, 59 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala new file mode 100644 index 0000000000000..e881685ce6262 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.internal.Logging + + +/** + * Utils for handling schemas. + * + * TODO: Merge this file with [[org.apache.spark.ml.util.SchemaUtils]]. + */ +private[spark] object SchemaUtils extends Logging { + + /** + * Checks if input column names have duplicate identifiers. Prints a warning message if + * the duplication exists. + * + * @param columnNames column names to check + * @param colType column type name, used in a warning message + * @param caseSensitiveAnalysis whether duplication checks should be case sensitive or not + */ + def checkColumnNameDuplication( + columnNames: Seq[String], colType: String, caseSensitiveAnalysis: Boolean): Unit = { + val names = if (caseSensitiveAnalysis) { + columnNames + } else { + columnNames.map(_.toLowerCase) + } + if (names.distinct.length != names.length) { + val duplicateColumns = names.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => s"`$x`" + } + logWarning(s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}. " + + "You might need to assign different column names.") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 08c78e6e326af..75e530607570f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{CalendarIntervalType, StructType} +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.Utils /** @@ -182,6 +183,11 @@ case class DataSource( throw new AnalysisException( s"Unable to infer schema for $format. It must be specified manually.") } + + SchemaUtils.checkColumnNameDuplication( + (dataSchema ++ partitionSchema).map(_.name), "in the data schema and the partition schema", + sparkSession.sessionState.conf.caseSensitiveAnalysis) + (dataSchema, partitionSchema) } From 07479b3cfb7a617a18feca14e9e31c208c80630e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 23 Jun 2017 09:59:24 -0700 Subject: [PATCH 068/118] [SPARK-21149][R] Add job description API for R ## What changes were proposed in this pull request? Extend `setJobDescription` to SparkR API. ## How was this patch tested? It looks difficult to add a test. Manually tested as below: ```r df <- createDataFrame(iris) count(df) setJobDescription("This is an example job.") count(df) ``` prints ... ![2017-06-22 12 05 49](https://user-images.githubusercontent.com/6477701/27415670-2a649936-5743-11e7-8e95-312f1cd103af.png) Author: hyukjinkwon Closes #18382 from HyukjinKwon/SPARK-21149. --- R/pkg/NAMESPACE | 3 ++- R/pkg/R/sparkR.R | 17 +++++++++++++++++ R/pkg/tests/fulltests/test_context.R | 1 + 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 229de4a997eef..b7fdae58de459 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -75,7 +75,8 @@ exportMethods("glm", # Job group lifecycle management methods export("setJobGroup", "clearJobGroup", - "cancelJobGroup") + "cancelJobGroup", + "setJobDescription") # Export Utility methods export("setLogLevel") diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index d0a12b7ecec65..f2d2620e5447a 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -535,6 +535,23 @@ cancelJobGroup <- function(sc, groupId) { } } +#' Set a human readable description of the current job. +#' +#' Set a description that is shown as a job description in UI. +#' +#' @param value The job description of the current job. +#' @rdname setJobDescription +#' @name setJobDescription +#' @examples +#'\dontrun{ +#' setJobDescription("This is an example job.") +#'} +#' @note setJobDescription since 2.3.0 +setJobDescription <- function(value) { + sc <- getSparkContext() + invisible(callJMethod(sc, "setJobDescription", value)) +} + sparkConfToSubmitOps <- new.env() sparkConfToSubmitOps[["spark.driver.memory"]] <- "--driver-memory" sparkConfToSubmitOps[["spark.driver.extraClassPath"]] <- "--driver-class-path" diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index 710485d56685a..77635c5a256b9 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -100,6 +100,7 @@ test_that("job group functions can be called", { setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") clearJobGroup() + setJobDescription("job description") suppressWarnings(setJobGroup(sc, "groupId", "job description", TRUE)) suppressWarnings(cancelJobGroup(sc, "groupId")) From b803b66a8133f705463039325ee71ee6827ce1a7 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Fri, 23 Jun 2017 10:33:53 -0700 Subject: [PATCH 069/118] [SPARK-21180][SQL] Remove conf from stats functions since now we have conf in LogicalPlan ## What changes were proposed in this pull request? After wiring `SQLConf` in logical plan ([PR 18299](https://github.com/apache/spark/pull/18299)), we can remove the need of passing `conf` into `def stats` and `def computeStats`. ## How was this patch tested? Covered by existing tests, plus some modified existing tests. Author: wangzhenhua Author: Zhenhua Wang Closes #18391 from wzhfy/removeConf. --- .../sql/catalyst/catalog/interface.scala | 3 +- .../optimizer/CostBasedJoinReorder.scala | 4 +- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../optimizer/StarSchemaDetection.scala | 14 ++-- .../plans/logical/LocalRelation.scala | 3 +- .../catalyst/plans/logical/LogicalPlan.scala | 15 ++--- .../plans/logical/basicLogicalOperators.scala | 65 +++++++++---------- .../sql/catalyst/plans/logical/hints.scala | 5 +- .../statsEstimation/AggregateEstimation.scala | 7 +- .../statsEstimation/EstimationUtils.scala | 5 +- .../statsEstimation/FilterEstimation.scala | 5 +- .../statsEstimation/JoinEstimation.scala | 21 +++--- .../statsEstimation/ProjectEstimation.scala | 7 +- .../optimizer/JoinOptimizationSuite.scala | 2 +- .../optimizer/LimitPushdownSuite.scala | 6 +- .../AggregateEstimationSuite.scala | 30 +++++---- .../BasicStatsEstimationSuite.scala | 27 +++++--- .../FilterEstimationSuite.scala | 2 +- .../statsEstimation/JoinEstimationSuite.scala | 26 ++++---- .../ProjectEstimationSuite.scala | 4 +- .../StatsEstimationTestBase.scala | 18 +++-- .../spark/sql/execution/ExistingRDD.scala | 5 +- .../spark/sql/execution/QueryExecution.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 13 ++-- .../execution/columnar/InMemoryRelation.scala | 3 +- .../datasources/LogicalRelation.scala | 3 +- .../sql/execution/streaming/memory.scala | 3 +- .../apache/spark/sql/CachedTableSuite.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 2 +- .../org/apache/spark/sql/JoinSuite.scala | 2 +- .../spark/sql/StatisticsCollectionSuite.scala | 18 ++--- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../datasources/HadoopFsRelationSuite.scala | 2 +- .../execution/streaming/MemorySinkSuite.scala | 6 +- .../apache/spark/sql/test/SQLTestData.scala | 3 - .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../spark/sql/hive/StatisticsSuite.scala | 10 +-- .../PruneFileSourcePartitionsSuite.scala | 2 +- 38 files changed, 178 insertions(+), 173 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index c043ed9c431b7..b63bef9193332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Attri import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -436,7 +435,7 @@ case class CatalogRelation( createTime = -1 )) - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { // For data source tables, we will create a `LogicalRelation` and won't call this method, for // hive serde tables, we will always generate a statistics. // TODO: unify the table stats generation. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 51eca6ca33760..3a7543e2141e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -58,7 +58,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr // Do reordering if the number of items is appropriate and join conditions exist. // We also need to check if costs of all items can be evaluated. if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty && - items.forall(_.stats(conf).rowCount.isDefined)) { + items.forall(_.stats.rowCount.isDefined)) { JoinReorderDP.search(conf, items, conditions, output) } else { plan @@ -322,7 +322,7 @@ object JoinReorderDP extends PredicateHelper with Logging { /** Get the cost of the root node of this plan tree. */ def rootCost(conf: SQLConf): Cost = { if (itemIds.size > 1) { - val rootStats = plan.stats(conf) + val rootStats = plan.stats Cost(rootStats.rowCount.get, rootStats.sizeInBytes) } else { // If the plan is a leaf item, it has zero cost. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3ab70fb90470c..b410312030c5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -317,7 +317,7 @@ case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] { case FullOuter => (left.maxRows, right.maxRows) match { case (None, None) => - if (left.stats(conf).sizeInBytes >= right.stats(conf).sizeInBytes) { + if (left.stats.sizeInBytes >= right.stats.sizeInBytes) { join.copy(left = maybePushLimit(exp, left)) } else { join.copy(right = maybePushLimit(exp, right)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index 97ee9988386dd..ca729127e7d1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -82,7 +82,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { // Find if the input plans are eligible for star join detection. // An eligible plan is a base table access with valid statistics. val foundEligibleJoin = input.forall { - case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true + case PhysicalOperation(_, _, t: LeafNode) if t.stats.rowCount.isDefined => true case _ => false } @@ -181,7 +181,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { val leafCol = findLeafNodeCol(column, plan) leafCol match { case Some(col) if t.outputSet.contains(col) => - val stats = t.stats(conf) + val stats = t.stats stats.rowCount match { case Some(rowCount) if rowCount >= 0 => if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { @@ -237,7 +237,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { val leafCol = findLeafNodeCol(column, plan) leafCol match { case Some(col) if t.outputSet.contains(col) => - val stats = t.stats(conf) + val stats = t.stats stats.attributeStats.nonEmpty && stats.attributeStats.contains(col) case None => false } @@ -296,11 +296,11 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { */ private def getTableAccessCardinality( input: LogicalPlan): Option[BigInt] = input match { - case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined => - if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) { - Option(input.stats(conf).rowCount.get) + case PhysicalOperation(_, cond, t: LeafNode) if t.stats.rowCount.isDefined => + if (conf.cboEnabled && input.stats.rowCount.isDefined) { + Option(input.stats.rowCount.get) } else { - Option(t.stats(conf).rowCount.get) + Option(t.stats.rowCount.get) } case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 9cd5dfd21b160..dc2add64b68b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -67,7 +66,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) } } - override def computeStats(conf: SQLConf): Statistics = + override def computeStats: Statistics = Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 95b4165f6b10d..0c098ac0209e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -90,8 +89,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai * first time. If the configuration changes, the cache can be invalidated by calling * [[invalidateStatsCache()]]. */ - final def stats(conf: SQLConf): Statistics = statsCache.getOrElse { - statsCache = Some(computeStats(conf)) + final def stats: Statistics = statsCache.getOrElse { + statsCache = Some(computeStats) statsCache.get } @@ -108,11 +107,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai * * [[LeafNode]]s must override this. */ - protected def computeStats(conf: SQLConf): Statistics = { + protected def computeStats: Statistics = { if (children.isEmpty) { throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") } - Statistics(sizeInBytes = children.map(_.stats(conf).sizeInBytes).product) + Statistics(sizeInBytes = children.map(_.stats.sizeInBytes).product) } override def verboseStringWithSuffix: String = { @@ -333,13 +332,13 @@ abstract class UnaryNode extends LogicalPlan { override protected def validConstraints: Set[Expression] = child.constraints - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { // There should be some overhead in Row object, the size should not be zero when there is // no columns, this help to prevent divide-by-zero error. val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8 val outputRowSize = output.map(_.dataType.defaultSize).sum + 8 // Assume there will be the same number of rows as child has. - var sizeInBytes = (child.stats(conf).sizeInBytes * outputRowSize) / childRowSize + var sizeInBytes = (child.stats.sizeInBytes * outputRowSize) / childRowSize if (sizeInBytes == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // (product of children). @@ -347,7 +346,7 @@ abstract class UnaryNode extends LogicalPlan { } // Don't propagate rowCount and attributeStats, since they are not estimated here. - Statistics(sizeInBytes = sizeInBytes, hints = child.stats(conf).hints) + Statistics(sizeInBytes = sizeInBytes, hints = child.stats.hints) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 6e88b7a57dc33..d8f89b108e63f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.RandomSampler @@ -65,11 +64,11 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend override def validConstraints: Set[Expression] = child.constraints.union(getAliasedConstraints(projectList)) - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { if (conf.cboEnabled) { - ProjectEstimation.estimate(conf, this).getOrElse(super.computeStats(conf)) + ProjectEstimation.estimate(this).getOrElse(super.computeStats) } else { - super.computeStats(conf) + super.computeStats } } } @@ -139,11 +138,11 @@ case class Filter(condition: Expression, child: LogicalPlan) child.constraints.union(predicates.toSet) } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { if (conf.cboEnabled) { - FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf)) + FilterEstimation(this).estimate.getOrElse(super.computeStats) } else { - super.computeStats(conf) + super.computeStats } } } @@ -192,13 +191,13 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation } } - override def computeStats(conf: SQLConf): Statistics = { - val leftSize = left.stats(conf).sizeInBytes - val rightSize = right.stats(conf).sizeInBytes + override def computeStats: Statistics = { + val leftSize = left.stats.sizeInBytes + val rightSize = right.stats.sizeInBytes val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize Statistics( sizeInBytes = sizeInBytes, - hints = left.stats(conf).hints.resetForJoin()) + hints = left.stats.hints.resetForJoin()) } } @@ -209,8 +208,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override protected def validConstraints: Set[Expression] = leftConstraints - override def computeStats(conf: SQLConf): Statistics = { - left.stats(conf).copy() + override def computeStats: Statistics = { + left.stats.copy() } } @@ -248,8 +247,8 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { children.length > 1 && childrenResolved && allChildrenCompatible } - override def computeStats(conf: SQLConf): Statistics = { - val sizeInBytes = children.map(_.stats(conf).sizeInBytes).sum + override def computeStats: Statistics = { + val sizeInBytes = children.map(_.stats.sizeInBytes).sum Statistics(sizeInBytes = sizeInBytes) } @@ -357,20 +356,20 @@ case class Join( case _ => resolvedExceptNatural } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { def simpleEstimation: Statistics = joinType match { case LeftAnti | LeftSemi => // LeftSemi and LeftAnti won't ever be bigger than left - left.stats(conf) + left.stats case _ => // Make sure we don't propagate isBroadcastable in other joins, because // they could explode the size. - val stats = super.computeStats(conf) + val stats = super.computeStats stats.copy(hints = stats.hints.resetForJoin()) } if (conf.cboEnabled) { - JoinEstimation.estimate(conf, this).getOrElse(simpleEstimation) + JoinEstimation.estimate(this).getOrElse(simpleEstimation) } else { simpleEstimation } @@ -523,7 +522,7 @@ case class Range( override def newInstance(): Range = copy(output = output.map(_.newInstance())) - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { val sizeInBytes = LongType.defaultSize * numElements Statistics( sizeInBytes = sizeInBytes ) } @@ -556,20 +555,20 @@ case class Aggregate( child.constraints.union(getAliasedConstraints(nonAgg)) } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { def simpleEstimation: Statistics = { if (groupingExpressions.isEmpty) { Statistics( sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1), rowCount = Some(1), - hints = child.stats(conf).hints) + hints = child.stats.hints) } else { - super.computeStats(conf) + super.computeStats } } if (conf.cboEnabled) { - AggregateEstimation.estimate(conf, this).getOrElse(simpleEstimation) + AggregateEstimation.estimate(this).getOrElse(simpleEstimation) } else { simpleEstimation } @@ -672,8 +671,8 @@ case class Expand( override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) - override def computeStats(conf: SQLConf): Statistics = { - val sizeInBytes = super.computeStats(conf).sizeInBytes * projections.length + override def computeStats: Statistics = { + val sizeInBytes = super.computeStats.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } @@ -743,9 +742,9 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN case _ => None } } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val childStats = child.stats(conf) + val childStats = child.stats val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit) // Don't propagate column stats, because we don't know the distribution after a limit operation Statistics( @@ -763,9 +762,9 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case _ => None } } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val childStats = child.stats(conf) + val childStats = child.stats if (limit == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // (product of children). @@ -832,9 +831,9 @@ case class Sample( override def output: Seq[Attribute] = child.output - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { val ratio = upperBound - lowerBound - val childStats = child.stats(conf) + val childStats = child.stats var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio) if (sizeInBytes == 0) { sizeInBytes = 1 @@ -898,7 +897,7 @@ case class RepartitionByExpression( case object OneRowRelation extends LeafNode { override def maxRows: Option[Long] = Some(1) override def output: Seq[Attribute] = Nil - override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = 1) + override def computeStats: Statistics = Statistics(sizeInBytes = 1) } /** A logical plan for `dropDuplicates`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index e49970df80457..8479c702d7561 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf /** * A general hint for the child that is not yet resolved. This node is generated by the parser and @@ -44,8 +43,8 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) override lazy val canonicalized: LogicalPlan = child.canonicalized - override def computeStats(conf: SQLConf): Statistics = { - val stats = child.stats(conf) + override def computeStats: Statistics = { + val stats = child.stats stats.copy(hints = hints) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index a0c23198451a8..c41fac4015ec0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics} -import org.apache.spark.sql.internal.SQLConf object AggregateEstimation { @@ -29,13 +28,13 @@ object AggregateEstimation { * Estimate the number of output rows based on column stats of group-by columns, and propagate * column stats for aggregate expressions. */ - def estimate(conf: SQLConf, agg: Aggregate): Option[Statistics] = { - val childStats = agg.child.stats(conf) + def estimate(agg: Aggregate): Option[Statistics] = { + val childStats = agg.child.stats // Check if we have column stats for all group-by columns. val colStatsExist = agg.groupingExpressions.forall { e => e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute]) } - if (rowCountsExist(conf, agg.child) && colStatsExist) { + if (rowCountsExist(agg.child) && colStatsExist) { // Multiply distinct counts of group-by columns. This is an upper bound, which assumes // the data contains all combinations of distinct values of group-by columns. var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index e5fcdf9039be9..9c34a9b7aa756 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -21,15 +21,14 @@ import scala.math.BigDecimal.RoundingMode import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, _} object EstimationUtils { /** Check if each plan has rowCount in its statistics. */ - def rowCountsExist(conf: SQLConf, plans: LogicalPlan*): Boolean = - plans.forall(_.stats(conf).rowCount.isDefined) + def rowCountsExist(plans: LogicalPlan*): Boolean = + plans.forall(_.stats.rowCount.isDefined) /** Check if each attribute has column stat in the corresponding statistics. */ def columnStatsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index df190867189ec..5a3bee7b9e449 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -25,12 +25,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging { +case class FilterEstimation(plan: Filter) extends Logging { - private val childStats = plan.child.stats(catalystConf) + private val childStats = plan.child.stats private val colStatsMap = new ColumnStatsMap(childStats.attributeStats) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 8ef905c45d50d..f48196997a24d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -import org.apache.spark.sql.internal.SQLConf object JoinEstimation extends Logging { @@ -34,12 +33,12 @@ object JoinEstimation extends Logging { * Estimate statistics after join. Return `None` if the join type is not supported, or we don't * have enough statistics for estimation. */ - def estimate(conf: SQLConf, join: Join): Option[Statistics] = { + def estimate(join: Join): Option[Statistics] = { join.joinType match { case Inner | Cross | LeftOuter | RightOuter | FullOuter => - InnerOuterEstimation(conf, join).doEstimate() + InnerOuterEstimation(join).doEstimate() case LeftSemi | LeftAnti => - LeftSemiAntiEstimation(conf, join).doEstimate() + LeftSemiAntiEstimation(join).doEstimate() case _ => logDebug(s"[CBO] Unsupported join type: ${join.joinType}") None @@ -47,16 +46,16 @@ object JoinEstimation extends Logging { } } -case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging { +case class InnerOuterEstimation(join: Join) extends Logging { - private val leftStats = join.left.stats(conf) - private val rightStats = join.right.stats(conf) + private val leftStats = join.left.stats + private val rightStats = join.right.stats /** * Estimate output size and number of rows after a join operator, and update output column stats. */ def doEstimate(): Option[Statistics] = join match { - case _ if !rowCountsExist(conf, join.left, join.right) => + case _ if !rowCountsExist(join.left, join.right) => None case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => @@ -273,13 +272,13 @@ case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging { } } -case class LeftSemiAntiEstimation(conf: SQLConf, join: Join) { +case class LeftSemiAntiEstimation(join: Join) { def doEstimate(): Option[Statistics] = { // TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic // column stats. Now we just propagate the statistics from left side. We should do more // accurate estimation when advanced stats (e.g. histograms) are available. - if (rowCountsExist(conf, join.left)) { - val leftStats = join.left.stats(conf) + if (rowCountsExist(join.left)) { + val leftStats = join.left.stats // Propagate the original column stats for cartesian product val outputRows = leftStats.rowCount.get Some(Statistics( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala index d700cd3b20f7d..489eb904ffd05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics} -import org.apache.spark.sql.internal.SQLConf object ProjectEstimation { import EstimationUtils._ - def estimate(conf: SQLConf, project: Project): Option[Statistics] = { - if (rowCountsExist(conf, project.child)) { - val childStats = project.child.stats(conf) + def estimate(project: Project): Option[Statistics] = { + if (rowCountsExist(project.child)) { + val childStats = project.child.stats val inputAttrStats = childStats.attributeStats // Match alias with its child's column stat val aliasStats = project.expressions.collect { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 105407d43bf39..a6584aa5fbba7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -142,7 +142,7 @@ class JoinOptimizationSuite extends PlanTest { comparePlans(optimized, expected) val broadcastChildren = optimized.collect { - case Join(_, r, _, _) if r.stats(conf).sizeInBytes == 1 => r + case Join(_, r, _, _) if r.stats.sizeInBytes == 1 => r } assert(broadcastChildren.size == 1) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index fb34c82de468b..d8302dfc9462d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -112,7 +112,7 @@ class LimitPushdownSuite extends PlanTest { } test("full outer join where neither side is limited and both sides have same statistics") { - assert(x.stats(conf).sizeInBytes === y.stats(conf).sizeInBytes) + assert(x.stats.sizeInBytes === y.stats.sizeInBytes) val originalQuery = x.join(y, FullOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Limit(1, LocalLimit(1, x).join(y, FullOuter)).analyze @@ -121,7 +121,7 @@ class LimitPushdownSuite extends PlanTest { test("full outer join where neither side is limited and left side has larger statistics") { val xBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('x) - assert(xBig.stats(conf).sizeInBytes > y.stats(conf).sizeInBytes) + assert(xBig.stats.sizeInBytes > y.stats.sizeInBytes) val originalQuery = xBig.join(y, FullOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Limit(1, LocalLimit(1, xBig).join(y, FullOuter)).analyze @@ -130,7 +130,7 @@ class LimitPushdownSuite extends PlanTest { test("full outer join where neither side is limited and right side has larger statistics") { val yBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('y) - assert(x.stats(conf).sizeInBytes < yBig.stats(conf).sizeInBytes) + assert(x.stats.sizeInBytes < yBig.stats.sizeInBytes) val originalQuery = x.join(yBig, FullOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Limit(1, x.join(LocalLimit(1, yBig), FullOuter)).analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 38483a298cef0..30ddf03bd3c4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -100,17 +100,23 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { size = Some(4 * (8 + 4)), attributeStats = AttributeMap(Seq("key12").map(nameToColInfo))) - val noGroupAgg = Aggregate(groupingExpressions = Nil, - aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) - assert(noGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == - // overhead + count result size - Statistics(sizeInBytes = 8 + 8, rowCount = Some(1))) - - val hasGroupAgg = Aggregate(groupingExpressions = attributes, - aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child) - assert(hasGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == - // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize - Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) + val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) + try { + SQLConf.get.setConf(SQLConf.CBO_ENABLED, false) + val noGroupAgg = Aggregate(groupingExpressions = Nil, + aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) + assert(noGroupAgg.stats == + // overhead + count result size + Statistics(sizeInBytes = 8 + 8, rowCount = Some(1))) + + val hasGroupAgg = Aggregate(groupingExpressions = attributes, + aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child) + assert(hasGroupAgg.stats == + // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize + Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) + } finally { + SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) + } } private def checkAggStats( @@ -134,6 +140,6 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { rowCount = Some(expectedOutputRowCount), attributeStats = expectedAttrStats) - assert(testAgg.stats(conf) == expectedStats) + assert(testAgg.stats == expectedStats) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 833f5a71994f7..e9ed36feec48c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -57,16 +57,16 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { val localLimit = LocalLimit(Literal(2), plan) val globalLimit = GlobalLimit(Literal(2), plan) // LocalLimit's stats is just its child's stats except column stats - checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(localLimit, plan.stats.copy(attributeStats = AttributeMap(Nil))) checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2))) } test("limit estimation: limit > child's rowCount") { val localLimit = LocalLimit(Literal(20), plan) val globalLimit = GlobalLimit(Literal(20), plan) - checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(localLimit, plan.stats.copy(attributeStats = AttributeMap(Nil))) // Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats. - checkStats(globalLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(globalLimit, plan.stats.copy(attributeStats = AttributeMap(Nil))) } test("limit estimation: limit = 0") { @@ -113,12 +113,19 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { plan: LogicalPlan, expectedStatsCboOn: Statistics, expectedStatsCboOff: Statistics): Unit = { - // Invalidate statistics - plan.invalidateStatsCache() - assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> true)) == expectedStatsCboOn) - - plan.invalidateStatsCache() - assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == expectedStatsCboOff) + val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) + try { + // Invalidate statistics + plan.invalidateStatsCache() + SQLConf.get.setConf(SQLConf.CBO_ENABLED, true) + assert(plan.stats == expectedStatsCboOn) + + plan.invalidateStatsCache() + SQLConf.get.setConf(SQLConf.CBO_ENABLED, false) + assert(plan.stats == expectedStatsCboOff) + } finally { + SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) + } } /** Check estimated stats when it's the same whether cbo is turned on or off. */ @@ -135,6 +142,6 @@ private case class DummyLogicalPlan( cboStats: Statistics) extends LogicalPlan { override def output: Seq[Attribute] = Nil override def children: Seq[LogicalPlan] = Nil - override def computeStats(conf: SQLConf): Statistics = + override def computeStats: Statistics = if (conf.cboEnabled) cboStats else defaultStats } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 2fa53a6466ef2..455037e6c9952 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -620,7 +620,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { rowCount = Some(expectedRowCount), attributeStats = expectedAttributeMap) - val filterStats = filter.stats(conf) + val filterStats = filter.stats assert(filterStats.sizeInBytes == expectedStats.sizeInBytes) assert(filterStats.rowCount == expectedStats.rowCount) val rowCountValue = filterStats.rowCount.getOrElse(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 2d6b6e8e21f34..097c78eb27fca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -77,7 +77,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // Keep the column stat from both sides unchanged. attributeStats = AttributeMap( Seq("key-1-5", "key-5-9", "key-1-2", "key-2-4").map(nameToColInfo))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("disjoint inner join") { @@ -90,7 +90,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 1, rowCount = Some(0), attributeStats = AttributeMap(Nil)) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("disjoint left outer join") { @@ -106,7 +106,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // Null count for right side columns = left row count Seq(nameToAttr("key-1-2") -> nullColumnStat(nameToAttr("key-1-2").dataType, 5), nameToAttr("key-2-4") -> nullColumnStat(nameToAttr("key-2-4").dataType, 5)))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("disjoint right outer join") { @@ -122,7 +122,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // Null count for left side columns = right row count Seq(nameToAttr("key-1-5") -> nullColumnStat(nameToAttr("key-1-5").dataType, 3), nameToAttr("key-5-9") -> nullColumnStat(nameToAttr("key-5-9").dataType, 3)))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("disjoint full outer join") { @@ -140,7 +140,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3), nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5), nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5)))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("inner join") { @@ -161,7 +161,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap( Seq(nameToAttr("key-1-5") -> joinedColStat, nameToAttr("key-1-2") -> joinedColStat, nameToAttr("key-5-9") -> colStatForkey59, nameToColInfo("key-2-4")))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("inner join with multiple equi-join keys") { @@ -183,7 +183,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap( Seq(nameToAttr("key-1-2") -> joinedColStat1, nameToAttr("key-1-2") -> joinedColStat1, nameToAttr("key-2-4") -> joinedColStat2, nameToAttr("key-2-3") -> joinedColStat2))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("left outer join") { @@ -201,7 +201,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap( Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-3"), nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("right outer join") { @@ -219,7 +219,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap( Seq(nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat, nameToColInfo("key-1-2"), nameToColInfo("key-2-3")))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("full outer join") { @@ -234,7 +234,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // Keep the column stat from both sides unchanged. attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4"), nameToColInfo("key-1-2"), nameToColInfo("key-2-3")))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("left semi/anti join") { @@ -248,7 +248,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 3 * (8 + 4 * 2), rowCount = Some(3), attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4")))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } } @@ -306,7 +306,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 1 * (8 + 2 * getColSize(key1, columnInfo1(key1))), rowCount = Some(1), attributeStats = AttributeMap(Seq(key1 -> columnInfo1(key1), key2 -> columnInfo1(key1)))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } } } @@ -323,6 +323,6 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 1, rowCount = Some(0), attributeStats = AttributeMap(Nil)) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index a5c4d22a29386..cda54fa9d64f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -45,7 +45,7 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 2 * (8 + 4 + 4), rowCount = Some(2), attributeStats = expectedAttrStats) - assert(proj.stats(conf) == expectedStats) + assert(proj.stats == expectedStats) } test("project on empty table") { @@ -131,6 +131,6 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { sizeInBytes = expectedSize, rowCount = Some(expectedRowCount), attributeStats = projectAttrMap) - assert(proj.stats(conf) == expectedStats) + assert(proj.stats == expectedStats) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index 263f4e18803d5..eaa33e44a6a5a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -21,14 +21,24 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED} import org.apache.spark.sql.types.{IntegerType, StringType} trait StatsEstimationTestBase extends SparkFunSuite { - /** Enable stats estimation based on CBO. */ - protected val conf = new SQLConf().copy(CASE_SENSITIVE -> true, CBO_ENABLED -> true) + var originalValue: Boolean = false + + override def beforeAll(): Unit = { + super.beforeAll() + // Enable stats estimation based on CBO. + originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) + SQLConf.get.setConf(SQLConf.CBO_ENABLED, true) + } + + override def afterAll(): Unit = { + SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) + super.afterAll() + } def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { // For UTF8String: base + offset + numBytes @@ -55,7 +65,7 @@ case class StatsTestPlan( attributeStats: AttributeMap[ColumnStat], size: Option[BigInt] = None) extends LeafNode { override def output: Seq[Attribute] = outputList - override def computeStats(conf: SQLConf): Statistics = Statistics( + override def computeStats: Statistics = Statistics( // If sizeInBytes is useless in testing, we just use a fake value sizeInBytes = size.getOrElse(Int.MaxValue), rowCount = Some(rowCount), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 3d1b481a53e75..66f66a289a065 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -89,7 +88,7 @@ case class ExternalRDD[T]( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats(conf: SQLConf): Statistics = Statistics( + @transient override def computeStats: Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) @@ -157,7 +156,7 @@ case class LogicalRDD( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats(conf: SQLConf): Statistics = Statistics( + @transient override def computeStats: Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 34998cbd61552..c7cac332a0377 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -221,7 +221,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { def stringWithStats: String = { // trigger to compute stats for logical plans - optimizedPlan.stats(sparkSession.sessionState.conf) + optimizedPlan.stats // only show optimized logical plan and physical plan s"""== Optimized Logical Plan == diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ea86f6e00fefa..a57d5abb90c0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -114,9 +113,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.stats(conf).hints.broadcast || - (plan.stats(conf).sizeInBytes >= 0 && - plan.stats(conf).sizeInBytes <= conf.autoBroadcastJoinThreshold) + plan.stats.hints.broadcast || + (plan.stats.sizeInBytes >= 0 && + plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold) } /** @@ -126,7 +125,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * dynamic. */ private def canBuildLocalHashMap(plan: LogicalPlan): Boolean = { - plan.stats(conf).sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions + plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions } /** @@ -137,7 +136,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * use the size of bytes here as estimation. */ private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { - a.stats(conf).sizeInBytes * 3 <= b.stats(conf).sizeInBytes + a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes } private def canBuildRight(joinType: JoinType): Boolean = joinType match { @@ -206,7 +205,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Join(left, right, joinType, condition) => val buildSide = - if (right.stats(conf).sizeInBytes <= left.stats(conf).sizeInBytes) { + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) { BuildRight } else { BuildLeft diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 456a8f3b20f30..2972132336de0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator @@ -70,7 +69,7 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { if (batchStats.value == 0L) { // Underlying columnar RDD hasn't been materialized, no useful statistics information // available, return the default statistics. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 3813f953e06a3..c1b2895f1747e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.util.Utils @@ -46,7 +45,7 @@ case class LogicalRelation( // Only care about relation when canonicalizing. override def preCanonicalized: LogicalPlan = copy(catalogTable = None) - @transient override def computeStats(conf: SQLConf): Statistics = { + @transient override def computeStats: Statistics = { catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( Statistics(sizeInBytes = relation.sizeInBytes)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 7eaa803a9ecb4..a5dac469f85b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -230,6 +229,6 @@ case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum - override def computeStats(conf: SQLConf): Statistics = + override def computeStats: Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 8532a5b5bc8eb..506cc2548e260 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -313,7 +313,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum - assert(cached.stats(sqlConf).sizeInBytes === actualSizeInBytes) + assert(cached.stats.sizeInBytes === actualSizeInBytes) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 165176f6c040e..87b7b090de3bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1146,7 +1146,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { // instead of Int for avoiding possible overflow. val ds = (0 to 10000).map( i => (i, Seq((i, Seq((i, "This is really not that long of a string")))))).toDS() - val sizeInBytes = ds.logicalPlan.stats(sqlConf).sizeInBytes + val sizeInBytes = ds.logicalPlan.stats.sizeInBytes // sizeInBytes is 2404280404, before the fix, it overflows to a negative number assert(sizeInBytes > 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 1a66aa85f5a02..895ca196a7a51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -33,7 +33,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { setupTestData() def statisticSizeInByte(df: DataFrame): BigInt = { - df.queryExecution.optimizedPlan.stats(sqlConf).sizeInBytes + df.queryExecution.optimizedPlan.stats.sizeInBytes } test("equi-join is hash-join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 601324f2c0172..9824062f969b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -60,7 +60,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared val df = df1.join(df2, Seq("k"), "left") val sizes = df.queryExecution.analyzed.collect { case g: Join => - g.stats(conf).sizeInBytes + g.stats.sizeInBytes } assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") @@ -107,9 +107,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { val rdd = sparkContext.range(1, 100).map(i => Row(i, i)) val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType)) - assert(df.queryExecution.analyzed.stats(conf).sizeInBytes > + assert(df.queryExecution.analyzed.stats.sizeInBytes > spark.sessionState.conf.autoBroadcastJoinThreshold) - assert(df.selectExpr("a").queryExecution.analyzed.stats(conf).sizeInBytes > + assert(df.selectExpr("a").queryExecution.analyzed.stats.sizeInBytes > spark.sessionState.conf.autoBroadcastJoinThreshold) } @@ -250,13 +250,13 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils test("SPARK-18856: non-empty partitioned table should not report zero size") { withTable("ds_tbl", "hive_tbl") { spark.range(100).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("ds_tbl") - val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats(conf) + val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)") sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1") - val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats(conf) + val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats assert(stats2.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") } } @@ -296,10 +296,10 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) // Check relation statistics - assert(relation.stats(conf).sizeInBytes == 0) - assert(relation.stats(conf).rowCount == Some(0)) - assert(relation.stats(conf).attributeStats.size == 1) - val (attribute, colStat) = relation.stats(conf).attributeStats.head + assert(relation.stats.sizeInBytes == 0) + assert(relation.stats.rowCount == Some(0)) + assert(relation.stats.attributeStats.size == 1) + val (attribute, colStat) = relation.stats.attributeStats.head assert(attribute.name == "c1") assert(colStat == emptyColStat) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 109b1d9db60d2..8d411eb191cd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -126,7 +126,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { .toDF().createOrReplaceTempView("sizeTst") spark.catalog.cacheTable("sizeTst") assert( - spark.table("sizeTst").queryExecution.analyzed.stats(sqlConf).sizeInBytes > + spark.table("sizeTst").queryExecution.analyzed.stats.sizeInBytes > spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index becb3aa270401..caf03885e3873 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -36,7 +36,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { }) val totalSize = allFiles.map(_.length()).sum val df = spark.read.parquet(dir.toString) - assert(df.queryExecution.logical.stats(sqlConf).sizeInBytes === BigInt(totalSize)) + assert(df.queryExecution.logical.stats.sizeInBytes === BigInt(totalSize)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 24a7b7740fa5b..e8420eee7fe9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -216,15 +216,15 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { // Before adding data, check output checkAnswer(sink.allData, Seq.empty) - assert(plan.stats(sqlConf).sizeInBytes === 0) + assert(plan.stats.sizeInBytes === 0) sink.addBatch(0, 1 to 3) plan.invalidateStatsCache() - assert(plan.stats(sqlConf).sizeInBytes === 12) + assert(plan.stats.sizeInBytes === 12) sink.addBatch(1, 4 to 6) plan.invalidateStatsCache() - assert(plan.stats(sqlConf).sizeInBytes === 24) + assert(plan.stats.sizeInBytes === 24) } ignore("stress test") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index f9b3ff8405823..0cfe260e52152 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -21,7 +21,6 @@ import java.nio.charset.StandardCharsets import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} -import org.apache.spark.sql.internal.SQLConf /** * A collection of sample data used in SQL tests. @@ -29,8 +28,6 @@ import org.apache.spark.sql.internal.SQLConf private[sql] trait SQLTestData { self => protected def spark: SparkSession - protected def sqlConf: SQLConf = spark.sessionState.conf - // Helper object to import SQL implicits without a concrete SQLContext private object internalImplicits extends SQLImplicits { protected override def _sqlContext: SQLContext = self.spark.sqlContext diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index ff5afc8e3ce05..808dc013f170b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -154,7 +154,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log Some(partitionSchema)) val logicalRelation = cached.getOrElse { - val sizeInBytes = relation.stats(sparkSession.sessionState.conf).sizeInBytes.toLong + val sizeInBytes = relation.stats.sizeInBytes.toLong val fileIndex = { val index = new CatalogFileIndex(sparkSession, relation.tableMeta, sizeInBytes) if (lazyPruningEnabled) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 001bbc230ff18..279db9a397258 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -68,7 +68,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto assert(properties("totalSize").toLong <= 0, "external table totalSize must be <= 0") assert(properties("rawDataSize").toLong <= 0, "external table rawDataSize must be <= 0") - val sizeInBytes = relation.stats(conf).sizeInBytes + val sizeInBytes = relation.stats.sizeInBytes assert(sizeInBytes === BigInt(file1.length() + file2.length())) } } @@ -77,7 +77,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("analyze Hive serde tables") { def queryTotalSize(tableName: String): BigInt = - spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() @@ -659,7 +659,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("estimates the size of a test Hive serde tables") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { - case relation: CatalogRelation => relation.stats(conf).sizeInBytes + case relation: CatalogRelation => relation.stats.sizeInBytes } assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}") assert(sizes(0).equals(BigInt(5812)), @@ -679,7 +679,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { - case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats(conf).sizeInBytes + case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats.sizeInBytes } assert(sizes.size === 2 && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold, @@ -733,7 +733,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { - case relation: CatalogRelation => relation.stats(conf).sizeInBytes + case relation: CatalogRelation => relation.stats.sizeInBytes } assert(sizes.size === 2 && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index d91f25a4da013..3a724aa14f2a9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -86,7 +86,7 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te case relation: LogicalRelation => relation } assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}") - val size2 = relations(0).computeStats(conf).sizeInBytes + val size2 = relations(0).computeStats.sizeInBytes assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes) assert(size2 < tableStats.get.sizeInBytes) } From 1ebe7ffe072bcac03360e65e959a6cd36530a9c4 Mon Sep 17 00:00:00 2001 From: Dhruve Ashar Date: Fri, 23 Jun 2017 10:36:29 -0700 Subject: [PATCH 070/118] [SPARK-21181] Release byteBuffers to suppress netty error messages ## What changes were proposed in this pull request? We are explicitly calling release on the byteBuf's used to encode the string to Base64 to suppress the memory leak error message reported by netty. This is to make it less confusing for the user. ### Changes proposed in this fix By explicitly invoking release on the byteBuf's we are decrement the internal reference counts for the wrappedByteBuf's. Now, when the GC kicks in, these would be reclaimed as before, just that netty wouldn't report any memory leak error messages as the internal ref. counts are now 0. ## How was this patch tested? Ran a few spark-applications and examined the logs. The error message no longer appears. Original PR was opened against branch-2.1 => https://github.com/apache/spark/pull/18392 Author: Dhruve Ashar Closes #18407 from dhruve/master. --- .../spark/network/sasl/SparkSaslServer.java | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index e24fdf0c74de3..00f3e83dbc8b3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -34,6 +34,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; +import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.handler.codec.base64.Base64; import org.slf4j.Logger; @@ -187,14 +188,31 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback /* Encode a byte[] identifier as a Base64-encoded string. */ public static String encodeIdentifier(String identifier) { Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); - return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(StandardCharsets.UTF_8))) - .toString(StandardCharsets.UTF_8); + return getBase64EncodedString(identifier); } /** Encode a password as a base64-encoded char[] array. */ public static char[] encodePassword(String password) { Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); - return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(StandardCharsets.UTF_8))) - .toString(StandardCharsets.UTF_8).toCharArray(); + return getBase64EncodedString(password).toCharArray(); + } + + /** Return a Base64-encoded string. */ + private static String getBase64EncodedString(String str) { + ByteBuf byteBuf = null; + ByteBuf encodedByteBuf = null; + try { + byteBuf = Unpooled.wrappedBuffer(str.getBytes(StandardCharsets.UTF_8)); + encodedByteBuf = Base64.encode(byteBuf); + return encodedByteBuf.toString(StandardCharsets.UTF_8); + } finally { + // The release is called to suppress the memory leak error messages raised by netty. + if (byteBuf != null) { + byteBuf.release(); + if (encodedByteBuf != null) { + encodedByteBuf.release(); + } + } + } } } From 2ebd0838d165fe33b404e8d86c0fa445d1f47439 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 23 Jun 2017 10:55:02 -0700 Subject: [PATCH 071/118] [SPARK-21192][SS] Preserve State Store provider class configuration across StreamingQuery restarts ## What changes were proposed in this pull request? If the SQL conf for StateStore provider class is changed between restarts (i.e. query started with providerClass1 and attempted to restart using providerClass2), then the query will fail in a unpredictable way as files saved by one provider class cannot be used by the newer one. Ideally, the provider class used to start the query should be used to restart the query, and the configuration in the session where it is being restarted should be ignored. This PR saves the provider class config to OffsetSeqLog, in the same way # shuffle partitions is saved and recovered. ## How was this patch tested? new unit tests Author: Tathagata Das Closes #18402 from tdas/SPARK-21192. --- .../apache/spark/sql/internal/SQLConf.scala | 5 +- .../sql/execution/streaming/OffsetSeq.scala | 39 +++++++++++++- .../execution/streaming/StreamExecution.scala | 26 +++------- .../streaming/state/StateStore.scala | 3 +- .../streaming/state/StateStoreConf.scala | 2 +- .../streaming/OffsetSeqLogSuite.scala | 10 ++-- .../spark/sql/streaming/StreamSuite.scala | 51 +++++++++++++++---- 7 files changed, 96 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e609256db2802..9c8e26a8eeadf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -601,7 +601,8 @@ object SQLConf { "The class used to manage state data in stateful streaming queries. This class must " + "be a subclass of StateStoreProvider, and must have a zero-arg constructor.") .stringConf - .createOptional + .createWithDefault( + "org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider") val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot") @@ -897,7 +898,7 @@ class SQLConf extends Serializable with Logging { def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD) - def stateStoreProviderClass: Option[String] = getConf(STATE_STORE_PROVIDER_CLASS) + def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 8249adab4bba8..4e0a468b962a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -20,6 +20,10 @@ package org.apache.spark.sql.execution.streaming import org.json4s.NoTypeHints import org.json4s.jackson.Serialization +import org.apache.spark.internal.Logging +import org.apache.spark.sql.RuntimeConfig +import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS} + /** * An ordered collection of offsets, used to track the progress of processing data from one or more * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance @@ -78,7 +82,40 @@ case class OffsetSeqMetadata( def json: String = Serialization.write(this)(OffsetSeqMetadata.format) } -object OffsetSeqMetadata { +object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) + private val relevantSQLConfs = Seq(SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS) + def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) + + def apply( + batchWatermarkMs: Long, + batchTimestampMs: Long, + sessionConf: RuntimeConfig): OffsetSeqMetadata = { + val confs = relevantSQLConfs.map { conf => conf.key -> sessionConf.get(conf.key) }.toMap + OffsetSeqMetadata(batchWatermarkMs, batchTimestampMs, confs) + } + + /** Set the SparkSession configuration with the values in the metadata */ + def setSessionConf(metadata: OffsetSeqMetadata, sessionConf: RuntimeConfig): Unit = { + OffsetSeqMetadata.relevantSQLConfs.map(_.key).foreach { confKey => + + metadata.conf.get(confKey) match { + + case Some(valueInMetadata) => + // Config value exists in the metadata, update the session config with this value + val optionalValueInSession = sessionConf.getOption(confKey) + if (optionalValueInSession.isDefined && optionalValueInSession.get != valueInMetadata) { + logWarning(s"Updating the value of conf '$confKey' in current session from " + + s"'${optionalValueInSession.get}' to '$valueInMetadata'.") + } + sessionConf.set(confKey, valueInMetadata) + + case None => + // For backward compatibility, if a config was not recorded in the offset log, + // then log it, and let the existing conf value in SparkSession prevail. + logWarning (s"Conf '$confKey' was not found in the offset log, using existing value") + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 06bdec8b06407..d5f8d2acba92b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -125,9 +125,8 @@ class StreamExecution( } /** Metadata associated with the offset seq of a batch in the query. */ - protected var offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0, - conf = Map(SQLConf.SHUFFLE_PARTITIONS.key -> - sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS).toString)) + protected var offsetSeqMetadata = OffsetSeqMetadata( + batchWatermarkMs = 0, batchTimestampMs = 0, sparkSession.conf) override val id: UUID = UUID.fromString(streamMetadata.id) @@ -285,9 +284,8 @@ class StreamExecution( val sparkSessionToRunBatches = sparkSession.cloneSession() // Adaptive execution can change num shuffle partitions, disallow sparkSessionToRunBatches.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false") - offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0, - conf = Map(SQLConf.SHUFFLE_PARTITIONS.key -> - sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS.key))) + offsetSeqMetadata = OffsetSeqMetadata( + batchWatermarkMs = 0, batchTimestampMs = 0, sparkSessionToRunBatches.conf) if (state.compareAndSet(INITIALIZING, ACTIVE)) { // Unblock `awaitInitialization` @@ -441,21 +439,9 @@ class StreamExecution( // update offset metadata nextOffsets.metadata.foreach { metadata => - val shufflePartitionsSparkSession: Int = - sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS) - val shufflePartitionsToUse = metadata.conf.getOrElse(SQLConf.SHUFFLE_PARTITIONS.key, { - // For backward compatibility, if # partitions was not recorded in the offset log, - // then ensure it is not missing. The new value is picked up from the conf. - logWarning("Number of shuffle partitions from previous run not found in checkpoint. " - + s"Using the value from the conf, $shufflePartitionsSparkSession partitions.") - shufflePartitionsSparkSession - }) + OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) offsetSeqMetadata = OffsetSeqMetadata( - metadata.batchWatermarkMs, metadata.batchTimestampMs, - metadata.conf + (SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitionsToUse.toString)) - // Update conf with correct number of shuffle partitions - sparkSessionToRunBatches.conf.set( - SQLConf.SHUFFLE_PARTITIONS.key, shufflePartitionsToUse.toString) + metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) } /* identify the current batch id: if commit log indicates we successfully processed the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index a94ff8a7ebd1e..86886466c4f56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -172,8 +172,7 @@ object StateStoreProvider { indexOrdinal: Option[Int], // for sorting the data storeConf: StateStoreConf, hadoopConf: Configuration): StateStoreProvider = { - val providerClass = storeConf.providerClass.map(Utils.classForName) - .getOrElse(classOf[HDFSBackedStateStoreProvider]) + val providerClass = Utils.classForName(storeConf.providerClass) val provider = providerClass.newInstance().asInstanceOf[StateStoreProvider] provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) provider diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index bab297c7df594..765ff076cb467 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -38,7 +38,7 @@ class StateStoreConf(@transient private val sqlConf: SQLConf) * Optional fully qualified name of the subclass of [[StateStoreProvider]] * managing state data. That is, the implementation of the State Store to use. */ - val providerClass: Option[String] = sqlConf.stateStoreProviderClass + val providerClass: String = sqlConf.stateStoreProviderClass /** * Additional configurations related to state store. This will capture all configs in diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index dc556322beddb..e6cdc063c4e9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -37,16 +37,18 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { } // None set - assert(OffsetSeqMetadata(0, 0, Map.empty) === OffsetSeqMetadata("""{}""")) + assert(new OffsetSeqMetadata(0, 0, Map.empty) === OffsetSeqMetadata("""{}""")) // One set - assert(OffsetSeqMetadata(1, 0, Map.empty) === OffsetSeqMetadata("""{"batchWatermarkMs":1}""")) - assert(OffsetSeqMetadata(0, 2, Map.empty) === OffsetSeqMetadata("""{"batchTimestampMs":2}""")) + assert(new OffsetSeqMetadata(1, 0, Map.empty) === + OffsetSeqMetadata("""{"batchWatermarkMs":1}""")) + assert(new OffsetSeqMetadata(0, 2, Map.empty) === + OffsetSeqMetadata("""{"batchTimestampMs":2}""")) assert(OffsetSeqMetadata(0, 0, getConfWith(shufflePartitions = 2)) === OffsetSeqMetadata(s"""{"conf": {"$key":2}}""")) // Two set - assert(OffsetSeqMetadata(1, 2, Map.empty) === + assert(new OffsetSeqMetadata(1, 2, Map.empty) === OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}""")) assert(OffsetSeqMetadata(1, 0, getConfWith(shufflePartitions = 3)) === OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"conf": {"$key":3}}""")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 86c3a35a59c13..6f7b9d35a6bb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -637,19 +637,11 @@ class StreamSuite extends StreamTest { } testQuietly("specify custom state store provider") { - val queryName = "memStream" val providerClassName = classOf[TestStateStoreProvider].getCanonicalName withSQLConf("spark.sql.streaming.stateStore.providerClass" -> providerClassName) { val input = MemoryStream[Int] - val query = input - .toDS() - .groupBy() - .count() - .writeStream - .outputMode("complete") - .format("memory") - .queryName(queryName) - .start() + val df = input.toDS().groupBy().count() + val query = df.writeStream.outputMode("complete").format("memory").queryName("name").start() input.addData(1, 2, 3) val e = intercept[Exception] { query.awaitTermination() @@ -659,6 +651,45 @@ class StreamSuite extends StreamTest { assert(e.getMessage.contains("instantiated")) } } + + testQuietly("custom state store provider read from offset log") { + val input = MemoryStream[Int] + val df = input.toDS().groupBy().count() + val providerConf1 = "spark.sql.streaming.stateStore.providerClass" -> + "org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider" + val providerConf2 = "spark.sql.streaming.stateStore.providerClass" -> + classOf[TestStateStoreProvider].getCanonicalName + + def runQuery(queryName: String, checkpointLoc: String): Unit = { + val query = df.writeStream + .outputMode("complete") + .format("memory") + .queryName(queryName) + .option("checkpointLocation", checkpointLoc) + .start() + input.addData(1, 2, 3) + query.processAllAvailable() + query.stop() + } + + withTempDir { dir => + val checkpointLoc1 = new File(dir, "1").getCanonicalPath + withSQLConf(providerConf1) { + runQuery("query1", checkpointLoc1) // generate checkpoints + } + + val checkpointLoc2 = new File(dir, "2").getCanonicalPath + withSQLConf(providerConf2) { + // Verify new query will use new provider that throw error on loading + intercept[Exception] { + runQuery("query2", checkpointLoc2) + } + + // Verify old query from checkpoint will still use old provider + runQuery("query1", checkpointLoc1) + } + } + } } abstract class FakeSource extends StreamSourceProvider { From 4cc62951a2b12a372a2b267bf8597a0a31e2b2cb Mon Sep 17 00:00:00 2001 From: Ong Ming Yang Date: Fri, 23 Jun 2017 10:56:59 -0700 Subject: [PATCH 072/118] [MINOR][DOCS] Docs in DataFrameNaFunctions.scala use wrong method ## What changes were proposed in this pull request? * Following the first few examples in this file, the remaining methods should also be methods of `df.na` not `df`. * Filled in some missing parentheses ## How was this patch tested? N/A Author: Ong Ming Yang Closes #18398 from ongmingyang/master. --- .../spark/sql/DataFrameNaFunctions.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index ee949e78fa3ba..871fff71e5538 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -268,13 +268,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * import com.google.common.collect.ImmutableMap; * * // Replaces all occurrences of 1.0 with 2.0 in column "height". - * df.replace("height", ImmutableMap.of(1.0, 2.0)); + * df.na.replace("height", ImmutableMap.of(1.0, 2.0)); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". - * df.replace("name", ImmutableMap.of("UNKNOWN", "unnamed")); + * df.na.replace("name", ImmutableMap.of("UNKNOWN", "unnamed")); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. - * df.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); + * df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); * }}} * * @param col name of the column to apply the value replacement @@ -295,10 +295,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * import com.google.common.collect.ImmutableMap; * * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". - * df.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0)); + * df.na.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0)); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". - * df.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); + * df.na.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); * }}} * * @param cols list of columns to apply the value replacement @@ -319,13 +319,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height". - * df.replace("height", Map(1.0 -> 2.0)) + * df.na.replace("height", Map(1.0 -> 2.0)); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". - * df.replace("name", Map("UNKNOWN" -> "unnamed") + * df.na.replace("name", Map("UNKNOWN" -> "unnamed")); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. - * df.replace("*", Map("UNKNOWN" -> "unnamed") + * df.na.replace("*", Map("UNKNOWN" -> "unnamed")); * }}} * * @param col name of the column to apply the value replacement @@ -348,10 +348,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". - * df.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0)); + * df.na.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0)); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". - * df.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed"); + * df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed")); * }}} * * @param cols list of columns to apply the value replacement From 13c2a4f2f8c6d3484f920caadddf4e5edce0a945 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 23 Jun 2017 11:02:54 -0700 Subject: [PATCH 073/118] [SPARK-20417][SQL] Move subquery error handling to checkAnalysis from Analyzer ## What changes were proposed in this pull request? Currently we do a lot of validations for subquery in the Analyzer. We should move them to CheckAnalysis which is the framework to catch and report Analysis errors. This was mentioned as a review comment in SPARK-18874. ## How was this patch tested? Exists tests + A few tests added to SQLQueryTestSuite. Author: Dilip Biswal Closes #17713 from dilipbiswal/subquery_checkanalysis. --- .../sql/catalyst/analysis/Analyzer.scala | 230 +----------- .../sql/catalyst/analysis/CheckAnalysis.scala | 338 ++++++++++++++---- .../sql/catalyst/expressions/predicates.scala | 46 ++- .../analysis/AnalysisErrorSuite.scala | 3 +- .../analysis/ResolveSubquerySuite.scala | 2 +- .../negative-cases/subq-input-typecheck.sql | 47 +++ .../subq-input-typecheck.sql.out | 106 ++++++ .../org/apache/spark/sql/SubquerySuite.scala | 2 +- 8 files changed, 464 insertions(+), 310 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 647fc0b9342c1..193082eb77024 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ -import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ @@ -1257,217 +1256,16 @@ class Analyzer( } /** - * Validates to make sure the outer references appearing inside the subquery - * are legal. This function also returns the list of expressions - * that contain outer references. These outer references would be kept as children - * of subquery expressions by the caller of this function. - */ - private def checkAndGetOuterReferences(sub: LogicalPlan): Seq[Expression] = { - val outerReferences = ArrayBuffer.empty[Expression] - - // Validate that correlated aggregate expression do not contain a mixture - // of outer and local references. - def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = { - expr.foreach { - case a: AggregateExpression if containsOuter(a) => - val outer = a.collect { case OuterReference(e) => e.toAttribute } - val local = a.references -- outer - if (local.nonEmpty) { - val msg = - s""" - |Found an aggregate expression in a correlated predicate that has both - |outer and local references, which is not supported yet. - |Aggregate expression: ${SubExprUtils.stripOuterReference(a).sql}, - |Outer references: ${outer.map(_.sql).mkString(", ")}, - |Local references: ${local.map(_.sql).mkString(", ")}. - """.stripMargin.replace("\n", " ").trim() - failAnalysis(msg) - } - case _ => - } - } - - // Make sure a plan's subtree does not contain outer references - def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { - if (hasOuterReferences(p)) { - failAnalysis(s"Accessing outer query column is not allowed in:\n$p") - } - } - - // Make sure a plan's expressions do not contain : - // 1. Aggregate expressions that have mixture of outer and local references. - // 2. Expressions containing outer references on plan nodes other than Filter. - def failOnInvalidOuterReference(p: LogicalPlan): Unit = { - p.expressions.foreach(checkMixedReferencesInsideAggregateExpr) - if (!p.isInstanceOf[Filter] && p.expressions.exists(containsOuter)) { - failAnalysis( - "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + - s"clauses:\n$p") - } - } - - // SPARK-17348: A potential incorrect result case. - // When a correlated predicate is a non-equality predicate, - // certain operators are not permitted from the operator - // hosting the correlated predicate up to the operator on the outer table. - // Otherwise, the pull up of the correlated predicate - // will generate a plan with a different semantics - // which could return incorrect result. - // Currently we check for Aggregate and Window operators - // - // Below shows an example of a Logical Plan during Analyzer phase that - // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..] - // through the Aggregate (or Window) operator could alter the result of - // the Aggregate. - // - // Project [c1#76] - // +- Project [c1#87, c2#88] - // : (Aggregate or Window operator) - // : +- Filter [outer(c2#77) >= c2#88)] - // : +- SubqueryAlias t2, `t2` - // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88] - // : +- LocalRelation [_1#84, _2#85] - // +- SubqueryAlias t1, `t1` - // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] - // +- LocalRelation [_1#73, _2#74] - def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = { - if (found) { - // Report a non-supported case as an exception - failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p") - } - } - - var foundNonEqualCorrelatedPred : Boolean = false - - // Simplify the predicates before validating any unsupported correlation patterns - // in the plan. - BooleanSimplification(sub).foreachUp { - - // Whitelist operators allowed in a correlated subquery - // There are 4 categories: - // 1. Operators that are allowed anywhere in a correlated subquery, and, - // by definition of the operators, they either do not contain - // any columns or cannot host outer references. - // 2. Operators that are allowed anywhere in a correlated subquery - // so long as they do not host outer references. - // 3. Operators that need special handlings. These operators are - // Project, Filter, Join, Aggregate, and Generate. - // - // Any operators that are not in the above list are allowed - // in a correlated subquery only if they are not on a correlation path. - // In other word, these operators are allowed only under a correlation point. - // - // A correlation path is defined as the sub-tree of all the operators that - // are on the path from the operator hosting the correlated expressions - // up to the operator producing the correlated values. - - // Category 1: - // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias - case _: ResolvedHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => - - // Category 2: - // These operators can be anywhere in a correlated subquery. - // so long as they do not host outer references in the operators. - case s: Sort => - failOnInvalidOuterReference(s) - case r: RepartitionByExpression => - failOnInvalidOuterReference(r) - - // Category 3: - // Filter is one of the two operators allowed to host correlated expressions. - // The other operator is Join. Filter can be anywhere in a correlated subquery. - case f: Filter => - // Find all predicates with an outer reference. - val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) - - // Find any non-equality correlated predicates - foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { - case _: EqualTo | _: EqualNullSafe => false - case _ => true - } - - failOnInvalidOuterReference(f) - // The aggregate expressions are treated in a special way by getOuterReferences. If the - // aggregate expression contains only outer reference attributes then the entire aggregate - // expression is isolated as an OuterReference. - // i.e min(OuterReference(b)) => OuterReference(min(b)) - outerReferences ++= getOuterReferences(correlated) - - // Project cannot host any correlated expressions - // but can be anywhere in a correlated subquery. - case p: Project => - failOnInvalidOuterReference(p) - - // Aggregate cannot host any correlated expressions - // It can be on a correlation path if the correlation contains - // only equality correlated predicates. - // It cannot be on a correlation path if the correlation has - // non-equality correlated predicates. - case a: Aggregate => - failOnInvalidOuterReference(a) - failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) - - // Join can host correlated expressions. - case j @ Join(left, right, joinType, _) => - joinType match { - // Inner join, like Filter, can be anywhere. - case _: InnerLike => - failOnInvalidOuterReference(j) - - // Left outer join's right operand cannot be on a correlation path. - // LeftAnti and ExistenceJoin are special cases of LeftOuter. - // Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame - // so it should not show up here in Analysis phase. This is just a safety net. - // - // LeftSemi does not allow output from the right operand. - // Any correlated references in the subplan - // of the right operand cannot be pulled up. - case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => - failOnInvalidOuterReference(j) - failOnOuterReferenceInSubTree(right) - - // Likewise, Right outer join's left operand cannot be on a correlation path. - case RightOuter => - failOnInvalidOuterReference(j) - failOnOuterReferenceInSubTree(left) - - // Any other join types not explicitly listed above, - // including Full outer join, are treated as Category 4. - case _ => - failOnOuterReferenceInSubTree(j) - } - - // Generator with join=true, i.e., expressed with - // LATERAL VIEW [OUTER], similar to inner join, - // allows to have correlation under it - // but must not host any outer references. - // Note: - // Generator with join=false is treated as Category 4. - case g: Generate if g.join => - failOnInvalidOuterReference(g) - - // Category 4: Any other operators not in the above 3 categories - // cannot be on a correlation path, that is they are allowed only - // under a correlation point but they and their descendant operators - // are not allowed to have any correlated expressions. - case p => - failOnOuterReferenceInSubTree(p) - } - outerReferences - } - - /** - * Resolves the subquery. The subquery is resolved using its outer plans. This method - * will resolve the subquery by alternating between the regular analyzer and by applying the - * resolveOuterReferences rule. + * Resolves the subquery plan that is referenced in a subquery expression. The normal + * attribute references are resolved using regular analyzer and the outer references are + * resolved from the outer plans using the resolveOuterReferences method. * * Outer references from the correlated predicates are updated as children of * Subquery expression. */ private def resolveSubQuery( e: SubqueryExpression, - plans: Seq[LogicalPlan], - requiredColumns: Int = 0)( + plans: Seq[LogicalPlan])( f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = { // Step 1: Resolve the outer expressions. var previous: LogicalPlan = null @@ -1488,15 +1286,8 @@ class Analyzer( // Step 2: If the subquery plan is fully resolved, pull the outer references and record // them as children of SubqueryExpression. if (current.resolved) { - // Make sure the resolved query has the required number of output columns. This is only - // needed for Scalar and IN subqueries. - if (requiredColumns > 0 && requiredColumns != current.output.size) { - failAnalysis(s"The number of columns in the subquery (${current.output.size}) " + - s"does not match the required number of columns ($requiredColumns)") - } - // Validate the outer reference and record the outer references as children of - // subquery expression. - f(current, checkAndGetOuterReferences(current)) + // Record the outer references as children of subquery expression. + f(current, SubExprUtils.getOuterReferences(current)) } else { e.withNewPlan(current) } @@ -1514,16 +1305,11 @@ class Analyzer( private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => - resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) + resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved => - // Get the left hand side expressions. - val expressions = value match { - case cns : CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - val expr = resolveSubQuery(l, plans, expressions.size)(ListQuery(_, _, exprId)) + val expr = resolveSubQuery(l, plans)(ListQuery(_, _, exprId)) In(value, Seq(expr)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 2e3ac3e474866..fb81a7006bc5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ +import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -129,61 +131,8 @@ trait CheckAnalysis extends PredicateHelper { case None => w } - case s @ ScalarSubquery(query, conditions, _) => - checkAnalysis(query) - - // If no correlation, the output must be exactly one column - if (conditions.isEmpty && query.output.size != 1) { - failAnalysis( - s"Scalar subquery must return only one column, but got ${query.output.size}") - } else if (conditions.nonEmpty) { - def checkAggregate(agg: Aggregate): Unit = { - // Make sure correlated scalar subqueries contain one row for every outer row by - // enforcing that they are aggregates containing exactly one aggregate expression. - // The analyzer has already checked that subquery contained only one output column, - // and added all the grouping expressions to the aggregate. - val aggregates = agg.expressions.flatMap(_.collect { - case a: AggregateExpression => a - }) - if (aggregates.isEmpty) { - failAnalysis("The output of a correlated scalar subquery must be aggregated") - } - - // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns - // are not part of the correlated columns. - val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) - // Collect the local references from the correlated predicate in the subquery. - val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references) - .filterNot(conditions.flatMap(_.references).contains) - val correlatedCols = AttributeSet(subqueryColumns) - val invalidCols = groupByCols -- correlatedCols - // GROUP BY columns must be a subset of columns in the predicates - if (invalidCols.nonEmpty) { - failAnalysis( - "A GROUP BY clause in a scalar correlated subquery " + - "cannot contain non-correlated columns: " + - invalidCols.mkString(",")) - } - } - - // Skip subquery aliases added by the Analyzer. - // For projects, do the necessary mapping and skip to its child. - def cleanQuery(p: LogicalPlan): LogicalPlan = p match { - case s: SubqueryAlias => cleanQuery(s.child) - case p: Project => cleanQuery(p.child) - case child => child - } - - cleanQuery(query) match { - case a: Aggregate => checkAggregate(a) - case Filter(_, a: Aggregate) => checkAggregate(a) - case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail") - } - } - s - case s: SubqueryExpression => - checkAnalysis(s.plan) + checkSubqueryExpression(operator, s) s } @@ -291,19 +240,6 @@ trait CheckAnalysis extends PredicateHelper { case LocalLimit(limitExpr, _) => checkLimitClause(limitExpr) - case p if p.expressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) => - p match { - case _: Filter | _: Aggregate | _: Project => // Ok - case other => failAnalysis( - s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p") - } - - case p if p.expressions.exists(SubqueryExpression.hasInOrExistsSubquery) => - p match { - case _: Filter => // Ok - case _ => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") - } - case _: Union | _: SetOperation if operator.children.length > 1 => def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType) def ordinalNumber(i: Int): String = i match { @@ -414,4 +350,272 @@ trait CheckAnalysis extends PredicateHelper { plan.foreach(_.setAnalyzed()) } + + /** + * Validates subquery expressions in the plan. Upon failure, returns an user facing error. + */ + private def checkSubqueryExpression(plan: LogicalPlan, expr: SubqueryExpression): Unit = { + def checkAggregateInScalarSubquery( + conditions: Seq[Expression], + query: LogicalPlan, agg: Aggregate): Unit = { + // Make sure correlated scalar subqueries contain one row for every outer row by + // enforcing that they are aggregates containing exactly one aggregate expression. + val aggregates = agg.expressions.flatMap(_.collect { + case a: AggregateExpression => a + }) + if (aggregates.isEmpty) { + failAnalysis("The output of a correlated scalar subquery must be aggregated") + } + + // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns + // are not part of the correlated columns. + val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) + // Collect the local references from the correlated predicate in the subquery. + val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references) + .filterNot(conditions.flatMap(_.references).contains) + val correlatedCols = AttributeSet(subqueryColumns) + val invalidCols = groupByCols -- correlatedCols + // GROUP BY columns must be a subset of columns in the predicates + if (invalidCols.nonEmpty) { + failAnalysis( + "A GROUP BY clause in a scalar correlated subquery " + + "cannot contain non-correlated columns: " + + invalidCols.mkString(",")) + } + } + + // Skip subquery aliases added by the Analyzer. + // For projects, do the necessary mapping and skip to its child. + def cleanQueryInScalarSubquery(p: LogicalPlan): LogicalPlan = p match { + case s: SubqueryAlias => cleanQueryInScalarSubquery(s.child) + case p: Project => cleanQueryInScalarSubquery(p.child) + case child => child + } + + // Validate the subquery plan. + checkAnalysis(expr.plan) + + expr match { + case ScalarSubquery(query, conditions, _) => + // Scalar subquery must return one column as output. + if (query.output.size != 1) { + failAnalysis( + s"Scalar subquery must return only one column, but got ${query.output.size}") + } + + if (conditions.nonEmpty) { + cleanQueryInScalarSubquery(query) match { + case a: Aggregate => checkAggregateInScalarSubquery(conditions, query, a) + case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(conditions, query, a) + case fail => failAnalysis(s"Correlated scalar subqueries must be aggregated: $fail") + } + + // Only certain operators are allowed to host subquery expression containing + // outer references. + plan match { + case _: Filter | _: Aggregate | _: Project => // Ok + case other => failAnalysis( + "Correlated scalar sub-queries can only be used in a " + + s"Filter/Aggregate/Project: $plan") + } + } + + case inSubqueryOrExistsSubquery => + plan match { + case _: Filter => // Ok + case _ => + failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in a Filter: $plan") + } + } + + // Validate to make sure the correlations appearing in the query are valid and + // allowed by spark. + checkCorrelationsInSubquery(expr.plan) + } + + /** + * Validates to make sure the outer references appearing inside the subquery + * are allowed. + */ + private def checkCorrelationsInSubquery(sub: LogicalPlan): Unit = { + // Validate that correlated aggregate expression do not contain a mixture + // of outer and local references. + def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = { + expr.foreach { + case a: AggregateExpression if containsOuter(a) => + val outer = a.collect { case OuterReference(e) => e.toAttribute } + val local = a.references -- outer + if (local.nonEmpty) { + val msg = + s""" + |Found an aggregate expression in a correlated predicate that has both + |outer and local references, which is not supported yet. + |Aggregate expression: ${SubExprUtils.stripOuterReference(a).sql}, + |Outer references: ${outer.map(_.sql).mkString(", ")}, + |Local references: ${local.map(_.sql).mkString(", ")}. + """.stripMargin.replace("\n", " ").trim() + failAnalysis(msg) + } + case _ => + } + } + + // Make sure a plan's subtree does not contain outer references + def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { + if (hasOuterReferences(p)) { + failAnalysis(s"Accessing outer query column is not allowed in:\n$p") + } + } + + // Make sure a plan's expressions do not contain : + // 1. Aggregate expressions that have mixture of outer and local references. + // 2. Expressions containing outer references on plan nodes other than Filter. + def failOnInvalidOuterReference(p: LogicalPlan): Unit = { + p.expressions.foreach(checkMixedReferencesInsideAggregateExpr) + if (!p.isInstanceOf[Filter] && p.expressions.exists(containsOuter)) { + failAnalysis( + "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + + s"clauses:\n$p") + } + } + + // SPARK-17348: A potential incorrect result case. + // When a correlated predicate is a non-equality predicate, + // certain operators are not permitted from the operator + // hosting the correlated predicate up to the operator on the outer table. + // Otherwise, the pull up of the correlated predicate + // will generate a plan with a different semantics + // which could return incorrect result. + // Currently we check for Aggregate and Window operators + // + // Below shows an example of a Logical Plan during Analyzer phase that + // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..] + // through the Aggregate (or Window) operator could alter the result of + // the Aggregate. + // + // Project [c1#76] + // +- Project [c1#87, c2#88] + // : (Aggregate or Window operator) + // : +- Filter [outer(c2#77) >= c2#88)] + // : +- SubqueryAlias t2, `t2` + // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88] + // : +- LocalRelation [_1#84, _2#85] + // +- SubqueryAlias t1, `t1` + // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] + // +- LocalRelation [_1#73, _2#74] + def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = { + if (found) { + // Report a non-supported case as an exception + failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p") + } + } + + var foundNonEqualCorrelatedPred: Boolean = false + + // Simplify the predicates before validating any unsupported correlation patterns + // in the plan. + BooleanSimplification(sub).foreachUp { + // Whitelist operators allowed in a correlated subquery + // There are 4 categories: + // 1. Operators that are allowed anywhere in a correlated subquery, and, + // by definition of the operators, they either do not contain + // any columns or cannot host outer references. + // 2. Operators that are allowed anywhere in a correlated subquery + // so long as they do not host outer references. + // 3. Operators that need special handlings. These operators are + // Filter, Join, Aggregate, and Generate. + // + // Any operators that are not in the above list are allowed + // in a correlated subquery only if they are not on a correlation path. + // In other word, these operators are allowed only under a correlation point. + // + // A correlation path is defined as the sub-tree of all the operators that + // are on the path from the operator hosting the correlated expressions + // up to the operator producing the correlated values. + + // Category 1: + // ResolvedHint, Distinct, LeafNode, Repartition, and SubqueryAlias + case _: ResolvedHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => + + // Category 2: + // These operators can be anywhere in a correlated subquery. + // so long as they do not host outer references in the operators. + case p: Project => + failOnInvalidOuterReference(p) + + case s: Sort => + failOnInvalidOuterReference(s) + + case r: RepartitionByExpression => + failOnInvalidOuterReference(r) + + // Category 3: + // Filter is one of the two operators allowed to host correlated expressions. + // The other operator is Join. Filter can be anywhere in a correlated subquery. + case f: Filter => + val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) + + // Find any non-equality correlated predicates + foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { + case _: EqualTo | _: EqualNullSafe => false + case _ => true + } + failOnInvalidOuterReference(f) + + // Aggregate cannot host any correlated expressions + // It can be on a correlation path if the correlation contains + // only equality correlated predicates. + // It cannot be on a correlation path if the correlation has + // non-equality correlated predicates. + case a: Aggregate => + failOnInvalidOuterReference(a) + failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) + + // Join can host correlated expressions. + case j @ Join(left, right, joinType, _) => + joinType match { + // Inner join, like Filter, can be anywhere. + case _: InnerLike => + failOnInvalidOuterReference(j) + + // Left outer join's right operand cannot be on a correlation path. + // LeftAnti and ExistenceJoin are special cases of LeftOuter. + // Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame + // so it should not show up here in Analysis phase. This is just a safety net. + // + // LeftSemi does not allow output from the right operand. + // Any correlated references in the subplan + // of the right operand cannot be pulled up. + case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => + failOnInvalidOuterReference(j) + failOnOuterReferenceInSubTree(right) + + // Likewise, Right outer join's left operand cannot be on a correlation path. + case RightOuter => + failOnInvalidOuterReference(j) + failOnOuterReferenceInSubTree(left) + + // Any other join types not explicitly listed above, + // including Full outer join, are treated as Category 4. + case _ => + failOnOuterReferenceInSubTree(j) + } + + // Generator with join=true, i.e., expressed with + // LATERAL VIEW [OUTER], similar to inner join, + // allows to have correlation under it + // but must not host any outer references. + // Note: + // Generator with join=false is treated as Category 4. + case g: Generate if g.join => + failOnInvalidOuterReference(g) + + // Category 4: Any other operators not in the above 3 categories + // cannot be on a correlation path, that is they are allowed only + // under a correlation point but they and their descendant operators + // are not allowed to have any correlated expressions. + case p => + failOnOuterReferenceInSubTree(p) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c15ee2ab270bc..f3fe58caa6fe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -144,27 +144,39 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { case cns: CreateNamedStruct => cns.valExprs case expr => Seq(expr) } - - val mismatchedColumns = valExprs.zip(sub.output).flatMap { - case (l, r) if l.dataType != r.dataType => - s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" - case _ => None - } - - if (mismatchedColumns.nonEmpty) { + if (valExprs.length != sub.output.length) { TypeCheckResult.TypeCheckFailure( s""" - |The data type of one or more elements in the left hand side of an IN subquery - |is not compatible with the data type of the output of the subquery - |Mismatched columns: - |[${mismatchedColumns.mkString(", ")}] - |Left side: - |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. - |Right side: - |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${valExprs.length}. + |#columns in right hand side: ${sub.output.length}. + |Left side columns: + |[${valExprs.map(_.sql).mkString(", ")}]. + |Right side columns: + |[${sub.output.map(_.sql).mkString(", ")}]. """.stripMargin) } else { - TypeCheckResult.TypeCheckSuccess + val mismatchedColumns = valExprs.zip(sub.output).flatMap { + case (l, r) if l.dataType != r.dataType => + s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" + case _ => None + } + if (mismatchedColumns.nonEmpty) { + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. + """.stripMargin) + } else { + TypeCheckResult.TypeCheckSuccess + } } case _ => if (list.exists(l => l.dataType != value.dataType)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5050318d96358..4ed995e20d7ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -111,8 +111,7 @@ class AnalysisErrorSuite extends AnalysisTest { "scalar subquery with 2 columns", testRelation.select( (ScalarSubquery(testRelation.select('a, dateLit.as('b))) + Literal(1)).as('a)), - "The number of columns in the subquery (2)" :: - "does not match the required number of columns (1)":: Nil) + "Scalar subquery must return only one column, but got 2" :: Nil) errorTest( "scalar subquery with no column", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 55693121431a2..1bf8d76da04d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -35,7 +35,7 @@ class ResolveSubquerySuite extends AnalysisTest { test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) val m = intercept[AnalysisException] { - SimpleAnalyzer.ResolveSubquery(expr) + SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr)) }.getMessage assert(m.contains( "Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses")) diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql new file mode 100644 index 0000000000000..b15f4da81dd93 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql @@ -0,0 +1,47 @@ +-- The test file contains negative test cases +-- of invalid queries where error messages are expected. + +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (1, 2, 3) +AS t1(t1a, t1b, t1c); + +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES + (1, 0, 1) +AS t2(t2a, t2b, t2c); + +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES + (3, 1, 2) +AS t3(t3a, t3b, t3c); + +-- TC 01.01 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b = t1.t1b + GROUP BY t2.t2b + ) +FROM t1; + +-- TC 01.01 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b > 0 + GROUP BY t2.t2b + ) +FROM t1; + +-- TC 01.03 +SELECT * FROM t1 +WHERE +t1a IN (SELECT t2a, t2b + FROM t2 + WHERE t1a = t2a); + +-- TC 01.04 +SELECT * FROM T1 +WHERE +(t1a, t1b) IN (SELECT t2a + FROM t2 + WHERE t1a = t2a); + diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out new file mode 100644 index 0000000000000..9ea9d3c4c6f40 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out @@ -0,0 +1,106 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (1, 2, 3) +AS t1(t1a, t1b, t1c) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES + (1, 0, 1) +AS t2(t2a, t2b, t2c) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES + (3, 1, 2) +AS t3(t3a, t3b, t3c) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b = t1.t1b + GROUP BY t2.t2b + ) +FROM t1 +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +Scalar subquery must return only one column, but got 2; + + +-- !query 4 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b > 0 + GROUP BY t2.t2b + ) +FROM t1 +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +Scalar subquery must return only one column, but got 2; + + +-- !query 5 +SELECT * FROM t1 +WHERE +t1a IN (SELECT t2a, t2b + FROM t2 + WHERE t1a = t2a) +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve '(t1.`t1a` IN (listquery(t1.`t1a`)))' due to data type mismatch: +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 1. +#columns in right hand side: 2. +Left side columns: +[t1.`t1a`]. +Right side columns: +[t2.`t2a`, t2.`t2b`]. + ; + + +-- !query 6 +SELECT * FROM T1 +WHERE +(t1a, t1b) IN (SELECT t2a + FROM t2 + WHERE t1a = t2a) +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve '(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`)))' due to data type mismatch: +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 2. +#columns in right hand side: 1. +Left side columns: +[t1.`t1a`, t1.`t1b`]. +Right side columns: +[t2.`t2a`]. + ; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 4629a8c0dbe5f..820cff655c4ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -517,7 +517,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { val msg1 = intercept[AnalysisException] { sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1") } - assert(msg1.getMessage.contains("Correlated scalar subqueries must be Aggregated")) + assert(msg1.getMessage.contains("Correlated scalar subqueries must be aggregated")) val msg2 = intercept[AnalysisException] { sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1") From 03eb6117affcca21798be25706a39e0d5a2f7288 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 23 Jun 2017 14:48:33 -0700 Subject: [PATCH 074/118] [SPARK-21164][SQL] Remove isTableSample from Sample and isGenerated from Alias and AttributeReference ## What changes were proposed in this pull request? `isTableSample` and `isGenerated ` were introduced for SQL Generation respectively by https://github.com/apache/spark/pull/11148 and https://github.com/apache/spark/pull/11050 Since SQL Generation is removed, we do not need to keep `isTableSample`. ## How was this patch tested? The existing test cases Author: Xiao Li Closes #18379 from gatorsmile/CleanSample. --- .../sql/catalyst/analysis/Analyzer.scala | 8 ++--- .../expressions/namedExpressions.scala | 34 +++++++------------ .../optimizer/RewriteDistinctAggregates.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../sql/catalyst/planning/patterns.scala | 4 +-- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../plans/logical/basicLogicalOperators.scala | 6 +--- .../analysis/AnalysisErrorSuite.scala | 2 +- .../analysis/UnsupportedOperationsSuite.scala | 2 +- .../optimizer/ColumnPruningSuite.scala | 8 ++--- .../sql/catalyst/parser/PlanParserSuite.scala | 4 +-- .../spark/sql/catalyst/plans/PlanTest.scala | 10 +++--- .../BasicStatsEstimationSuite.scala | 4 +-- .../scala/org/apache/spark/sql/Dataset.scala | 4 +-- 15 files changed, 40 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 193082eb77024..7e5ebfc93286f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -874,7 +874,7 @@ class Analyzer( def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { expressions.map { - case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated) + case a: Alias => Alias(a.child, a.name)() case other => other } } @@ -1368,7 +1368,7 @@ class Analyzer( val aggregatedCondition = Aggregate( grouping, - Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil, + Alias(havingCondition, "havingCondition")() :: Nil, child) val resolvedOperator = execute(aggregatedCondition) def resolvedAggregateFilter = @@ -1424,7 +1424,7 @@ class Analyzer( try { val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) val aliasedOrdering = - unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")(isGenerated = true)) + unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = @@ -1935,7 +1935,7 @@ class Analyzer( leafNondeterministic.distinct.map { e => val ne = e match { case n: NamedExpression => n - case _ => Alias(e, "_nondeterministic")(isGenerated = true) + case _ => Alias(e, "_nondeterministic")() } e -> ne } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index c842f85af693c..29c33804f077a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -81,9 +81,6 @@ trait NamedExpression extends Expression { /** Returns the metadata when an expression is a reference to another expression with metadata. */ def metadata: Metadata = Metadata.empty - /** Returns true if the expression is generated by Catalyst */ - def isGenerated: java.lang.Boolean = false - /** Returns a copy of this expression with a new `exprId`. */ def newInstance(): NamedExpression @@ -128,13 +125,11 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn * qualified way. Consider the examples tableName.name, subQueryAlias.name. * tableName and subQueryAlias are possible qualifiers. * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. - * @param isGenerated A flag to indicate if this alias is generated by Catalyst */ case class Alias(child: Expression, name: String)( val exprId: ExprId = NamedExpression.newExprId, val qualifier: Option[String] = None, - val explicitMetadata: Option[Metadata] = None, - override val isGenerated: java.lang.Boolean = false) + val explicitMetadata: Option[Metadata] = None) extends UnaryExpression with NamedExpression { // Alias(Generator, xx) need to be transformed into Generate(generator, ...) @@ -159,13 +154,11 @@ case class Alias(child: Expression, name: String)( } def newInstance(): NamedExpression = - Alias(child, name)( - qualifier = qualifier, explicitMetadata = explicitMetadata, isGenerated = isGenerated) + Alias(child, name)(qualifier = qualifier, explicitMetadata = explicitMetadata) override def toAttribute: Attribute = { if (resolved) { - AttributeReference(name, child.dataType, child.nullable, metadata)( - exprId, qualifier, isGenerated) + AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifier) } else { UnresolvedAttribute(name) } @@ -174,7 +167,7 @@ case class Alias(child: Expression, name: String)( override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix" override protected final def otherCopyArgs: Seq[AnyRef] = { - exprId :: qualifier :: explicitMetadata :: isGenerated :: Nil + exprId :: qualifier :: explicitMetadata :: Nil } override def hashCode(): Int = { @@ -207,7 +200,6 @@ case class Alias(child: Expression, name: String)( * @param qualifier An optional string that can be used to referred to this attribute in a fully * qualified way. Consider the examples tableName.name, subQueryAlias.name. * tableName and subQueryAlias are possible qualifiers. - * @param isGenerated A flag to indicate if this reference is generated by Catalyst */ case class AttributeReference( name: String, @@ -215,8 +207,7 @@ case class AttributeReference( nullable: Boolean = true, override val metadata: Metadata = Metadata.empty)( val exprId: ExprId = NamedExpression.newExprId, - val qualifier: Option[String] = None, - override val isGenerated: java.lang.Boolean = false) + val qualifier: Option[String] = None) extends Attribute with Unevaluable { /** @@ -253,8 +244,7 @@ case class AttributeReference( } override def newInstance(): AttributeReference = - AttributeReference(name, dataType, nullable, metadata)( - qualifier = qualifier, isGenerated = isGenerated) + AttributeReference(name, dataType, nullable, metadata)(qualifier = qualifier) /** * Returns a copy of this [[AttributeReference]] with changed nullability. @@ -263,7 +253,7 @@ case class AttributeReference( if (nullable == newNullability) { this } else { - AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifier, isGenerated) + AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifier) } } @@ -271,7 +261,7 @@ case class AttributeReference( if (name == newName) { this } else { - AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifier, isGenerated) + AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifier) } } @@ -282,7 +272,7 @@ case class AttributeReference( if (newQualifier == qualifier) { this } else { - AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifier, isGenerated) + AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifier) } } @@ -290,16 +280,16 @@ case class AttributeReference( if (exprId == newExprId) { this } else { - AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifier, isGenerated) + AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifier) } } override def withMetadata(newMetadata: Metadata): Attribute = { - AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier, isGenerated) + AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier) } override protected final def otherCopyArgs: Seq[AnyRef] = { - exprId :: qualifier :: isGenerated :: Nil + exprId :: qualifier :: Nil } /** Used to signal the column used to calculate an eventTime watermark (e.g. a#1-T{delayMs}) */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 3b27cd2ffe028..4448ace7105a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -134,7 +134,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Aggregation strategy can handle queries with a single distinct group. if (distinctAggGroups.size > 1) { // Create the attributes for the grouping id and the group by clause. - val gid = AttributeReference("gid", IntegerType, nullable = false)(isGenerated = true) + val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { case ne: NamedExpression => ne -> ne.toAttribute case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 315c6721b3f65..ef79cbcaa0ce6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -627,7 +627,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps, s"Sampling fraction ($fraction) must be on interval [0, 1]", ctx) - Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query) } ctx.sampleType.getType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index ef925f92ecc7e..7f370fb731b2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -80,12 +80,12 @@ object PhysicalOperation extends PredicateHelper { expr.transform { case a @ Alias(ref: AttributeReference, name) => aliases.get(ref) - .map(Alias(_, name)(a.exprId, a.qualifier, isGenerated = a.isGenerated)) + .map(Alias(_, name)(a.exprId, a.qualifier)) .getOrElse(a) case a: AttributeReference => aliases.get(a) - .map(Alias(_, a.name)(a.exprId, a.qualifier, isGenerated = a.isGenerated)).getOrElse(a) + .map(Alias(_, a.name)(a.exprId, a.qualifier)).getOrElse(a) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 1f6d05bc8d816..01b3da3f7c482 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -200,7 +200,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT // normalize that for equality testing, by assigning expr id from 0 incrementally. The // alias name doesn't matter and should be erased. val normalizedChild = QueryPlan.normalizeExprId(a.child, allAttributes) - Alias(normalizedChild, "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated) + Alias(normalizedChild, "")(ExprId(id), a.qualifier) case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => // Top level `AttributeReference` may also be used for output like `Alias`, we should diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0c098ac0209e8..0d30aa76049a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -221,7 +221,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai nameParts: Seq[String], resolver: Resolver, attribute: Attribute): Option[(Attribute, List[String])] = { - if (!attribute.isGenerated && resolver(attribute.name, nameParts.head)) { + if (resolver(attribute.name, nameParts.head)) { Option((attribute.withName(nameParts.head), nameParts.tail.toList)) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d8f89b108e63f..e89caabf252d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -807,15 +807,13 @@ case class SubqueryAlias( * @param withReplacement Whether to sample with replacement. * @param seed the random seed * @param child the LogicalPlan - * @param isTableSample Is created from TABLESAMPLE in the parser. */ case class Sample( lowerBound: Double, upperBound: Double, withReplacement: Boolean, seed: Long, - child: LogicalPlan)( - val isTableSample: java.lang.Boolean = false) extends UnaryNode { + child: LogicalPlan) extends UnaryNode { val eps = RandomSampler.roundingEpsilon val fraction = upperBound - lowerBound @@ -842,8 +840,6 @@ case class Sample( // Don't propagate column stats, because we don't know the distribution after a sample operation Statistics(sizeInBytes, sampledRowCount, hints = childStats.hints) } - - override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 4ed995e20d7ce..7311dc3899e53 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -573,7 +573,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan5 = Filter( Exists( Sample(0.0, 0.5, false, 1L, - Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))().select('b) + Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b))).select('b) ), LocalRelation(a)) assertAnalysisError(plan5, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index c39e372c272b1..f68d930f60523 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -491,7 +491,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { // Other unary operations testUnaryOperatorInStreamingPlan( - "sample", Sample(0.1, 1, true, 1L, _)(), expectedMsg = "sampling") + "sample", Sample(0.1, 1, true, 1L, _), expectedMsg = "sampling") testUnaryOperatorInStreamingPlan( "window", Window(Nil, Nil, Nil, _), expectedMsg = "non-time-based windows") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 0b419e9631b29..08e58d47e0e25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -349,14 +349,14 @@ class ColumnPruningSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val x = testRelation.subquery('x) - val query1 = Sample(0.0, 0.6, false, 11L, x)().select('a) + val query1 = Sample(0.0, 0.6, false, 11L, x).select('a) val optimized1 = Optimize.execute(query1.analyze) - val expected1 = Sample(0.0, 0.6, false, 11L, x.select('a))() + val expected1 = Sample(0.0, 0.6, false, 11L, x.select('a)) comparePlans(optimized1, expected1.analyze) - val query2 = Sample(0.0, 0.6, false, 11L, x)().select('a as 'aa) + val query2 = Sample(0.0, 0.6, false, 11L, x).select('a as 'aa) val optimized2 = Optimize.execute(query2.analyze) - val expected2 = Sample(0.0, 0.6, false, 11L, x.select('a))().select('a as 'aa) + val expected2 = Sample(0.0, 0.6, false, 11L, x.select('a)).select('a as 'aa) comparePlans(optimized2, expected2.analyze) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 0a4ae098d65cc..bf15b85d5b510 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -411,9 +411,9 @@ class PlanParserSuite extends AnalysisTest { assertEqual(s"$sql tablesample(100 rows)", table("t").limit(100).select(star())) assertEqual(s"$sql tablesample(43 percent) as x", - Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + Sample(0, .43d, withReplacement = false, 10L, table("t").as("x")).select(star())) assertEqual(s"$sql tablesample(bucket 4 out of 10) as x", - Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + Sample(0, .4d, withReplacement = false, 10L, table("t").as("x")).select(star())) intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x", "TABLESAMPLE(BUCKET x OUT OF y ON colname) is not supported") intercept(s"$sql tablesample(bucket 11 out of 10) as x", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 25313af2be184..6883d23d477e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -63,14 +63,14 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { */ protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { - case filter @ Filter(condition: Expression, child: LogicalPlan) => - Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode()) + case Filter(condition: Expression, child: LogicalPlan) => + Filter(splitConjunctivePredicates(condition).map(rewriteEqual).sortBy(_.hashCode()) .reduce(And), child) case sample: Sample => - sample.copy(seed = 0L)(true) - case join @ Join(left, right, joinType, condition) if condition.isDefined => + sample.copy(seed = 0L) + case Join(left, right, joinType, condition) if condition.isDefined => val newCondition = - splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode()) + splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode()) .reduce(And) Join(left, right, joinType, Some(newCondition)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index e9ed36feec48c..912c5fed63450 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -78,14 +78,14 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { } test("sample estimation") { - val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan)() + val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan) checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5))) // Child doesn't have rowCount in stats val childStats = Statistics(sizeInBytes = 120) val childPlan = DummyLogicalPlan(childStats, childStats) val sample2 = - Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan)() + Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan) checkStats(sample2, Statistics(sizeInBytes = 14)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 767dad3e63a6d..6e66e92091ff9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1807,7 +1807,7 @@ class Dataset[T] private[sql]( */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, logicalPlan)() + Sample(0.0, fraction, withReplacement, seed, logicalPlan) } } @@ -1863,7 +1863,7 @@ class Dataset[T] private[sql]( val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => new Dataset[T]( - sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan)(), encoder) + sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan), encoder) }.toArray } From 7525ce98b4575b1ac4e44cc9b3a5773f03eba19e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 24 Jun 2017 11:39:41 +0800 Subject: [PATCH 075/118] [SPARK-20431][SS][FOLLOWUP] Specify a schema by using a DDL-formatted string in DataStreamReader ## What changes were proposed in this pull request? This pr supported a DDL-formatted string in `DataStreamReader.schema`. This fix could make users easily define a schema without importing the type classes. For example, ```scala scala> spark.readStream.schema("col0 INT, col1 DOUBLE").load("/tmp/abc").printSchema() root |-- col0: integer (nullable = true) |-- col1: double (nullable = true) ``` ## How was this patch tested? Added tests in `DataStreamReaderWriterSuite`. Author: hyukjinkwon Closes #18373 from HyukjinKwon/SPARK-20431. --- python/pyspark/sql/readwriter.py | 2 ++ python/pyspark/sql/streaming.py | 24 ++++++++++++------- .../sql/streaming/DataStreamReader.scala | 12 ++++++++++ .../test/DataStreamReaderWriterSuite.scala | 12 ++++++++++ 4 files changed, 42 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index aef71f9ca7001..7279173df6e4f 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -98,6 +98,8 @@ def schema(self, schema): :param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). + + >>> s = spark.read.schema("col0 INT, col1 DOUBLE") """ from pyspark.sql import SparkSession spark = SparkSession.builder.getOrCreate() diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 58aa2468e006d..5bbd70cf0a789 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -319,16 +319,21 @@ def schema(self, schema): .. note:: Evolving. - :param schema: a :class:`pyspark.sql.types.StructType` object + :param schema: a :class:`pyspark.sql.types.StructType` object or a DDL-formatted string + (For example ``col0 INT, col1 DOUBLE``). >>> s = spark.readStream.schema(sdf_schema) + >>> s = spark.readStream.schema("col0 INT, col1 DOUBLE") """ from pyspark.sql import SparkSession - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") spark = SparkSession.builder.getOrCreate() - jschema = spark._jsparkSession.parseDataType(schema.json()) - self._jreader = self._jreader.schema(jschema) + if isinstance(schema, StructType): + jschema = spark._jsparkSession.parseDataType(schema.json()) + self._jreader = self._jreader.schema(jschema) + elif isinstance(schema, basestring): + self._jreader = self._jreader.schema(schema) + else: + raise TypeError("schema should be StructType or string") return self @since(2.0) @@ -372,7 +377,8 @@ def load(self, path=None, format=None, schema=None, **options): :param path: optional string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. - :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema + or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param options: all other string options >>> json_sdf = spark.readStream.format("json") \\ @@ -415,7 +421,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param path: string represents path to the JSON dataset, or RDD of Strings storing JSON objects. - :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema + or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param primitivesAsString: infers all primitive values as a string type. If None is set, it uses the default value, ``false``. :param prefersDecimal: infers all floating-point values as a decimal type. If the values @@ -542,7 +549,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non .. note:: Evolving. :param path: string, or list of strings, for input path(s). - :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema + or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param sep: sets the single character as a separator for each field and value. If None is set, it uses the default value, ``,``. :param encoding: decodes the CSV files by the given encoding type. If None is set, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 7e8e6394b4862..70ddfa8e9b835 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -59,6 +59,18 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo this } + /** + * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) can + * infer the input schema automatically from data. By specifying the schema here, the underlying + * data source can skip the schema inference step, and thus speed up data loading. + * + * @since 2.3.0 + */ + def schema(schemaString: String): DataStreamReader = { + this.userSpecifiedSchema = Option(StructType.fromDDL(schemaString)) + this + } + /** * Adds an input option for the underlying data source. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index b5f1e28d7396a..3de0ae67a3892 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -663,4 +663,16 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } assert(fs.exists(checkpointDir)) } + + test("SPARK-20431: Specify a schema by using a DDL-formatted string") { + spark.readStream + .format("org.apache.spark.sql.streaming.test") + .schema("aa INT") + .load() + + assert(LastOptions.schema.isDefined) + assert(LastOptions.schema.get === StructType(StructField("aa", IntegerType) :: Nil)) + + LastOptions.clear() + } } From b837bf9ae97cf7ee7558c10a5a34636e69367a05 Mon Sep 17 00:00:00 2001 From: Gabor Feher Date: Fri, 23 Jun 2017 21:53:38 -0700 Subject: [PATCH 076/118] [SPARK-20555][SQL] Fix mapping of Oracle DECIMAL types to Spark types in read path ## What changes were proposed in this pull request? This PR is to revert some code changes in the read path of https://github.com/apache/spark/pull/14377. The original fix is https://github.com/apache/spark/pull/17830 When merging this PR, please give the credit to gaborfeher ## How was this patch tested? Added a test case to OracleIntegrationSuite.scala Author: Gabor Feher Author: gatorsmile Closes #18408 from gatorsmile/OracleType. --- .../sql/jdbc/OracleIntegrationSuite.scala | 65 +++++++++++++------ .../apache/spark/sql/jdbc/OracleDialect.scala | 4 -- 2 files changed, 45 insertions(+), 24 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index f7b1ec34ced76..b2f096964427e 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, Date, Timestamp} import java.util.Properties +import java.math.BigDecimal import org.apache.spark.sql.Row import org.apache.spark.sql.test.SharedSQLContext @@ -93,8 +94,31 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo |USING org.apache.spark.sql.jdbc |OPTIONS (url '$jdbcUrl', dbTable 'datetime1', oracle.jdbc.mapDateToTimestamp 'false') """.stripMargin.replaceAll("\n", " ")) + + + conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate(); + conn.prepareStatement( + "INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate(); + conn.commit(); } + + test("SPARK-16625 : Importing Oracle numeric types") { + val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties); + val rows = df.collect() + assert(rows.size == 1) + val row = rows(0) + // The main point of the below assertions is not to make sure that these Oracle types are + // mapped to decimal types, but to make sure that the returned values are correct. + // A value > 1 from DECIMAL(1) is correct: + assert(row.getDecimal(0).compareTo(BigDecimal.valueOf(4)) == 0) + // A value with fractions from DECIMAL(3, 2) is correct: + assert(row.getDecimal(1).compareTo(BigDecimal.valueOf(1.23)) == 0) + // A value > Int.MaxValue from DECIMAL(10) is correct: + assert(row.getDecimal(2).compareTo(BigDecimal.valueOf(9999999999l)) == 0) + } + + test("SPARK-12941: String datatypes to be mapped to Varchar in Oracle") { // create a sample dataframe with string type val df1 = sparkContext.parallelize(Seq(("foo"))).toDF("x") @@ -154,27 +178,28 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo val dfRead = spark.read.jdbc(jdbcUrl, tableName, props) val rows = dfRead.collect() // verify the data type is inserted - val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types(0).equals("class java.lang.Boolean")) - assert(types(1).equals("class java.lang.Integer")) - assert(types(2).equals("class java.lang.Long")) - assert(types(3).equals("class java.lang.Float")) - assert(types(4).equals("class java.lang.Float")) - assert(types(5).equals("class java.lang.Integer")) - assert(types(6).equals("class java.lang.Integer")) - assert(types(7).equals("class java.lang.String")) - assert(types(8).equals("class [B")) - assert(types(9).equals("class java.sql.Date")) - assert(types(10).equals("class java.sql.Timestamp")) + val types = dfRead.schema.map(field => field.dataType) + assert(types(0).equals(DecimalType(1, 0))) + assert(types(1).equals(DecimalType(10, 0))) + assert(types(2).equals(DecimalType(19, 0))) + assert(types(3).equals(DecimalType(19, 4))) + assert(types(4).equals(DecimalType(19, 4))) + assert(types(5).equals(DecimalType(3, 0))) + assert(types(6).equals(DecimalType(5, 0))) + assert(types(7).equals(StringType)) + assert(types(8).equals(BinaryType)) + assert(types(9).equals(DateType)) + assert(types(10).equals(TimestampType)) + // verify the value is the inserted correct or not val values = rows(0) - assert(values.getBoolean(0).equals(booleanVal)) - assert(values.getInt(1).equals(integerVal)) - assert(values.getLong(2).equals(longVal)) - assert(values.getFloat(3).equals(floatVal)) - assert(values.getFloat(4).equals(doubleVal.toFloat)) - assert(values.getInt(5).equals(byteVal.toInt)) - assert(values.getInt(6).equals(shortVal.toInt)) + assert(values.getDecimal(0).compareTo(BigDecimal.valueOf(1)) == 0) + assert(values.getDecimal(1).compareTo(BigDecimal.valueOf(integerVal)) == 0) + assert(values.getDecimal(2).compareTo(BigDecimal.valueOf(longVal)) == 0) + assert(values.getDecimal(3).compareTo(BigDecimal.valueOf(floatVal)) == 0) + assert(values.getDecimal(4).compareTo(BigDecimal.valueOf(doubleVal)) == 0) + assert(values.getDecimal(5).compareTo(BigDecimal.valueOf(byteVal)) == 0) + assert(values.getDecimal(6).compareTo(BigDecimal.valueOf(shortVal)) == 0) assert(values.getString(7).equals(stringVal)) assert(values.getAs[Array[Byte]](8).mkString.equals("678")) assert(values.getDate(9).equals(dateVal)) @@ -183,7 +208,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo test("SPARK-19318: connection property keys should be case-sensitive") { def checkRow(row: Row): Unit = { - assert(row.getInt(0) == 1) + assert(row.getDecimal(0).equals(BigDecimal.valueOf(1))) assert(row.getDate(1).equals(Date.valueOf("1991-11-09"))) assert(row.getTimestamp(2).equals(Timestamp.valueOf("1996-01-01 01:23:45"))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index f541996b651e9..20e634c06b610 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -43,10 +43,6 @@ private case object OracleDialect extends JdbcDialect { // Not sure if there is a more robust way to identify the field as a float (or other // numeric types that do not specify a scale. case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) - case 1 => Option(BooleanType) - case 3 | 5 | 10 => Option(IntegerType) - case 19 if scale == 0L => Option(LongType) - case 19 if scale == 4L => Option(FloatType) case _ => None } } else { From bfd73a7c48b87456d1b84d826e04eca938a1be64 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 24 Jun 2017 13:23:43 +0800 Subject: [PATCH 077/118] [SPARK-21159][CORE] Don't try to connect to launcher in standalone cluster mode. Monitoring for standalone cluster mode is not implemented (see SPARK-11033), but the same scheduler implementation is used, and if it tries to connect to the launcher it will fail. So fix the scheduler so it only tries that in client mode; cluster mode applications will be correctly launched and will work, but monitoring through the launcher handle will not be available. Tested by running a cluster mode app with "SparkLauncher.startApplication". Author: Marcelo Vanzin Closes #18397 from vanzin/SPARK-21159. --- .../scheduler/cluster/StandaloneSchedulerBackend.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index fd8e64454bf70..a4e2a74341283 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -58,7 +58,13 @@ private[spark] class StandaloneSchedulerBackend( override def start() { super.start() - launcherBackend.connect() + + // SPARK-21159. The scheduler backend should only try to connect to the launcher when in client + // mode. In cluster mode, the code that submits the application to the Master needs to connect + // to the launcher instead. + if (sc.deployMode == "client") { + launcherBackend.connect() + } // The endpoint for executors to talk to us val driverUrl = RpcEndpointAddress( From 7c7bc8fc0ff85fe70968b47433bb7757326a6b12 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 24 Jun 2017 10:14:31 +0100 Subject: [PATCH 078/118] [SPARK-21189][INFRA] Handle unknown error codes in Jenkins rather then leaving incomplete comment in PRs ## What changes were proposed in this pull request? Recently, Jenkins tests were unstable due to unknown reasons as below: ``` /home/jenkins/workspace/SparkPullRequestBuilder/dev/lint-r ; process was terminated by signal 9 test_result_code, test_result_note = run_tests(tests_timeout) File "./dev/run-tests-jenkins.py", line 140, in run_tests test_result_note = ' * This patch **fails %s**.' % failure_note_by_errcode[test_result_code] KeyError: -9 ``` ``` Traceback (most recent call last): File "./dev/run-tests-jenkins.py", line 226, in main() File "./dev/run-tests-jenkins.py", line 213, in main test_result_code, test_result_note = run_tests(tests_timeout) File "./dev/run-tests-jenkins.py", line 140, in run_tests test_result_note = ' * This patch **fails %s**.' % failure_note_by_errcode[test_result_code] KeyError: -10 ``` This exception looks causing failing to update the comments in the PR. For example: ![2017-06-23 4 19 41](https://user-images.githubusercontent.com/6477701/27470626-d035ecd8-582f-11e7-883e-0ae6941659b7.png) ![2017-06-23 4 19 50](https://user-images.githubusercontent.com/6477701/27470629-d11ba782-582f-11e7-97e0-64d28cbc19aa.png) these comment just remain. This always requires, for both reviewers and the author, a overhead to click and check the logs, which I believe are not really useful. This PR proposes to leave the code in the PR comment messages and let update the comments. ## How was this patch tested? Jenkins tests below, I manually gave the error code to test this. Author: hyukjinkwon Closes #18399 from HyukjinKwon/jenkins-print-errors. --- dev/run-tests-jenkins.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 53061bc947e5f..914eb93622d51 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -137,7 +137,9 @@ def run_tests(tests_timeout): if test_result_code == 0: test_result_note = ' * This patch passes all tests.' else: - test_result_note = ' * This patch **fails %s**.' % failure_note_by_errcode[test_result_code] + note = failure_note_by_errcode.get( + test_result_code, "due to an unknown error code, %s" % test_result_code) + test_result_note = ' * This patch **fails %s**.' % note return [test_result_code, test_result_note] From 2e1586f60a77ea0adb6f3f68ba74323f0c242199 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 24 Jun 2017 22:35:59 +0800 Subject: [PATCH 079/118] [SPARK-21203][SQL] Fix wrong results of insertion of Array of Struct ### What changes were proposed in this pull request? ```SQL CREATE TABLE `tab1` (`custom_fields` ARRAY>) USING parquet INSERT INTO `tab1` SELECT ARRAY(named_struct('id', 1, 'value', 'a'), named_struct('id', 2, 'value', 'b')) SELECT custom_fields.id, custom_fields.value FROM tab1 ``` The above query always return the last struct of the array, because the rule `SimplifyCasts` incorrectly rewrites the query. The underlying cause is we always use the same `GenericInternalRow` object when doing the cast. ### How was this patch tested? Author: gatorsmile Closes #18412 from gatorsmile/castStruct. --- .../spark/sql/catalyst/expressions/Cast.scala | 4 ++-- .../spark/sql/sources/InsertSuite.scala | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a53ef426f79b5..43df19ba009a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -482,15 +482,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case (fromField, toField) => cast(fromField.dataType, toField.dataType) } // TODO: Could be faster? - val newRow = new GenericInternalRow(from.fields.length) buildCast[InternalRow](_, row => { + val newRow = new GenericInternalRow(from.fields.length) var i = 0 while (i < row.numFields) { newRow.update(i, if (row.isNullAt(i)) null else castFuncs(i)(row.get(i, from.apply(i).dataType))) i += 1 } - newRow.copy() + newRow }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 2eae66dda88de..41abff2a5da25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -345,4 +345,25 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { ) } } + + test("SPARK-21203 wrong results of insertion of Array of Struct") { + val tabName = "tab1" + withTable(tabName) { + spark.sql( + """ + |CREATE TABLE `tab1` + |(`custom_fields` ARRAY>) + |USING parquet + """.stripMargin) + spark.sql( + """ + |INSERT INTO `tab1` + |SELECT ARRAY(named_struct('id', 1, 'value', 'a'), named_struct('id', 2, 'value', 'b')) + """.stripMargin) + + checkAnswer( + spark.sql("SELECT custom_fields.id, custom_fields.value FROM tab1"), + Row(Array(1, 2), Array("a", "b"))) + } + } } From b449a1d6aa322a50cf221cd7a2ae85a91d6c7e9f Mon Sep 17 00:00:00 2001 From: Masha Basmanova Date: Sat, 24 Jun 2017 22:49:35 -0700 Subject: [PATCH 080/118] [SPARK-21079][SQL] Calculate total size of a partition table as a sum of individual partitions ## What changes were proposed in this pull request? Storage URI of a partitioned table may or may not point to a directory under which individual partitions are stored. In fact, individual partitions may be located in totally unrelated directories. Before this change, ANALYZE TABLE table COMPUTE STATISTICS command calculated total size of a table by adding up sizes of files found under table's storage URI. This calculation could produce 0 if partitions are stored elsewhere. This change uses storage URIs of individual partitions to calculate the sizes of all partitions of a table and adds these up to produce the total size of a table. CC: wzhfy ## How was this patch tested? Added unit test. Ran ANALYZE TABLE xxx COMPUTE STATISTICS on a partitioned Hive table and verified that sizeInBytes is calculated correctly. Before this change, the size would be zero. Author: Masha Basmanova Closes #18309 from mbasmanova/mbasmanova-analyze-part-table. --- .../command/AnalyzeTableCommand.scala | 29 ++++++-- .../spark/sql/hive/StatisticsSuite.scala | 72 +++++++++++++++++++ 2 files changed, 95 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 3c59b982c2dca..06e588f56f1e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.net.URI + import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} @@ -81,6 +83,21 @@ case class AnalyzeTableCommand( object AnalyzeTableCommand extends Logging { def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): Long = { + if (catalogTable.partitionColumnNames.isEmpty) { + calculateLocationSize(sessionState, catalogTable.identifier, catalogTable.storage.locationUri) + } else { + // Calculate table size as a sum of the visible partitions. See SPARK-21079 + val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) + partitions.map(p => + calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) + ).sum + } + } + + private def calculateLocationSize( + sessionState: SessionState, + tableId: TableIdentifier, + locationUri: Option[URI]): Long = { // This method is mainly based on // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) // in Hive 0.13 (except that we do not use fs.getContentSummary). @@ -91,13 +108,13 @@ object AnalyzeTableCommand extends Logging { // countFileSize to count the table size. val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") - def calculateTableSize(fs: FileSystem, path: Path): Long = { + def calculateLocationSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) val size = if (fileStatus.isDirectory) { fs.listStatus(path) .map { status => if (!status.getPath.getName.startsWith(stagingDir)) { - calculateTableSize(fs, status.getPath) + calculateLocationSize(fs, status.getPath) } else { 0L } @@ -109,16 +126,16 @@ object AnalyzeTableCommand extends Logging { size } - catalogTable.storage.locationUri.map { p => + locationUri.map { p => val path = new Path(p) try { val fs = path.getFileSystem(sessionState.newHadoopConf()) - calculateTableSize(fs, path) + calculateLocationSize(fs, path) } catch { case NonFatal(e) => logWarning( - s"Failed to get the size of table ${catalogTable.identifier.table} in the " + - s"database ${catalogTable.identifier.database} because of ${e.toString}", e) + s"Failed to get the size of table ${tableId.table} in the " + + s"database ${tableId.database} because of ${e.toString}", e) 0L } }.getOrElse(0L) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 279db9a397258..0ee18bbe9befe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { @@ -128,6 +129,77 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto TableIdentifier("tempTable"), ignoreIfNotExists = true, purge = false) } + test("SPARK-21079 - analyze table with location different than that of individual partitions") { + def queryTotalSize(tableName: String): BigInt = + spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + + val tableName = "analyzeTable_part" + withTable(tableName) { + withTempPath { path => + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)") + + val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03") + partitionDates.foreach { ds => + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds') SELECT * FROM src") + } + + sql(s"ALTER TABLE $tableName SET LOCATION '$path'") + + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") + + assert(queryTotalSize(tableName) === BigInt(17436)) + } + } + } + + test("SPARK-21079 - analyze partitioned table with only a subset of partitions visible") { + def queryTotalSize(tableName: String): BigInt = + spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + + val sourceTableName = "analyzeTable_part" + val tableName = "analyzeTable_part_vis" + withTable(sourceTableName, tableName) { + withTempPath { path => + // Create a table with 3 partitions all located under a single top-level directory 'path' + sql( + s""" + |CREATE TABLE $sourceTableName (key STRING, value STRING) + |PARTITIONED BY (ds STRING) + |LOCATION '$path' + """.stripMargin) + + val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03") + partitionDates.foreach { ds => + sql( + s""" + |INSERT INTO TABLE $sourceTableName PARTITION (ds='$ds') + |SELECT * FROM src + """.stripMargin) + } + + // Create another table referring to the same location + sql( + s""" + |CREATE TABLE $tableName (key STRING, value STRING) + |PARTITIONED BY (ds STRING) + |LOCATION '$path' + """.stripMargin) + + // Register only one of the partitions found on disk + val ds = partitionDates.head + sql(s"ALTER TABLE $tableName ADD PARTITION (ds='$ds')").collect() + + // Analyze original table - expect 3 partitions + sql(s"ANALYZE TABLE $sourceTableName COMPUTE STATISTICS noscan") + assert(queryTotalSize(sourceTableName) === BigInt(3 * 5812)) + + // Analyze partial-copy table - expect only 1 partition + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") + assert(queryTotalSize(tableName) === BigInt(5812)) + } + } + } + test("analyzing views is not supported") { def assertAnalyzeUnsupported(analyzeCommand: String): Unit = { val err = intercept[AnalysisException] { From 884347e1f79e4e7c157834881e79447d7ee58f88 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Sun, 25 Jun 2017 15:06:29 +0100 Subject: [PATCH 081/118] [HOT FIX] fix stats functions in the recent patch ## What changes were proposed in this pull request? Builds failed due to the recent [merge](https://github.com/apache/spark/commit/b449a1d6aa322a50cf221cd7a2ae85a91d6c7e9f). This is because [PR#18309](https://github.com/apache/spark/pull/18309) needed update after [this patch](https://github.com/apache/spark/commit/b803b66a8133f705463039325ee71ee6827ce1a7) was merged. ## How was this patch tested? N/A Author: Zhenhua Wang Closes #18415 from wzhfy/hotfixStats. --- .../scala/org/apache/spark/sql/hive/StatisticsSuite.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 0ee18bbe9befe..64deb3818d5d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { @@ -131,7 +130,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("SPARK-21079 - analyze table with location different than that of individual partitions") { def queryTotalSize(tableName: String): BigInt = - spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes val tableName = "analyzeTable_part" withTable(tableName) { @@ -154,7 +153,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("SPARK-21079 - analyze partitioned table with only a subset of partitions visible") { def queryTotalSize(tableName: String): BigInt = - spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes val sourceTableName = "analyzeTable_part" val tableName = "analyzeTable_part_vis" From 6b3d02285ee0debc73cbcab01b10398a498fbeb8 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 25 Jun 2017 11:05:57 -0700 Subject: [PATCH 082/118] [SPARK-21093][R] Terminate R's worker processes in the parent of R's daemon to prevent a leak ## What changes were proposed in this pull request? `mcfork` in R looks opening a pipe ahead but the existing logic does not properly close it when it is executed hot. This leads to the failure of more forking due to the limit for number of files open. This hot execution looks particularly for `gapply`/`gapplyCollect`. For unknown reason, this happens more easily in CentOS and could be reproduced in Mac too. All the details are described in https://issues.apache.org/jira/browse/SPARK-21093 This PR proposes simply to terminate R's worker processes in the parent of R's daemon to prevent a leak. ## How was this patch tested? I ran the codes below on both CentOS and Mac with that configuration disabled/enabled. ```r df <- createDataFrame(list(list(1L, 1, "1", 0.1)), c("a", "b", "c", "d")) collect(gapply(df, "a", function(key, x) { x }, schema(df))) collect(gapply(df, "a", function(key, x) { x }, schema(df))) ... # 30 times ``` Also, now it passes R tests on CentOS as below: ``` SparkSQL functions: Spark package found in SPARK_HOME: .../spark .............................................................................................................................................................. .............................................................................................................................................................. .............................................................................................................................................................. .............................................................................................................................................................. .............................................................................................................................................................. .................................................................................................................................... ``` Author: hyukjinkwon Closes #18320 from HyukjinKwon/SPARK-21093. --- R/pkg/inst/worker/daemon.R | 59 +++++++++++++++++++++++++++++++++++--- 1 file changed, 55 insertions(+), 4 deletions(-) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 3a318b71ea06d..6e385b2a27622 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -30,8 +30,55 @@ port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) +# Waits indefinitely for a socket connecion by default. +selectTimeout <- NULL + +# Exit code that children send to the parent to indicate they exited. +exitCode <- 1 + while (TRUE) { - ready <- socketSelect(list(inputCon)) + ready <- socketSelect(list(inputCon), timeout = selectTimeout) + + # Note that the children should be terminated in the parent. If each child terminates + # itself, it appears that the resource is not released properly, that causes an unexpected + # termination of this daemon due to, for example, running out of file descriptors + # (see SPARK-21093). Therefore, the current implementation tries to retrieve children + # that are exited (but not terminated) and then sends a kill signal to terminate them properly + # in the parent. + # + # There are two paths that it attempts to send a signal to terminate the children in the parent. + # + # 1. Every second if any socket connection is not available and if there are child workers + # running. + # 2. Right after a socket connection is available. + # + # In other words, the parent attempts to send the signal to the children every second if + # any worker is running or right before launching other worker children from the following + # new socket connection. + + # Only the process IDs of children sent data to the parent are returned below. The children + # send a custom exit code to the parent after being exited and the parent tries + # to terminate them only if they sent the exit code. + children <- parallel:::selectChildren(timeout = 0) + + if (is.integer(children)) { + lapply(children, function(child) { + # This data should be raw bytes if any data was sent from this child. + # Otherwise, this returns the PID. + data <- parallel:::readChild(child) + if (is.raw(data)) { + # This checks if the data from this child is the exit code that indicates an exited child. + if (unserialize(data) == exitCode) { + # If so, we terminate this child. + tools::pskill(child, tools::SIGUSR1) + } + } + }) + } else if (is.null(children)) { + # If it is NULL, there are no children. Waits indefinitely for a socket connecion. + selectTimeout <- NULL + } + if (ready) { port <- SparkR:::readInt(inputCon) # There is a small chance that it could be interrupted by signal, retry one time @@ -44,12 +91,16 @@ while (TRUE) { } p <- parallel:::mcfork() if (inherits(p, "masterProcess")) { + # Reach here because this is a child process. close(inputCon) Sys.setenv(SPARKR_WORKER_PORT = port) try(source(script)) - # Set SIGUSR1 so that child can exit - tools::pskill(Sys.getpid(), tools::SIGUSR1) - parallel:::mcexit(0L) + # Note that this mcexit does not fully terminate this child. So, this writes back + # a custom exit code so that the parent can read and terminate this child. + parallel:::mcexit(0L, send = exitCode) + } else { + # Forking succeeded and we need to check if they finished their jobs every second. + selectTimeout <- 1 } } } From 5282bae0408dec8aa0cefafd7673dd34d232ead9 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 26 Jun 2017 01:26:32 -0700 Subject: [PATCH 083/118] [SPARK-21153] Use project instead of expand in tumbling windows ## What changes were proposed in this pull request? Time windowing in Spark currently performs an Expand + Filter, because there is no way to guarantee the amount of windows a timestamp will fall in, in the general case. However, for tumbling windows, a record is guaranteed to fall into a single bucket. In this case, doubling the number of records with Expand is wasteful, and can be improved by using a simple Projection instead. Benchmarks show that we get an order of magnitude performance improvement after this patch. ## How was this patch tested? Existing unit tests. Benchmarked using the following code: ```scala import org.apache.spark.sql.functions._ spark.time { spark.range(numRecords) .select(from_unixtime((current_timestamp().cast("long") * 1000 + 'id / 1000) / 1000) as 'time) .select(window('time, "10 seconds")) .count() } ``` Setup: - 1 c3.2xlarge worker (8 cores) ![image](https://user-images.githubusercontent.com/5243515/27348748-ed991b84-55a9-11e7-8f8b-6e7abc524417.png) 1 B rows ran in 287 seconds after this optimization. I didn't wait for it to finish without the optimization. Shows about 5x improvement for large number of records. Author: Burak Yavuz Closes #18364 from brkyvz/opt-tumble. --- .../sql/catalyst/analysis/Analyzer.scala | 72 +++++++++++++------ .../sql/catalyst/expressions/TimeWindow.scala | 12 ++-- .../sql/DataFrameTimeWindowingSuite.scala | 49 +++++++++---- 3 files changed, 94 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7e5ebfc93286f..434b6ffee37fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2301,6 +2301,7 @@ object EliminateEventTimeWatermark extends Rule[LogicalPlan] { object TimeWindowing extends Rule[LogicalPlan] { import org.apache.spark.sql.catalyst.dsl.expressions._ + private final val WINDOW_COL_NAME = "window" private final val WINDOW_START = "start" private final val WINDOW_END = "end" @@ -2336,49 +2337,76 @@ object TimeWindowing extends Rule[LogicalPlan] { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = - p.expressions.flatMap(_.collect { case t: TimeWindow => t }).distinct.toList // Not correct. + p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet // Only support a single window expression for now if (windowExpressions.size == 1 && windowExpressions.head.timeColumn.resolved && windowExpressions.head.checkInputDataTypes().isSuccess) { + val window = windowExpressions.head val metadata = window.timeColumn match { case a: Attribute => a.metadata case _ => Metadata.empty } - val windowAttr = - AttributeReference("window", window.dataType, metadata = metadata)() - - val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt - val windows = Seq.tabulate(maxNumOverlapping + 1) { i => - val windowId = Ceil((PreciseTimestamp(window.timeColumn) - window.startTime) / - window.slideDuration) - val windowStart = (windowId + i - maxNumOverlapping) * - window.slideDuration + window.startTime + + def getWindow(i: Int, overlappingWindows: Int): Expression = { + val division = (PreciseTimestampConversion( + window.timeColumn, TimestampType, LongType) - window.startTime) / window.slideDuration + val ceil = Ceil(division) + // if the division is equal to the ceiling, our record is the start of a window + val windowId = CaseWhen(Seq((ceil === division, ceil + 1)), Some(ceil)) + val windowStart = (windowId + i - overlappingWindows) * + window.slideDuration + window.startTime val windowEnd = windowStart + window.windowDuration CreateNamedStruct( - Literal(WINDOW_START) :: windowStart :: - Literal(WINDOW_END) :: windowEnd :: Nil) + Literal(WINDOW_START) :: + PreciseTimestampConversion(windowStart, LongType, TimestampType) :: + Literal(WINDOW_END) :: + PreciseTimestampConversion(windowEnd, LongType, TimestampType) :: + Nil) } - val projections = windows.map(_ +: p.children.head.output) + val windowAttr = AttributeReference( + WINDOW_COL_NAME, window.dataType, metadata = metadata)() + + if (window.windowDuration == window.slideDuration) { + val windowStruct = Alias(getWindow(0, 1), WINDOW_COL_NAME)( + exprId = windowAttr.exprId) + + val replacedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } + + // For backwards compatibility we add a filter to filter out nulls + val filterExpr = IsNotNull(window.timeColumn) - val filterExpr = - window.timeColumn >= windowAttr.getField(WINDOW_START) && - window.timeColumn < windowAttr.getField(WINDOW_END) + replacedPlan.withNewChildren( + Filter(filterExpr, + Project(windowStruct +: child.output, child)) :: Nil) + } else { + val overlappingWindows = + math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt + val windows = + Seq.tabulate(overlappingWindows)(i => getWindow(i, overlappingWindows)) + + val projections = windows.map(_ +: child.output) + + val filterExpr = + window.timeColumn >= windowAttr.getField(WINDOW_START) && + window.timeColumn < windowAttr.getField(WINDOW_END) - val expandedPlan = - Filter(filterExpr, + val substitutedPlan = Filter(filterExpr, Expand(projections, windowAttr +: child.output, child)) - val substitutedPlan = p transformExpressions { - case t: TimeWindow => windowAttr - } + val renamedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } - substitutedPlan.withNewChildren(expandedPlan :: Nil) + renamedPlan.withNewChildren(substitutedPlan :: Nil) + } } else if (windowExpressions.size > 1) { p.failAnalysis("Multiple time window expressions would result in a cartesian product " + "of rows, therefore they are currently not supported.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 7ff61ee479452..9a9f579b37f58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -152,12 +152,15 @@ object TimeWindow { } /** - * Expression used internally to convert the TimestampType to Long without losing + * Expression used internally to convert the TimestampType to Long and back without losing * precision, i.e. in microseconds. Used in time windowing. */ -case class PreciseTimestamp(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) - override def dataType: DataType = LongType +case class PreciseTimestampConversion( + child: Expression, + fromType: DataType, + toType: DataType) extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(fromType) + override def dataType: DataType = toType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) ev.copy(code = eval.code + @@ -165,4 +168,5 @@ case class PreciseTimestamp(child: Expression) extends UnaryExpression with Expe |${ctx.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } + override def nullSafeEval(input: Any): Any = input } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 22d5c47a6fb51..6fe356877c268 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql -import java.util.TimeZone - import org.scalatest.BeforeAndAfterEach +import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StringType @@ -29,11 +28,27 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B import testImplicits._ + test("simple tumbling window with record at window start") { + val df = Seq( + ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"), + Seq( + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1) + ) + ) + } + test("tumbling window groupBy statement") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), ("2016-03-27 19:39:56", 2, "a"), ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + checkAnswer( df.groupBy(window($"time", "10 seconds")) .agg(count("*").as("counts")) @@ -59,14 +74,18 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("tumbling window with multi-column projection") { val df = Seq( - ("2016-03-27 19:39:34", 1, "a"), - ("2016-03-27 19:39:56", 2, "a"), - ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(window($"time", "10 seconds"), $"value") + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"value") + + val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand]) + assert(expands.isEmpty, "Tumbling windows shouldn't require expand") checkAnswer( - df.select(window($"time", "10 seconds"), $"value") - .orderBy($"window.start".asc) - .select($"window.start".cast("string"), $"window.end".cast("string"), $"value"), + df, Seq( Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), @@ -104,13 +123,17 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("sliding window projection") { val df = Seq( - ("2016-03-27 19:39:34", 1, "a"), - ("2016-03-27 19:39:56", 2, "a"), - ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value") + .orderBy($"window.start".asc, $"value".desc).select("value") + + val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand]) + assert(expands.nonEmpty, "Sliding windows require expand") checkAnswer( - df.select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value") - .orderBy($"window.start".asc, $"value".desc).select("value"), + df, // 2016-03-27 19:39:27 UTC -> 4 bins // 2016-03-27 19:39:34 UTC -> 3 bins // 2016-03-27 19:39:56 UTC -> 3 bins From 9e50a1d37a4cf0c34e20a7c1a910ceaff41535a2 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 26 Jun 2017 11:14:03 -0500 Subject: [PATCH 084/118] [SPARK-13669][SPARK-20898][CORE] Improve the blacklist mechanism to handle external shuffle service unavailable situation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Currently we are running into an issue with Yarn work preserving enabled + external shuffle service. In the work preserving enabled scenario, the failure of NM will not lead to the exit of executors, so executors can still accept and run the tasks. The problem here is when NM is failed, external shuffle service is actually inaccessible, so reduce tasks will always complain about the “Fetch failure”, and the failure of reduce stage will make the parent stage (map stage) rerun. The tricky thing here is Spark scheduler is not aware of the unavailability of external shuffle service, and will reschedule the map tasks on the executor where NM is failed, and again reduce stage will be failed with “Fetch failure”, and after 4 retries, the job is failed. This could also apply to other cluster manager with external shuffle service. So here the main problem is that we should avoid assigning tasks to those bad executors (where shuffle service is unavailable). Current Spark's blacklist mechanism could blacklist executors/nodes by failure tasks, but it doesn't handle this specific fetch failure scenario. So here propose to improve the current application blacklist mechanism to handle fetch failure issue (especially with external shuffle service unavailable issue), to blacklist the executors/nodes where shuffle fetch is unavailable. ## How was this patch tested? Unit test and small cluster verification. Author: jerryshao Closes #17113 from jerryshao/SPARK-13669. --- .../spark/internal/config/package.scala | 5 + .../spark/scheduler/BlacklistTracker.scala | 95 ++++++++++++++----- .../spark/scheduler/TaskSchedulerImpl.scala | 18 +--- .../spark/scheduler/TaskSetManager.scala | 6 ++ .../scheduler/BlacklistTrackerSuite.scala | 55 +++++++++++ .../scheduler/TaskSchedulerImplSuite.scala | 4 +- .../spark/scheduler/TaskSetManagerSuite.scala | 32 +++++++ docs/configuration.md | 9 ++ 8 files changed, 186 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 462c1890fd8df..be63c637a3a13 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -149,6 +149,11 @@ package object config { .internal() .timeConf(TimeUnit.MILLISECONDS) .createOptional + + private[spark] val BLACKLIST_FETCH_FAILURE_ENABLED = + ConfigBuilder("spark.blacklist.application.fetchFailure.enabled") + .booleanConf + .createWithDefault(false) // End blacklist confs private[spark] val UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE = diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index e130e609e4f63..cd8e61d6d0208 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -61,6 +61,7 @@ private[scheduler] class BlacklistTracker ( private val MAX_FAILURES_PER_EXEC = conf.get(config.MAX_FAILURES_PER_EXEC) private val MAX_FAILED_EXEC_PER_NODE = conf.get(config.MAX_FAILED_EXEC_PER_NODE) val BLACKLIST_TIMEOUT_MILLIS = BlacklistTracker.getBlacklistTimeout(conf) + private val BLACKLIST_FETCH_FAILURE_ENABLED = conf.get(config.BLACKLIST_FETCH_FAILURE_ENABLED) /** * A map from executorId to information on task failures. Tracks the time of each task failure, @@ -145,6 +146,74 @@ private[scheduler] class BlacklistTracker ( nextExpiryTime = math.min(execMinExpiry, nodeMinExpiry) } + private def killBlacklistedExecutor(exec: String): Unit = { + if (conf.get(config.BLACKLIST_KILL_ENABLED)) { + allocationClient match { + case Some(a) => + logInfo(s"Killing blacklisted executor id $exec " + + s"since ${config.BLACKLIST_KILL_ENABLED.key} is set.") + a.killExecutors(Seq(exec), true, true) + case None => + logWarning(s"Not attempting to kill blacklisted executor id $exec " + + s"since allocation client is not defined.") + } + } + } + + private def killExecutorsOnBlacklistedNode(node: String): Unit = { + if (conf.get(config.BLACKLIST_KILL_ENABLED)) { + allocationClient match { + case Some(a) => + logInfo(s"Killing all executors on blacklisted host $node " + + s"since ${config.BLACKLIST_KILL_ENABLED.key} is set.") + if (a.killExecutorsOnHost(node) == false) { + logError(s"Killing executors on node $node failed.") + } + case None => + logWarning(s"Not attempting to kill executors on blacklisted host $node " + + s"since allocation client is not defined.") + } + } + } + + def updateBlacklistForFetchFailure(host: String, exec: String): Unit = { + if (BLACKLIST_FETCH_FAILURE_ENABLED) { + // If we blacklist on fetch failures, we are implicitly saying that we believe the failure is + // non-transient, and can't be recovered from (even if this is the first fetch failure, + // stage is retried after just one failure, so we don't always get a chance to collect + // multiple fetch failures). + // If the external shuffle-service is on, then every other executor on this node would + // be suffering from the same issue, so we should blacklist (and potentially kill) all + // of them immediately. + + val now = clock.getTimeMillis() + val expiryTimeForNewBlacklists = now + BLACKLIST_TIMEOUT_MILLIS + + if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { + if (!nodeIdToBlacklistExpiryTime.contains(host)) { + logInfo(s"blacklisting node $host due to fetch failure of external shuffle service") + + nodeIdToBlacklistExpiryTime.put(host, expiryTimeForNewBlacklists) + listenerBus.post(SparkListenerNodeBlacklisted(now, host, 1)) + _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) + killExecutorsOnBlacklistedNode(host) + updateNextExpiryTime() + } + } else if (!executorIdToBlacklistStatus.contains(exec)) { + logInfo(s"Blacklisting executor $exec due to fetch failure") + + executorIdToBlacklistStatus.put(exec, BlacklistedExecutor(host, expiryTimeForNewBlacklists)) + // We hardcoded number of failure tasks to 1 for fetch failure, because there's no + // reattempt for such failure. + listenerBus.post(SparkListenerExecutorBlacklisted(now, exec, 1)) + updateNextExpiryTime() + killBlacklistedExecutor(exec) + + val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(exec, HashSet[String]()) + blacklistedExecsOnNode += exec + } + } + } def updateBlacklistForSuccessfulTaskSet( stageId: Int, @@ -174,17 +243,7 @@ private[scheduler] class BlacklistTracker ( listenerBus.post(SparkListenerExecutorBlacklisted(now, exec, newTotal)) executorIdToFailureList.remove(exec) updateNextExpiryTime() - if (conf.get(config.BLACKLIST_KILL_ENABLED)) { - allocationClient match { - case Some(allocationClient) => - logInfo(s"Killing blacklisted executor id $exec " + - s"since spark.blacklist.killBlacklistedExecutors is set.") - allocationClient.killExecutors(Seq(exec), true, true) - case None => - logWarning(s"Not attempting to kill blacklisted executor id $exec " + - s"since allocation client is not defined.") - } - } + killBlacklistedExecutor(exec) // In addition to blacklisting the executor, we also update the data for failures on the // node, and potentially put the entire node into a blacklist as well. @@ -199,19 +258,7 @@ private[scheduler] class BlacklistTracker ( nodeIdToBlacklistExpiryTime.put(node, expiryTimeForNewBlacklists) listenerBus.post(SparkListenerNodeBlacklisted(now, node, blacklistedExecsOnNode.size)) _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) - if (conf.get(config.BLACKLIST_KILL_ENABLED)) { - allocationClient match { - case Some(allocationClient) => - logInfo(s"Killing all executors on blacklisted host $node " + - s"since spark.blacklist.killBlacklistedExecutors is set.") - if (allocationClient.killExecutorsOnHost(node) == false) { - logError(s"Killing executors on node $node failed.") - } - case None => - logWarning(s"Not attempting to kill executors on blacklisted host $node " + - s"since allocation client is not defined.") - } - } + killExecutorsOnBlacklistedNode(node) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index bba0b294f1afb..91ec172ffeda1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -51,29 +51,21 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} * acquire a lock on us, so we need to make sure that we don't try to lock the backend while * we are holding a lock on ourselves. */ -private[spark] class TaskSchedulerImpl private[scheduler]( +private[spark] class TaskSchedulerImpl( val sc: SparkContext, val maxTaskFailures: Int, - private[scheduler] val blacklistTrackerOpt: Option[BlacklistTracker], isLocal: Boolean = false) extends TaskScheduler with Logging { import TaskSchedulerImpl._ def this(sc: SparkContext) = { - this( - sc, - sc.conf.get(config.MAX_TASK_FAILURES), - TaskSchedulerImpl.maybeCreateBlacklistTracker(sc)) + this(sc, sc.conf.get(config.MAX_TASK_FAILURES)) } - def this(sc: SparkContext, maxTaskFailures: Int, isLocal: Boolean) = { - this( - sc, - maxTaskFailures, - TaskSchedulerImpl.maybeCreateBlacklistTracker(sc), - isLocal = isLocal) - } + // Lazily initializing blackListTrackOpt to avoid getting empty ExecutorAllocationClient, + // because ExecutorAllocationClient is created after this TaskSchedulerImpl. + private[scheduler] lazy val blacklistTrackerOpt = maybeCreateBlacklistTracker(sc) val conf = sc.conf diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a41b059fa7dec..02d374dc37cd5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -774,6 +774,12 @@ private[spark] class TaskSetManager( tasksSuccessful += 1 } isZombie = true + + if (fetchFailed.bmAddress != null) { + blacklistTracker.foreach(_.updateBlacklistForFetchFailure( + fetchFailed.bmAddress.host, fetchFailed.bmAddress.executorId)) + } + None case ef: ExceptionFailure => diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index 571c6bbb4585d..7ff03c44b0611 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -530,4 +530,59 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M verify(allocationClientMock).killExecutors(Seq("2"), true, true) verify(allocationClientMock).killExecutorsOnHost("hostA") } + + test("fetch failure blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") { + val allocationClientMock = mock[ExecutorAllocationClient] + when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called")) + when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] { + // To avoid a race between blacklisting and killing, it is important that the nodeBlacklist + // is updated before we ask the executor allocation client to kill all the executors + // on a particular host. + override def answer(invocation: InvocationOnMock): Boolean = { + if (blacklist.nodeBlacklist.contains("hostA") == false) { + throw new IllegalStateException("hostA should be on the blacklist") + } + true + } + }) + + conf.set(config.BLACKLIST_FETCH_FAILURE_ENABLED, true) + blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock) + + // Disable auto-kill. Blacklist an executor and make sure killExecutors is not called. + conf.set(config.BLACKLIST_KILL_ENABLED, false) + blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") + + verify(allocationClientMock, never).killExecutors(any(), any(), any()) + verify(allocationClientMock, never).killExecutorsOnHost(any()) + + // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. + conf.set(config.BLACKLIST_KILL_ENABLED, true) + blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock) + clock.advance(1000) + blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") + + verify(allocationClientMock).killExecutors(Seq("1"), true, true) + verify(allocationClientMock, never).killExecutorsOnHost(any()) + + assert(blacklist.executorIdToBlacklistStatus.contains("1")) + assert(blacklist.executorIdToBlacklistStatus("1").node === "hostA") + assert(blacklist.executorIdToBlacklistStatus("1").expiryTime === + 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + assert(blacklist.nodeIdToBlacklistExpiryTime.isEmpty) + + // Enable external shuffle service to see if all the executors on this node will be killed. + conf.set(config.SHUFFLE_SERVICE_ENABLED, true) + clock.advance(1000) + blacklist.updateBlacklistForFetchFailure("hostA", exec = "2") + + verify(allocationClientMock, never).killExecutors(Seq("2"), true, true) + verify(allocationClientMock).killExecutorsOnHost("hostA") + + assert(blacklist.nodeIdToBlacklistExpiryTime.contains("hostA")) + assert(blacklist.nodeIdToBlacklistExpiryTime("hostA") === + 2000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 8b9d45f734cda..a00337776dadc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -87,7 +87,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B conf.set(config.BLACKLIST_ENABLED, true) sc = new SparkContext(conf) taskScheduler = - new TaskSchedulerImpl(sc, sc.conf.getInt("spark.task.maxFailures", 4), Some(blacklist)) { + new TaskSchedulerImpl(sc, sc.conf.getInt("spark.task.maxFailures", 4)) { override def createTaskSetManager(taskSet: TaskSet, maxFailures: Int): TaskSetManager = { val tsm = super.createTaskSetManager(taskSet, maxFailures) // we need to create a spied tsm just so we can set the TaskSetBlacklist @@ -98,6 +98,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B stageToMockTaskSetBlacklist(taskSet.stageId) = taskSetBlacklist tsmSpy } + + override private[scheduler] lazy val blacklistTrackerOpt = Some(blacklist) } setupHelper() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index db14c9acfdce5..80fb674725814 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1140,6 +1140,38 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg .updateBlacklistForFailedTask(anyString(), anyString(), anyInt()) } + test("update application blacklist for shuffle-fetch") { + // Setup a taskset, and fail some one task for fetch failure. + val conf = new SparkConf() + .set(config.BLACKLIST_ENABLED, true) + .set(config.SHUFFLE_SERVICE_ENABLED, true) + .set(config.BLACKLIST_FETCH_FAILURE_ENABLED, true) + sc = new SparkContext("local", "test", conf) + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = FakeTask.createTaskSet(4) + val blacklistTracker = new BlacklistTracker(sc, None) + val tsm = new TaskSetManager(sched, taskSet, 4, Some(blacklistTracker)) + + // make some offers to our taskset, to get tasks we will fail + val taskDescs = Seq( + "exec1" -> "host1", + "exec2" -> "host2" + ).flatMap { case (exec, host) => + // offer each executor twice (simulating 2 cores per executor) + (0 until 2).flatMap{ _ => tsm.resourceOffer(exec, host, TaskLocality.ANY)} + } + assert(taskDescs.size === 4) + + assert(!blacklistTracker.isExecutorBlacklisted(taskDescs(0).executorId)) + assert(!blacklistTracker.isNodeBlacklisted("host1")) + + // Fail the task with fetch failure + tsm.handleFailedTask(taskDescs(0).taskId, TaskState.FAILED, + FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0, 0, "ignored")) + + assert(blacklistTracker.isNodeBlacklisted("host1")) + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { diff --git a/docs/configuration.md b/docs/configuration.md index f4bec589208be..c8e61537a457c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1479,6 +1479,15 @@ Apart from these, the following properties are also available, and may be useful all of the executors on that node will be killed. + + spark.blacklist.application.fetchFailure.enabled + false + + (Experimental) If set to "true", Spark will blacklist the executor immediately when a fetch + failure happenes. If external shuffle service is enabled, then the whole node will be + blacklisted. + + spark.speculation false From c22810004fb2db249be6477c9801d09b807af851 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 27 Jun 2017 02:35:51 +0800 Subject: [PATCH 085/118] [SPARK-20213][SQL][FOLLOW-UP] introduce SQLExecution.ignoreNestedExecutionId ## What changes were proposed in this pull request? in https://github.com/apache/spark/pull/18064, to work around the nested sql execution id issue, we introduced several internal methods in `Dataset`, like `collectInternal`, `countInternal`, `showInternal`, etc., to avoid nested execution id. However, this approach has poor expansibility. When we hit other nested execution id cases, we may need to add more internal methods in `Dataset`. Our goal is to ignore the nested execution id in some cases, and we can have a better approach to achieve this goal, by introducing `SQLExecution.ignoreNestedExecutionId`. Whenever we find a place which needs to ignore the nested execution, we can just wrap the action with `SQLExecution.ignoreNestedExecutionId`, and this is more expansible than the previous approach. The idea comes from https://github.com/apache/spark/pull/17540/files#diff-ab49028253e599e6e74cc4f4dcb2e3a8R57 by rdblue ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #18419 from cloud-fan/follow. --- .../scala/org/apache/spark/sql/Dataset.scala | 39 ++----------------- .../spark/sql/execution/SQLExecution.scala | 39 +++++++++++++++++-- .../command/AnalyzeTableCommand.scala | 5 ++- .../spark/sql/execution/command/cache.scala | 19 ++++----- .../datasources/csv/CSVDataSource.scala | 6 ++- .../datasources/jdbc/JDBCRelation.scala | 14 +++---- .../sql/execution/streaming/console.scala | 13 +++++-- .../sql/execution/streaming/memory.scala | 33 +++++++++------- 8 files changed, 89 insertions(+), 79 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6e66e92091ff9..268a37ff5d271 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -246,13 +246,8 @@ class Dataset[T] private[sql]( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { val numRows = _numRows.max(0) val takeResult = toDF().take(numRows + 1) - showString(takeResult, numRows, truncate, vertical) - } - - private def showString( - dataWithOneMoreRow: Array[Row], numRows: Int, truncate: Int, vertical: Boolean): String = { - val hasMoreData = dataWithOneMoreRow.length > numRows - val data = dataWithOneMoreRow.take(numRows) + val hasMoreData = takeResult.length > numRows + val data = takeResult.take(numRows) lazy val timeZone = DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone) @@ -688,19 +683,6 @@ class Dataset[T] private[sql]( println(showString(numRows, truncate = 0)) } - // An internal version of `show`, which won't set execution id and trigger listeners. - private[sql] def showInternal(_numRows: Int, truncate: Boolean): Unit = { - val numRows = _numRows.max(0) - val takeResult = toDF().takeInternal(numRows + 1) - - if (truncate) { - println(showString(takeResult, numRows, truncate = 20, vertical = false)) - } else { - println(showString(takeResult, numRows, truncate = 0, vertical = false)) - } - } - // scalastyle:on println - /** * Displays the Dataset in a tabular form. For example: * {{{ @@ -2467,11 +2449,6 @@ class Dataset[T] private[sql]( */ def take(n: Int): Array[T] = head(n) - // An internal version of `take`, which won't set execution id and trigger listeners. - private[sql] def takeInternal(n: Int): Array[T] = { - collectFromPlan(limit(n).queryExecution.executedPlan) - } - /** * Returns the first `n` rows in the Dataset as a list. * @@ -2496,11 +2473,6 @@ class Dataset[T] private[sql]( */ def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan) - // An internal version of `collect`, which won't set execution id and trigger listeners. - private[sql] def collectInternal(): Array[T] = { - collectFromPlan(queryExecution.executedPlan) - } - /** * Returns a Java list that contains all rows in this Dataset. * @@ -2542,11 +2514,6 @@ class Dataset[T] private[sql]( plan.executeCollect().head.getLong(0) } - // An internal version of `count`, which won't set execution id and trigger listeners. - private[sql] def countInternal(): Long = { - groupBy().count().queryExecution.executedPlan.executeCollect().head.getLong(0) - } - /** * Returns a new Dataset that has exactly `numPartitions` partitions. * @@ -2792,7 +2759,7 @@ class Dataset[T] private[sql]( createTempViewCommand(viewName, replace = true, global = true) } - private[spark] def createTempViewCommand( + private def createTempViewCommand( viewName: String, replace: Boolean, global: Boolean): CreateViewCommand = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index bb206e84325fd..ca8bed5214f87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -29,6 +29,8 @@ object SQLExecution { val EXECUTION_ID_KEY = "spark.sql.execution.id" + private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId" + private val _nextExecutionId = new AtomicLong(0) private def nextExecutionId: Long = _nextExecutionId.getAndIncrement @@ -42,8 +44,11 @@ object SQLExecution { private val testing = sys.props.contains("spark.testing") private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = { + val sc = sparkSession.sparkContext + val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null + val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null // only throw an exception during tests. a missing execution ID should not fail a job. - if (testing && sparkSession.sparkContext.getLocalProperty(EXECUTION_ID_KEY) == null) { + if (testing && !isNestedExecution && !hasExecutionId) { // Attention testers: when a test fails with this exception, it means that the action that // started execution of a query didn't call withNewExecutionId. The execution ID should be // set by calling withNewExecutionId in the action that begins execution, like @@ -65,7 +70,7 @@ object SQLExecution { val executionId = SQLExecution.nextExecutionId sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) executionIdToQueryExecution.put(executionId, queryExecution) - val r = try { + try { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" @@ -84,7 +89,15 @@ object SQLExecution { executionIdToQueryExecution.remove(executionId) sc.setLocalProperty(EXECUTION_ID_KEY, null) } - r + } else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) { + // If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the + // `body`, so that Spark jobs issued in the `body` won't be tracked. + try { + sc.setLocalProperty(EXECUTION_ID_KEY, null) + body + } finally { + sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId) + } } else { // Don't support nested `withNewExecutionId`. This is an example of the nested // `withNewExecutionId`: @@ -100,7 +113,9 @@ object SQLExecution { // all accumulator metrics will be 0. It will confuse people if we show them in Web UI. // // A real case is the `DataFrame.count` method. - throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set") + throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " + + "action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " + + "jobs issued by the nested execution.") } } @@ -118,4 +133,20 @@ object SQLExecution { sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) } } + + /** + * Wrap an action which may have nested execution id. This method can be used to run an execution + * inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that, + * all Spark jobs issued in the body won't be tracked in UI. + */ + def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = { + val sc = sparkSession.sparkContext + val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) + try { + sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true") + body + } finally { + sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 06e588f56f1e9..13b8faff844c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.internal.SessionState @@ -58,7 +59,9 @@ case class AnalyzeTableCommand( // 2. when total size is changed, `oldRowCount` becomes invalid. // This is to make sure that we only record the right statistics. if (!noscan) { - val newRowCount = sparkSession.table(tableIdentWithDB).countInternal() + val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) { + sparkSession.table(tableIdentWithDB).count() + } if (newRowCount >= 0 && newRowCount != oldRowCount) { newStats = if (newStats.isDefined) { newStats.map(_.copy(rowCount = Some(BigInt(newRowCount)))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 184d0387ebfa9..d36eb7587a3ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SQLExecution case class CacheTableCommand( tableIdent: TableIdentifier, @@ -33,16 +34,16 @@ case class CacheTableCommand( override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq override def run(sparkSession: SparkSession): Seq[Row] = { - plan.foreach { logicalPlan => - Dataset.ofRows(sparkSession, logicalPlan) - .createTempViewCommand(tableIdent.quotedString, replace = false, global = false) - .run(sparkSession) - } - sparkSession.catalog.cacheTable(tableIdent.quotedString) + SQLExecution.ignoreNestedExecutionId(sparkSession) { + plan.foreach { logicalPlan => + Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString) + } + sparkSession.catalog.cacheTable(tableIdent.quotedString) - if (!isLazy) { - // Performs eager caching - sparkSession.table(tableIdent).countInternal() + if (!isLazy) { + // Performs eager caching + sparkSession.table(tableIdent).count() + } } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index eadc6c94f4b3c..99133bd70989a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -32,6 +32,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -144,8 +145,9 @@ object TextInputCSVDataSource extends CSVDataSource { inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): StructType = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) - val maybeFirstLine = - CSVUtils.filterCommentAndEmpty(csv, parsedOptions).takeInternal(1).headOption + val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) { + CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption + } inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index a06f1ce3287e6..b11da7045de22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.Partition import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType @@ -129,14 +130,11 @@ private[sql] case class JDBCRelation( } override def insert(data: DataFrame, overwrite: Boolean): Unit = { - import scala.collection.JavaConverters._ - - val options = jdbcOptions.asProperties.asScala + - ("url" -> jdbcOptions.url, "dbtable" -> jdbcOptions.table) - val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append - - new JdbcRelationProvider().createRelation( - data.sparkSession.sqlContext, mode, options.toMap, data) + SQLExecution.ignoreNestedExecutionId(data.sparkSession) { + data.write + .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) + .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) + } } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 9e889ff679450..6fa7c113defaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.types.StructType class ConsoleSink(options: Map[String, String]) extends Sink with Logging { @@ -47,9 +48,11 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging { println(batchIdStr) println("-------------------------------------------") // scalastyle:off println - data.sparkSession.createDataFrame( - data.sparkSession.sparkContext.parallelize(data.collectInternal()), data.schema) - .showInternal(numRowsToShow, isTruncated) + SQLExecution.ignoreNestedExecutionId(data.sparkSession) { + data.sparkSession.createDataFrame( + data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) + .show(numRowsToShow, isTruncated) + } } } @@ -79,7 +82,9 @@ class ConsoleSinkProvider extends StreamSinkProvider // Truncate the displayed data if it is too long, by default it is true val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true) - data.showInternal(numRowsToShow, isTruncated) + SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) { + data.show(numRowsToShow, isTruncated) + } ConsoleRelation(sqlContext, data) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index a5dac469f85b6..198a342582804 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -193,21 +194,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi } if (notCommitted) { logDebug(s"Committing batch $batchId to $this") - outputMode match { - case Append | Update => - val rows = AddedData(batchId, data.collectInternal()) - synchronized { batches += rows } - - case Complete => - val rows = AddedData(batchId, data.collectInternal()) - synchronized { - batches.clear() - batches += rows - } - - case _ => - throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") + SQLExecution.ignoreNestedExecutionId(data.sparkSession) { + outputMode match { + case Append | Update => + val rows = AddedData(batchId, data.collect()) + synchronized { batches += rows } + + case Complete => + val rows = AddedData(batchId, data.collect()) + synchronized { + batches.clear() + batches += rows + } + + case _ => + throw new IllegalArgumentException( + s"Output mode $outputMode is not supported by MemorySink") + } } } else { logDebug(s"Skipping already committed batch: $batchId") From 3cb3ccce120fa9f0273133912624b877b42d95fd Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Tue, 27 Jun 2017 17:24:46 +0800 Subject: [PATCH 086/118] [SPARK-21196] Split codegen info of query plan into sequence codegen info of query plan can be very long. In debugging console / web page, it would be more readable if the subtrees and corresponding codegen are split into sequence. Example: ```java codegenStringSeq(sql("select 1").queryExecution.executedPlan) ``` The example will return Seq[(String, String)] of length 1, containing the subtree as string and the corresponding generated code. The subtree as string: > (*Project [1 AS 1#0] > +- Scan OneRowRelation[] The generated code: ```java /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow project_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder project_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter project_rowWriter; /* 012 */ /* 013 */ public GeneratedIterator(Object[] references) { /* 014 */ this.references = references; /* 015 */ } /* 016 */ /* 017 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 018 */ partitionIndex = index; /* 019 */ this.inputs = inputs; /* 020 */ inputadapter_input = inputs[0]; /* 021 */ project_result = new UnsafeRow(1); /* 022 */ project_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(project_result, 0); /* 023 */ project_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(project_holder, 1); /* 024 */ /* 025 */ } /* 026 */ /* 027 */ protected void processNext() throws java.io.IOException { /* 028 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 029 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 030 */ project_rowWriter.write(0, 1); /* 031 */ append(project_result); /* 032 */ if (shouldStop()) return; /* 033 */ } /* 034 */ } /* 035 */ /* 036 */ } ``` ## What changes were proposed in this pull request? add method codegenToSeq: split codegen info of query plan into sequence ## How was this patch tested? unit test cloud-fan gatorsmile Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Wang Gengliang Closes #18409 from gengliangwang/codegen. --- .../spark/sql/execution/QueryExecution.scala | 9 +++++ .../spark/sql/execution/debug/package.scala | 35 ++++++++++++++----- .../sql/execution/debug/DebuggingSuite.scala | 7 ++++ 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index c7cac332a0377..9533144214a10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -245,5 +245,14 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { println(org.apache.spark.sql.execution.debug.codegenString(executedPlan)) // scalastyle:on println } + + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan + * + * @return Sequence of WholeStageCodegen subtrees and corresponding codegen + */ + def codegenToSeq(): Seq[(String, String)] = { + org.apache.spark.sql.execution.debug.codegenStringSeq(executedPlan) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 0395c43ba2cbc..a717cbd4a7df9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -50,7 +50,31 @@ package object debug { // scalastyle:on println } + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan into one String + * + * @param plan the query plan for codegen + * @return single String containing all WholeStageCodegen subtrees and corresponding codegen + */ def codegenString(plan: SparkPlan): String = { + val codegenSeq = codegenStringSeq(plan) + var output = s"Found ${codegenSeq.size} WholeStageCodegen subtrees.\n" + for (((subtree, code), i) <- codegenSeq.zipWithIndex) { + output += s"== Subtree ${i + 1} / ${codegenSeq.size} ==\n" + output += subtree + output += "\nGenerated code:\n" + output += s"${code}\n" + } + output + } + + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan + * + * @param plan the query plan for codegen + * @return Sequence of WholeStageCodegen subtrees and corresponding codegen + */ + def codegenStringSeq(plan: SparkPlan): Seq[(String, String)] = { val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]() plan transform { case s: WholeStageCodegenExec => @@ -58,15 +82,10 @@ package object debug { s case s => s } - var output = s"Found ${codegenSubtrees.size} WholeStageCodegen subtrees.\n" - for ((s, i) <- codegenSubtrees.toSeq.zipWithIndex) { - output += s"== Subtree ${i + 1} / ${codegenSubtrees.size} ==\n" - output += s - output += "\nGenerated code:\n" - val (_, source) = s.doCodeGen() - output += s"${CodeFormatter.format(source)}\n" + codegenSubtrees.toSeq.map { subtree => + val (_, source) = subtree.doCodeGen() + (subtree.toString, CodeFormatter.format(source)) } - output } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 4fc52c99fbeeb..adcaf2d76519f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -38,4 +38,11 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) } + + test("debugCodegenStringSeq") { + val res = codegenStringSeq(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + assert(res.length == 2) + assert(res.forall{ case (subtree, code) => + subtree.contains("Range") && code.contains("Object[]")}) + } } From b32bd005e46443bbd487b7a1f1078578c8f4c181 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 27 Jun 2017 13:14:12 +0100 Subject: [PATCH 087/118] [INFRA] Close stale PRs ## What changes were proposed in this pull request? This PR proposes to close stale PRs, mostly the same instances with https://github.com/apache/spark/pull/18017 I believe the author in #14807 removed his account. Closes #7075 Closes #8927 Closes #9202 Closes #9366 Closes #10861 Closes #11420 Closes #12356 Closes #13028 Closes #13506 Closes #14191 Closes #14198 Closes #14330 Closes #14807 Closes #15839 Closes #16225 Closes #16685 Closes #16692 Closes #16995 Closes #17181 Closes #17211 Closes #17235 Closes #17237 Closes #17248 Closes #17341 Closes #17708 Closes #17716 Closes #17721 Closes #17937 Added: Closes #14739 Closes #17139 Closes #17445 Closes #18042 Closes #18359 Added: Closes #16450 Closes #16525 Closes #17738 Added: Closes #16458 Closes #16508 Closes #17714 Added: Closes #17830 Closes #14742 ## How was this patch tested? N/A Author: hyukjinkwon Closes #18417 from HyukjinKwon/close-stale-pr. From fd8c931a30a084ee981b75aa469fc97dda6cfaa9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 28 Jun 2017 00:57:05 +0800 Subject: [PATCH 088/118] [SPARK-19104][SQL] Lambda variables in ExternalMapToCatalyst should be global ## What changes were proposed in this pull request? The issue happens in `ExternalMapToCatalyst`. For example, the following codes create `ExternalMapToCatalyst` to convert Scala Map to catalyst map format. val data = Seq.tabulate(10)(i => NestedData(1, Map("key" -> InnerData("name", i + 100)))) val ds = spark.createDataset(data) The `valueConverter` in `ExternalMapToCatalyst` looks like: if (isnull(lambdavariable(ExternalMapToCatalyst_value52, ExternalMapToCatalyst_value_isNull52, ObjectType(class org.apache.spark.sql.InnerData), true))) null else named_struct(name, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(lambdavariable(ExternalMapToCatalyst_value52, ExternalMapToCatalyst_value_isNull52, ObjectType(class org.apache.spark.sql.InnerData), true)).name, true), value, assertnotnull(lambdavariable(ExternalMapToCatalyst_value52, ExternalMapToCatalyst_value_isNull52, ObjectType(class org.apache.spark.sql.InnerData), true)).value) There is a `CreateNamedStruct` expression (`named_struct`) to create a row of `InnerData.name` and `InnerData.value` that are referred by `ExternalMapToCatalyst_value52`. Because `ExternalMapToCatalyst_value52` are local variable, when `CreateNamedStruct` splits expressions to individual functions, the local variable can't be accessed anymore. ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh Closes #18418 from viirya/SPARK-19104. --- .../catalyst/expressions/objects/objects.scala | 18 ++++++++++++------ .../spark/sql/DatasetPrimitiveSuite.scala | 8 ++++++++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 073993cccdf8a..4b651836ff4d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -911,6 +911,12 @@ case class ExternalMapToCatalyst private( val entry = ctx.freshName("entry") val entries = ctx.freshName("entries") + val keyElementJavaType = ctx.javaType(keyType) + val valueElementJavaType = ctx.javaType(valueType) + ctx.addMutableState(keyElementJavaType, key, "") + ctx.addMutableState("boolean", valueIsNull, "") + ctx.addMutableState(valueElementJavaType, value, "") + val (defineEntries, defineKeyValue) = child.dataType match { case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => val javaIteratorCls = classOf[java.util.Iterator[_]].getName @@ -922,8 +928,8 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); - ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry.getKey(); - ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry.getValue(); + $key = (${ctx.boxedType(keyType)}) $entry.getKey(); + $value = (${ctx.boxedType(valueType)}) $entry.getValue(); """ defineEntries -> defineKeyValue @@ -937,17 +943,17 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); - ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry._1(); - ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry._2(); + $key = (${ctx.boxedType(keyType)}) $entry._1(); + $value = (${ctx.boxedType(valueType)}) $entry._2(); """ defineEntries -> defineKeyValue } val valueNullCheck = if (ctx.isPrimitiveType(valueType)) { - s"boolean $valueIsNull = false;" + s"$valueIsNull = false;" } else { - s"boolean $valueIsNull = $value == null;" + s"$valueIsNull = $value == null;" } val arrayCls = classOf[GenericArrayData].getName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 4126660b5d102..a6847dcfbffc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -39,6 +39,9 @@ case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass) case class ComplexMapClass(map: MapClass, lhmap: LHMapClass) +case class InnerData(name: String, value: Int) +case class NestedData(id: Int, param: Map[String, InnerData]) + package object packageobject { case class PackageClass(value: Int) } @@ -354,4 +357,9 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) } + test("SPARK-19104: Lambda variables in ExternalMapToCatalyst should be global") { + val data = Seq.tabulate(10)(i => NestedData(1, Map("key" -> InnerData("name", i + 100)))) + val ds = spark.createDataset(data) + checkDataset(ds, data: _*) + } } From 2d686a19e341a31d976aa42228b7589f87dfd6c2 Mon Sep 17 00:00:00 2001 From: Eric Vandenberg Date: Wed, 28 Jun 2017 09:26:33 +0800 Subject: [PATCH 089/118] [SPARK-21155][WEBUI] Add (? running tasks) into Spark UI progress ## What changes were proposed in this pull request? Add metric on number of running tasks to status bar on Jobs / Active Jobs. ## How was this patch tested? Run a long running (1 minute) query in spark-shell and use localhost:4040 web UI to observe progress. See jira for screen snapshot. Author: Eric Vandenberg Closes #18369 from ericvandenbergfb/runningTasks. --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 2610f673d27f6..ba798df13c95d 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -356,6 +356,7 @@ private[spark] object UIUtils extends Logging {
    {completed}/{total} + { if (failed == 0 && skipped == 0 && started > 0) s"($started running)" } { if (failed > 0) s"($failed failed)" } { if (skipped > 0) s"($skipped skipped)" } { reasonToNumKilled.toSeq.sortBy(-_._2).map { From e793bf248bc3c71b9664f26377bce06b0ffa97a7 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Tue, 27 Jun 2017 23:15:45 -0700 Subject: [PATCH 090/118] [SPARK-20889][SPARKR] Grouped documentation for MATH column methods ## What changes were proposed in this pull request? Grouped documentation for math column methods. Author: actuaryzhang Author: Wayne Zhang Closes #18371 from actuaryzhang/sparkRDocMath. --- R/pkg/R/functions.R | 619 +++++++++++++++----------------------------- R/pkg/R/generics.R | 48 ++-- 2 files changed, 241 insertions(+), 426 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 31028585aaa13..23ccdf941a8c7 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -86,6 +86,31 @@ NULL #' df <- createDataFrame(data.frame(time = as.POSIXct(dts), y = y))} NULL +#' Math functions for Column operations +#' +#' Math functions defined for \code{Column}. +#' +#' @param x Column to compute on. In \code{shiftLeft}, \code{shiftRight} and \code{shiftRightUnsigned}, +#' this is the number of bits to shift. +#' @param y Column to compute on. +#' @param ... additional argument(s). +#' @name column_math_functions +#' @rdname column_math_functions +#' @family math functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' tmp <- mutate(df, v1 = log(df$mpg), v2 = cbrt(df$disp), +#' v3 = bround(df$wt, 1), v4 = bin(df$cyl), +#' v5 = hex(df$wt), v6 = toDegrees(df$gear), +#' v7 = atan2(df$cyl, df$am), v8 = hypot(df$cyl, df$am), +#' v9 = pmod(df$hp, df$cyl), v10 = shiftLeft(df$disp, 1), +#' v11 = conv(df$hp, 10, 16), v12 = sign(df$vs - 0.5), +#' v13 = sqrt(df$disp), v14 = ceil(df$wt)) +#' head(tmp)} +NULL + #' lit #' #' A new \linkS4class{Column} is created to represent the literal value. @@ -112,18 +137,12 @@ setMethod("lit", signature("ANY"), column(jc) }) -#' abs -#' -#' Computes the absolute value. -#' -#' @param x Column to compute on. +#' @details +#' \code{abs}: Computes the absolute value. #' -#' @rdname abs -#' @name abs -#' @family non-aggregate functions +#' @rdname column_math_functions #' @export -#' @examples \dontrun{abs(df$c)} -#' @aliases abs,Column-method +#' @aliases abs abs,Column-method #' @note abs since 1.5.0 setMethod("abs", signature(x = "Column"), @@ -132,19 +151,13 @@ setMethod("abs", column(jc) }) -#' acos -#' -#' Computes the cosine inverse of the given value; the returned angle is in the range -#' 0.0 through pi. -#' -#' @param x Column to compute on. +#' @details +#' \code{acos}: Computes the cosine inverse of the given value; the returned angle is in +#' the range 0.0 through pi. #' -#' @rdname acos -#' @name acos -#' @family math functions +#' @rdname column_math_functions #' @export -#' @examples \dontrun{acos(df$c)} -#' @aliases acos,Column-method +#' @aliases acos acos,Column-method #' @note acos since 1.5.0 setMethod("acos", signature(x = "Column"), @@ -196,19 +209,13 @@ setMethod("ascii", column(jc) }) -#' asin -#' -#' Computes the sine inverse of the given value; the returned angle is in the range -#' -pi/2 through pi/2. -#' -#' @param x Column to compute on. +#' @details +#' \code{asin}: Computes the sine inverse of the given value; the returned angle is in +#' the range -pi/2 through pi/2. #' -#' @rdname asin -#' @name asin -#' @family math functions +#' @rdname column_math_functions #' @export -#' @aliases asin,Column-method -#' @examples \dontrun{asin(df$c)} +#' @aliases asin asin,Column-method #' @note asin since 1.5.0 setMethod("asin", signature(x = "Column"), @@ -217,18 +224,12 @@ setMethod("asin", column(jc) }) -#' atan -#' -#' Computes the tangent inverse of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{atan}: Computes the tangent inverse of the given value. #' -#' @rdname atan -#' @name atan -#' @family math functions +#' @rdname column_math_functions #' @export -#' @aliases atan,Column-method -#' @examples \dontrun{atan(df$c)} +#' @aliases atan atan,Column-method #' @note atan since 1.5.0 setMethod("atan", signature(x = "Column"), @@ -276,19 +277,13 @@ setMethod("base64", column(jc) }) -#' bin -#' -#' An expression that returns the string representation of the binary value of the given long -#' column. For example, bin("12") returns "1100". -#' -#' @param x Column to compute on. +#' @details +#' \code{bin}: An expression that returns the string representation of the binary value +#' of the given long column. For example, bin("12") returns "1100". #' -#' @rdname bin -#' @name bin -#' @family math functions +#' @rdname column_math_functions #' @export -#' @aliases bin,Column-method -#' @examples \dontrun{bin(df$c)} +#' @aliases bin bin,Column-method #' @note bin since 1.5.0 setMethod("bin", signature(x = "Column"), @@ -317,18 +312,12 @@ setMethod("bitwiseNOT", column(jc) }) -#' cbrt -#' -#' Computes the cube-root of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{cbrt}: Computes the cube-root of the given value. #' -#' @rdname cbrt -#' @name cbrt -#' @family math functions +#' @rdname column_math_functions #' @export -#' @aliases cbrt,Column-method -#' @examples \dontrun{cbrt(df$c)} +#' @aliases cbrt cbrt,Column-method #' @note cbrt since 1.4.0 setMethod("cbrt", signature(x = "Column"), @@ -337,18 +326,12 @@ setMethod("cbrt", column(jc) }) -#' Computes the ceiling of the given value -#' -#' Computes the ceiling of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{ceil}: Computes the ceiling of the given value. #' -#' @rdname ceil -#' @name ceil -#' @family math functions +#' @rdname column_math_functions #' @export -#' @aliases ceil,Column-method -#' @examples \dontrun{ceil(df$c)} +#' @aliases ceil ceil,Column-method #' @note ceil since 1.5.0 setMethod("ceil", signature(x = "Column"), @@ -357,6 +340,19 @@ setMethod("ceil", column(jc) }) +#' @details +#' \code{ceiling}: Alias for \code{ceil}. +#' +#' @rdname column_math_functions +#' @aliases ceiling ceiling,Column-method +#' @export +#' @note ceiling since 1.5.0 +setMethod("ceiling", + signature(x = "Column"), + function(x) { + ceil(x) + }) + #' Returns the first column that is not NA #' #' Returns the first column that is not NA, or NA if all inputs are. @@ -405,6 +401,7 @@ setMethod("column", function(x) { col(x) }) + #' corr #' #' Computes the Pearson Correlation Coefficient for two Columns. @@ -493,18 +490,12 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr column(jc) }) -#' cos -#' -#' Computes the cosine of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{cos}: Computes the cosine of the given value. #' -#' @rdname cos -#' @name cos -#' @family math functions -#' @aliases cos,Column-method +#' @rdname column_math_functions +#' @aliases cos cos,Column-method #' @export -#' @examples \dontrun{cos(df$c)} #' @note cos since 1.5.0 setMethod("cos", signature(x = "Column"), @@ -513,18 +504,12 @@ setMethod("cos", column(jc) }) -#' cosh -#' -#' Computes the hyperbolic cosine of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{cosh}: Computes the hyperbolic cosine of the given value. #' -#' @rdname cosh -#' @name cosh -#' @family math functions -#' @aliases cosh,Column-method +#' @rdname column_math_functions +#' @aliases cosh cosh,Column-method #' @export -#' @examples \dontrun{cosh(df$c)} #' @note cosh since 1.5.0 setMethod("cosh", signature(x = "Column"), @@ -679,18 +664,12 @@ setMethod("encode", column(jc) }) -#' exp -#' -#' Computes the exponential of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{exp}: Computes the exponential of the given value. #' -#' @rdname exp -#' @name exp -#' @family math functions -#' @aliases exp,Column-method +#' @rdname column_math_functions +#' @aliases exp exp,Column-method #' @export -#' @examples \dontrun{exp(df$c)} #' @note exp since 1.5.0 setMethod("exp", signature(x = "Column"), @@ -699,18 +678,12 @@ setMethod("exp", column(jc) }) -#' expm1 -#' -#' Computes the exponential of the given value minus one. -#' -#' @param x Column to compute on. +#' @details +#' \code{expm1}: Computes the exponential of the given value minus one. #' -#' @rdname expm1 -#' @name expm1 -#' @aliases expm1,Column-method -#' @family math functions +#' @rdname column_math_functions +#' @aliases expm1 expm1,Column-method #' @export -#' @examples \dontrun{expm1(df$c)} #' @note expm1 since 1.5.0 setMethod("expm1", signature(x = "Column"), @@ -719,18 +692,12 @@ setMethod("expm1", column(jc) }) -#' factorial -#' -#' Computes the factorial of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{factorial}: Computes the factorial of the given value. #' -#' @rdname factorial -#' @name factorial -#' @aliases factorial,Column-method -#' @family math functions +#' @rdname column_math_functions +#' @aliases factorial factorial,Column-method #' @export -#' @examples \dontrun{factorial(df$c)} #' @note factorial since 1.5.0 setMethod("factorial", signature(x = "Column"), @@ -772,18 +739,12 @@ setMethod("first", column(jc) }) -#' floor -#' -#' Computes the floor of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{floor}: Computes the floor of the given value. #' -#' @rdname floor -#' @name floor -#' @aliases floor,Column-method -#' @family math functions +#' @rdname column_math_functions +#' @aliases floor floor,Column-method #' @export -#' @examples \dontrun{floor(df$c)} #' @note floor since 1.5.0 setMethod("floor", signature(x = "Column"), @@ -792,18 +753,12 @@ setMethod("floor", column(jc) }) -#' hex -#' -#' Computes hex value of the given column. -#' -#' @param x Column to compute on. +#' @details +#' \code{hex}: Computes hex value of the given column. #' -#' @rdname hex -#' @name hex -#' @family math functions -#' @aliases hex,Column-method +#' @rdname column_math_functions +#' @aliases hex hex,Column-method #' @export -#' @examples \dontrun{hex(df$c)} #' @note hex since 1.5.0 setMethod("hex", signature(x = "Column"), @@ -983,18 +938,12 @@ setMethod("length", column(jc) }) -#' log -#' -#' Computes the natural logarithm of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{log}: Computes the natural logarithm of the given value. #' -#' @rdname log -#' @name log -#' @aliases log,Column-method -#' @family math functions +#' @rdname column_math_functions +#' @aliases log log,Column-method #' @export -#' @examples \dontrun{log(df$c)} #' @note log since 1.5.0 setMethod("log", signature(x = "Column"), @@ -1003,18 +952,12 @@ setMethod("log", column(jc) }) -#' log10 -#' -#' Computes the logarithm of the given value in base 10. -#' -#' @param x Column to compute on. +#' @details +#' \code{log10}: Computes the logarithm of the given value in base 10. #' -#' @rdname log10 -#' @name log10 -#' @family math functions -#' @aliases log10,Column-method +#' @rdname column_math_functions +#' @aliases log10 log10,Column-method #' @export -#' @examples \dontrun{log10(df$c)} #' @note log10 since 1.5.0 setMethod("log10", signature(x = "Column"), @@ -1023,18 +966,12 @@ setMethod("log10", column(jc) }) -#' log1p -#' -#' Computes the natural logarithm of the given value plus one. -#' -#' @param x Column to compute on. +#' @details +#' \code{log1p}: Computes the natural logarithm of the given value plus one. #' -#' @rdname log1p -#' @name log1p -#' @family math functions -#' @aliases log1p,Column-method +#' @rdname column_math_functions +#' @aliases log1p log1p,Column-method #' @export -#' @examples \dontrun{log1p(df$c)} #' @note log1p since 1.5.0 setMethod("log1p", signature(x = "Column"), @@ -1043,18 +980,12 @@ setMethod("log1p", column(jc) }) -#' log2 -#' -#' Computes the logarithm of the given column in base 2. -#' -#' @param x Column to compute on. +#' @details +#' \code{log2}: Computes the logarithm of the given column in base 2. #' -#' @rdname log2 -#' @name log2 -#' @family math functions -#' @aliases log2,Column-method +#' @rdname column_math_functions +#' @aliases log2 log2,Column-method #' @export -#' @examples \dontrun{log2(df$c)} #' @note log2 since 1.5.0 setMethod("log2", signature(x = "Column"), @@ -1287,19 +1218,13 @@ setMethod("reverse", column(jc) }) -#' rint -#' -#' Returns the double value that is closest in value to the argument and +#' @details +#' \code{rint}: Returns the double value that is closest in value to the argument and #' is equal to a mathematical integer. #' -#' @param x Column to compute on. -#' -#' @rdname rint -#' @name rint -#' @family math functions -#' @aliases rint,Column-method +#' @rdname column_math_functions +#' @aliases rint rint,Column-method #' @export -#' @examples \dontrun{rint(df$c)} #' @note rint since 1.5.0 setMethod("rint", signature(x = "Column"), @@ -1308,18 +1233,13 @@ setMethod("rint", column(jc) }) -#' round -#' -#' Returns the value of the column \code{e} rounded to 0 decimal places using HALF_UP rounding mode. -#' -#' @param x Column to compute on. +#' @details +#' \code{round}: Returns the value of the column rounded to 0 decimal places +#' using HALF_UP rounding mode. #' -#' @rdname round -#' @name round -#' @family math functions -#' @aliases round,Column-method +#' @rdname column_math_functions +#' @aliases round round,Column-method #' @export -#' @examples \dontrun{round(df$c)} #' @note round since 1.5.0 setMethod("round", signature(x = "Column"), @@ -1328,24 +1248,18 @@ setMethod("round", column(jc) }) -#' bround -#' -#' Returns the value of the column \code{e} rounded to \code{scale} decimal places using HALF_EVEN rounding -#' mode if \code{scale} >= 0 or at integer part when \code{scale} < 0. +#' @details +#' \code{bround}: Returns the value of the column \code{e} rounded to \code{scale} decimal places +#' using HALF_EVEN rounding mode if \code{scale} >= 0 or at integer part when \code{scale} < 0. #' Also known as Gaussian rounding or bankers' rounding that rounds to the nearest even number. #' bround(2.5, 0) = 2, bround(3.5, 0) = 4. #' -#' @param x Column to compute on. #' @param scale round to \code{scale} digits to the right of the decimal point when \code{scale} > 0, #' the nearest even number when \code{scale} = 0, and \code{scale} digits to the left #' of the decimal point when \code{scale} < 0. -#' @param ... further arguments to be passed to or from other methods. -#' @rdname bround -#' @name bround -#' @family math functions -#' @aliases bround,Column-method +#' @rdname column_math_functions +#' @aliases bround bround,Column-method #' @export -#' @examples \dontrun{bround(df$c, 0)} #' @note bround since 2.0.0 setMethod("bround", signature(x = "Column"), @@ -1354,7 +1268,6 @@ setMethod("bround", column(jc) }) - #' rtrim #' #' Trim the spaces from right end for the specified string value. @@ -1375,7 +1288,6 @@ setMethod("rtrim", column(jc) }) - #' @details #' \code{sd}: Alias for \code{stddev_samp}. #' @@ -1429,18 +1341,12 @@ setMethod("sha1", column(jc) }) -#' signum -#' -#' Computes the signum of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{signum}: Computes the signum of the given value. #' -#' @rdname sign -#' @name signum -#' @aliases signum,Column-method -#' @family math functions +#' @rdname column_math_functions +#' @aliases signum signum,Column-method #' @export -#' @examples \dontrun{signum(df$c)} #' @note signum since 1.5.0 setMethod("signum", signature(x = "Column"), @@ -1449,18 +1355,24 @@ setMethod("signum", column(jc) }) -#' sin -#' -#' Computes the sine of the given value. +#' @details +#' \code{sign}: Alias for \code{signum}. #' -#' @param x Column to compute on. +#' @rdname column_math_functions +#' @aliases sign sign,Column-method +#' @export +#' @note sign since 1.5.0 +setMethod("sign", signature(x = "Column"), + function(x) { + signum(x) + }) + +#' @details +#' \code{sin}: Computes the sine of the given value. #' -#' @rdname sin -#' @name sin -#' @family math functions -#' @aliases sin,Column-method +#' @rdname column_math_functions +#' @aliases sin sin,Column-method #' @export -#' @examples \dontrun{sin(df$c)} #' @note sin since 1.5.0 setMethod("sin", signature(x = "Column"), @@ -1469,18 +1381,12 @@ setMethod("sin", column(jc) }) -#' sinh -#' -#' Computes the hyperbolic sine of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{sinh}: Computes the hyperbolic sine of the given value. #' -#' @rdname sinh -#' @name sinh -#' @family math functions -#' @aliases sinh,Column-method +#' @rdname column_math_functions +#' @aliases sinh sinh,Column-method #' @export -#' @examples \dontrun{sinh(df$c)} #' @note sinh since 1.5.0 setMethod("sinh", signature(x = "Column"), @@ -1616,18 +1522,12 @@ setMethod("struct", column(jc) }) -#' sqrt -#' -#' Computes the square root of the specified float value. -#' -#' @param x Column to compute on. +#' @details +#' \code{sqrt}: Computes the square root of the specified float value. #' -#' @rdname sqrt -#' @name sqrt -#' @family math functions -#' @aliases sqrt,Column-method +#' @rdname column_math_functions +#' @aliases sqrt sqrt,Column-method #' @export -#' @examples \dontrun{sqrt(df$c)} #' @note sqrt since 1.5.0 setMethod("sqrt", signature(x = "Column"), @@ -1669,18 +1569,12 @@ setMethod("sumDistinct", column(jc) }) -#' tan -#' -#' Computes the tangent of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{tan}: Computes the tangent of the given value. #' -#' @rdname tan -#' @name tan -#' @family math functions -#' @aliases tan,Column-method +#' @rdname column_math_functions +#' @aliases tan tan,Column-method #' @export -#' @examples \dontrun{tan(df$c)} #' @note tan since 1.5.0 setMethod("tan", signature(x = "Column"), @@ -1689,18 +1583,12 @@ setMethod("tan", column(jc) }) -#' tanh -#' -#' Computes the hyperbolic tangent of the given value. -#' -#' @param x Column to compute on. +#' @details +#' \code{tanh}: Computes the hyperbolic tangent of the given value. #' -#' @rdname tanh -#' @name tanh -#' @family math functions -#' @aliases tanh,Column-method +#' @rdname column_math_functions +#' @aliases tanh tanh,Column-method #' @export -#' @examples \dontrun{tanh(df$c)} #' @note tanh since 1.5.0 setMethod("tanh", signature(x = "Column"), @@ -1709,18 +1597,13 @@ setMethod("tanh", column(jc) }) -#' toDegrees -#' -#' Converts an angle measured in radians to an approximately equivalent angle measured in degrees. -#' -#' @param x Column to compute on. +#' @details +#' \code{toDegrees}: Converts an angle measured in radians to an approximately equivalent angle +#' measured in degrees. #' -#' @rdname toDegrees -#' @name toDegrees -#' @family math functions -#' @aliases toDegrees,Column-method +#' @rdname column_math_functions +#' @aliases toDegrees toDegrees,Column-method #' @export -#' @examples \dontrun{toDegrees(df$c)} #' @note toDegrees since 1.4.0 setMethod("toDegrees", signature(x = "Column"), @@ -1729,18 +1612,13 @@ setMethod("toDegrees", column(jc) }) -#' toRadians -#' -#' Converts an angle measured in degrees to an approximately equivalent angle measured in radians. -#' -#' @param x Column to compute on. +#' @details +#' \code{toRadians}: Converts an angle measured in degrees to an approximately equivalent angle +#' measured in radians. #' -#' @rdname toRadians -#' @name toRadians -#' @family math functions -#' @aliases toRadians,Column-method +#' @rdname column_math_functions +#' @aliases toRadians toRadians,Column-method #' @export -#' @examples \dontrun{toRadians(df$c)} #' @note toRadians since 1.4.0 setMethod("toRadians", signature(x = "Column"), @@ -1894,19 +1772,13 @@ setMethod("unbase64", column(jc) }) -#' unhex -#' -#' Inverse of hex. Interprets each pair of characters as a hexadecimal number +#' @details +#' \code{unhex}: Inverse of hex. Interprets each pair of characters as a hexadecimal number #' and converts to the byte representation of number. #' -#' @param x Column to compute on. -#' -#' @rdname unhex -#' @name unhex -#' @family math functions -#' @aliases unhex,Column-method +#' @rdname column_math_functions +#' @aliases unhex unhex,Column-method #' @export -#' @examples \dontrun{unhex(df$c)} #' @note unhex since 1.5.0 setMethod("unhex", signature(x = "Column"), @@ -2020,20 +1892,13 @@ setMethod("year", column(jc) }) -#' atan2 -#' -#' Returns the angle theta from the conversion of rectangular coordinates (x, y) to -#' polar coordinates (r, theta). -# -#' @param x Column to compute on. -#' @param y Column to compute on. +#' @details +#' \code{atan2}: Returns the angle theta from the conversion of rectangular coordinates +#' (x, y) to polar coordinates (r, theta). #' -#' @rdname atan2 -#' @name atan2 -#' @family math functions -#' @aliases atan2,Column-method +#' @rdname column_math_functions +#' @aliases atan2 atan2,Column-method #' @export -#' @examples \dontrun{atan2(df$c, x)} #' @note atan2 since 1.5.0 setMethod("atan2", signature(y = "Column"), function(y, x) { @@ -2068,19 +1933,12 @@ setMethod("datediff", signature(y = "Column"), column(jc) }) -#' hypot -#' -#' Computes "sqrt(a^2 + b^2)" without intermediate overflow or underflow. -# -#' @param x Column to compute on. -#' @param y Column to compute on. +#' @details +#' \code{hypot}: Computes "sqrt(a^2 + b^2)" without intermediate overflow or underflow. #' -#' @rdname hypot -#' @name hypot -#' @family math functions -#' @aliases hypot,Column-method +#' @rdname column_math_functions +#' @aliases hypot hypot,Column-method #' @export -#' @examples \dontrun{hypot(df$c, x)} #' @note hypot since 1.4.0 setMethod("hypot", signature(y = "Column"), function(y, x) { @@ -2154,20 +2012,13 @@ setMethod("nanvl", signature(y = "Column"), column(jc) }) -#' pmod -#' -#' Returns the positive value of dividend mod divisor. -#' -#' @param x divisor Column. -#' @param y dividend Column. +#' @details +#' \code{pmod}: Returns the positive value of dividend mod divisor. +#' Column \code{x} is divisor column, and column \code{y} is the dividend column. #' -#' @rdname pmod -#' @name pmod -#' @docType methods -#' @family math functions -#' @aliases pmod,Column-method +#' @rdname column_math_functions +#' @aliases pmod pmod,Column-method #' @export -#' @examples \dontrun{pmod(df$c, x)} #' @note pmod since 1.5.0 setMethod("pmod", signature(y = "Column"), function(y, x) { @@ -2290,31 +2141,6 @@ setMethod("least", column(jc) }) -#' @rdname ceil -#' -#' @name ceiling -#' @aliases ceiling,Column-method -#' @export -#' @examples \dontrun{ceiling(df$c)} -#' @note ceiling since 1.5.0 -setMethod("ceiling", - signature(x = "Column"), - function(x) { - ceil(x) - }) - -#' @rdname sign -#' -#' @name sign -#' @aliases sign,Column-method -#' @export -#' @examples \dontrun{sign(df$c)} -#' @note sign since 1.5.0 -setMethod("sign", signature(x = "Column"), - function(x) { - signum(x) - }) - #' @details #' \code{n_distinct}: Returns the number of distinct items in a group. #' @@ -2564,20 +2390,13 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), column(jc) }) -#' shiftLeft -#' -#' Shift the given value numBits left. If the given value is a long value, this function -#' will return a long value else it will return an integer value. -#' -#' @param y column to compute on. -#' @param x number of bits to shift. +#' @details +#' \code{shiftLeft}: Shifts the given value numBits left. If the given value is a long value, +#' this function will return a long value else it will return an integer value. #' -#' @family math functions -#' @rdname shiftLeft -#' @name shiftLeft -#' @aliases shiftLeft,Column,numeric-method +#' @rdname column_math_functions +#' @aliases shiftLeft shiftLeft,Column,numeric-method #' @export -#' @examples \dontrun{shiftLeft(df$c, 1)} #' @note shiftLeft since 1.5.0 setMethod("shiftLeft", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2587,20 +2406,13 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), column(jc) }) -#' shiftRight -#' -#' (Signed) shift the given value numBits right. If the given value is a long value, it will return -#' a long value else it will return an integer value. -#' -#' @param y column to compute on. -#' @param x number of bits to shift. +#' @details +#' \code{shiftRight}: (Signed) shifts the given value numBits right. If the given value is a long value, +#' it will return a long value else it will return an integer value. #' -#' @family math functions -#' @rdname shiftRight -#' @name shiftRight -#' @aliases shiftRight,Column,numeric-method +#' @rdname column_math_functions +#' @aliases shiftRight shiftRight,Column,numeric-method #' @export -#' @examples \dontrun{shiftRight(df$c, 1)} #' @note shiftRight since 1.5.0 setMethod("shiftRight", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2610,20 +2422,13 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), column(jc) }) -#' shiftRightUnsigned -#' -#' Unsigned shift the given value numBits right. If the given value is a long value, +#' @details +#' \code{shiftRight}: (Unigned) shifts the given value numBits right. If the given value is a long value, #' it will return a long value else it will return an integer value. #' -#' @param y column to compute on. -#' @param x number of bits to shift. -#' -#' @family math functions -#' @rdname shiftRightUnsigned -#' @name shiftRightUnsigned -#' @aliases shiftRightUnsigned,Column,numeric-method +#' @rdname column_math_functions +#' @aliases shiftRightUnsigned shiftRightUnsigned,Column,numeric-method #' @export -#' @examples \dontrun{shiftRightUnsigned(df$c, 1)} #' @note shiftRightUnsigned since 1.5.0 setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2656,20 +2461,14 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), column(jc) }) -#' conv -#' -#' Convert a number in a string column from one base to another. +#' @details +#' \code{conv}: Converts a number in a string column from one base to another. #' -#' @param x column to convert. #' @param fromBase base to convert from. #' @param toBase base to convert to. -#' -#' @family math functions -#' @rdname conv -#' @aliases conv,Column,numeric,numeric-method -#' @name conv +#' @rdname column_math_functions +#' @aliases conv conv,Column,numeric,numeric-method #' @export -#' @examples \dontrun{conv(df$n, 2, 16)} #' @note conv since 1.5.0 setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), function(x, fromBase, toBase) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index f105174cea70d..0248ec585d771 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -931,24 +931,28 @@ setGeneric("avg", function(x, ...) { standardGeneric("avg") }) #' @export setGeneric("base64", function(x) { standardGeneric("base64") }) -#' @rdname bin +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("bin", function(x) { standardGeneric("bin") }) #' @rdname bitwiseNOT #' @export setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) -#' @rdname bround +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("bround", function(x, ...) { standardGeneric("bround") }) -#' @rdname cbrt +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) -#' @rdname ceil +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("ceil", function(x) { standardGeneric("ceil") }) #' @rdname column_aggregate_functions @@ -973,8 +977,9 @@ setGeneric("concat", function(x, ...) { standardGeneric("concat") }) #' @export setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) -#' @rdname conv +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) #' @rdname column_aggregate_functions @@ -1094,8 +1099,9 @@ setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") }) #' @name NULL setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) -#' @rdname hex +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("hex", function(x) { standardGeneric("hex") }) #' @rdname column_datetime_functions @@ -1103,8 +1109,9 @@ setGeneric("hex", function(x) { standardGeneric("hex") }) #' @name NULL setGeneric("hour", function(x) { standardGeneric("hour") }) -#' @rdname hypot +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) #' @rdname initcap @@ -1235,8 +1242,9 @@ setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) #' @export setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") }) -#' @rdname pmod +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) #' @rdname posexplode @@ -1281,8 +1289,9 @@ setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) #' @export setGeneric("reverse", function(x) { standardGeneric("reverse") }) -#' @rdname rint +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("rint", function(x) { standardGeneric("rint") }) #' @param x empty. Should be used with no argument. @@ -1316,20 +1325,24 @@ setGeneric("sha1", function(x) { standardGeneric("sha1") }) #' @export setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) -#' @rdname shiftLeft +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") }) -#' @rdname shiftRight +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) -#' @rdname shiftRightUnsigned +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) -#' @rdname sign +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("signum", function(x) { standardGeneric("signum") }) #' @rdname size @@ -1386,12 +1399,14 @@ setGeneric("substring_index", function(x, delim, count) { standardGeneric("subst #' @name NULL setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) -#' @rdname toDegrees +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) -#' @rdname toRadians +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname column_datetime_functions @@ -1425,8 +1440,9 @@ setGeneric("trim", function(x) { standardGeneric("trim") }) #' @export setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) -#' @rdname unhex +#' @rdname column_math_functions #' @export +#' @name NULL setGeneric("unhex", function(x) { standardGeneric("unhex") }) #' @rdname column_datetime_functions From 838effb98a0d3410766771533402ce0386133af3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 28 Jun 2017 14:28:40 +0800 Subject: [PATCH 091/118] Revert "[SPARK-13534][PYSPARK] Using Apache Arrow to increase performance of DataFrame.toPandas" This reverts commit e44697606f429b01808c1a22cb44cb5b89585c5c. --- bin/pyspark | 2 +- dev/deps/spark-deps-hadoop-2.6 | 5 - dev/deps/spark-deps-hadoop-2.7 | 5 - dev/run-pip-tests | 6 - pom.xml | 20 - python/pyspark/serializers.py | 17 - python/pyspark/sql/dataframe.py | 48 +- python/pyspark/sql/tests.py | 79 +- .../apache/spark/sql/internal/SQLConf.scala | 22 - sql/core/pom.xml | 4 - .../scala/org/apache/spark/sql/Dataset.scala | 20 - .../sql/execution/arrow/ArrowConverters.scala | 429 ------ .../arrow/ArrowConvertersSuite.scala | 1222 ----------------- 13 files changed, 13 insertions(+), 1866 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala diff --git a/bin/pyspark b/bin/pyspark index 8eeea7716cc98..98387c2ec5b8a 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m "$@" + exec "$PYSPARK_DRIVER_PYTHON" -m "$1" exit fi diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 9868c1ab7c2ab..9287bd47cf113 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -13,9 +13,6 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.4.0.jar -arrow-memory-0.4.0.jar -arrow-vector-0.4.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -58,7 +55,6 @@ datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar -flatbuffers-1.2.0-3f79e055.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -81,7 +77,6 @@ hadoop-yarn-server-web-proxy-2.6.5.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar -hppc-0.7.1.jar htrace-core-3.0.4.jar httpclient-4.5.2.jar httpcore-4.4.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 57c78cfe12087..9127413ab6c23 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -13,9 +13,6 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.4.0.jar -arrow-memory-0.4.0.jar -arrow-vector-0.4.0.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -58,7 +55,6 @@ datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar -flatbuffers-1.2.0-3f79e055.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -81,7 +77,6 @@ hadoop-yarn-server-web-proxy-2.7.3.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar -hppc-0.7.1.jar htrace-core-3.1.0-incubating.jar httpclient-4.5.2.jar httpcore-4.4.4.jar diff --git a/dev/run-pip-tests b/dev/run-pip-tests index 225e9209536f0..d51dde12a03c5 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -83,8 +83,6 @@ for python in "${PYTHON_EXECS[@]}"; do if [ -n "$USE_CONDA" ]; then conda create -y -p "$VIRTUALENV_PATH" python=$python numpy pandas pip setuptools source activate "$VIRTUALENV_PATH" - conda install -y -c conda-forge pyarrow=0.4.0 - TEST_PYARROW=1 else mkdir -p "$VIRTUALENV_PATH" virtualenv --python=$python "$VIRTUALENV_PATH" @@ -122,10 +120,6 @@ for python in "${PYTHON_EXECS[@]}"; do python "$FWDIR"/dev/pip-sanity-check.py echo "Run the tests for context.py" python "$FWDIR"/python/pyspark/context.py - if [ -n "$TEST_PYARROW" ]; then - echo "Run tests for pyarrow" - SPARK_TESTING=1 "$FWDIR"/bin/pyspark pyspark.sql.tests ArrowTests - fi cd "$FWDIR" diff --git a/pom.xml b/pom.xml index f124ba45007b7..5f524079495c0 100644 --- a/pom.xml +++ b/pom.xml @@ -181,7 +181,6 @@ 2.6 1.8 1.0.0 - 0.4.0 ${java.home} @@ -1879,25 +1878,6 @@ paranamer ${paranamer.version} - - org.apache.arrow - arrow-vector - ${arrow.version} - - - com.fasterxml.jackson.core - jackson-annotations - - - com.fasterxml.jackson.core - jackson-databind - - - io.netty - netty-handler - - - diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d5c2a7518b18f..ea5e00e9eeef5 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -182,23 +182,6 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): - """ - Serializes an Arrow stream. - """ - - def dumps(self, obj): - raise NotImplementedError - - def loads(self, obj): - import pyarrow as pa - reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) - return reader.read_all() - - def __repr__(self): - return "ArrowSerializer" - - class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 760f113dfd197..0649271ed2246 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,8 +29,7 @@ from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ - UTF8Deserializer +from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -1709,8 +1708,7 @@ def toDF(self, *cols): @since(1.3) def toPandas(self): - """ - Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. This is only available if Pandas is installed and available. @@ -1723,42 +1721,18 @@ def toPandas(self): 1 5 Bob """ import pandas as pd - if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true": - try: - import pyarrow - tables = self._collectAsArrow() - if tables: - table = pyarrow.concat_tables(tables) - return table.to_pandas() - else: - return pd.DataFrame.from_records([], columns=self.columns) - except ImportError as e: - msg = "note: pyarrow must be installed and available on calling Python process " \ - "if using spark.sql.execution.arrow.enable=true" - raise ImportError("%s\n%s" % (e.message, msg)) - else: - dtype = {} - for field in self.schema: - pandas_type = _to_corrected_pandas_type(field.dataType) - if pandas_type is not None: - dtype[field.name] = pandas_type - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + if pandas_type is not None: + dtype[field.name] = pandas_type - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) - return pdf + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - def _collectAsArrow(self): - """ - Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed - and available. - - .. note:: Experimental. - """ - with SCCallSiteSync(self._sc) as css: - port = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(port, ArrowSerializer())) + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + return pdf ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 326e8548a617c..0a1cd6856b8e8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -58,21 +58,12 @@ from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type -from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests +from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException -_have_arrow = False -try: - import pyarrow - _have_arrow = True -except: - # No Arrow, but that's okay, we'll skip those tests - pass - - class UTCOffsetTimezone(datetime.tzinfo): """ Specifies timezone in UTC offset @@ -2629,74 +2620,6 @@ def range_frame_match(): importlib.reload(window) - -@unittest.skipIf(not _have_arrow, "Arrow not installed") -class ArrowTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - cls.spark.conf.set("spark.sql.execution.arrow.enable", "true") - cls.schema = StructType([ - StructField("1_str_t", StringType(), True), - StructField("2_int_t", IntegerType(), True), - StructField("3_long_t", LongType(), True), - StructField("4_float_t", FloatType(), True), - StructField("5_double_t", DoubleType(), True)]) - cls.data = [("a", 1, 10, 0.2, 2.0), - ("b", 2, 20, 0.4, 4.0), - ("c", 3, 30, 0.8, 6.0)] - - def assertFramesEqual(self, df_with_arrow, df_without): - msg = ("DataFrame from Arrow is not equal" + - ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + - ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) - self.assertTrue(df_without.equals(df_with_arrow), msg=msg) - - def test_unsupported_datatype(self): - schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)]) - df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema) - with QuietTest(self.sc): - self.assertRaises(Exception, lambda: df.toPandas()) - - def test_null_conversion(self): - df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + - self.data) - pdf = df_null.toPandas() - null_counts = pdf.isnull().sum().tolist() - self.assertTrue(all([c == 1 for c in null_counts])) - - def test_toPandas_arrow_toggle(self): - df = self.spark.createDataFrame(self.data, schema=self.schema) - self.spark.conf.set("spark.sql.execution.arrow.enable", "false") - pdf = df.toPandas() - self.spark.conf.set("spark.sql.execution.arrow.enable", "true") - pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) - - def test_pandas_round_trip(self): - import pandas as pd - import numpy as np - data_dict = {} - for j, name in enumerate(self.schema.names): - data_dict[name] = [self.data[i][j] for i in range(len(self.data))] - # need to convert these to numpy types first - data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) - data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) - pdf = pd.DataFrame(data=data_dict) - df = self.spark.createDataFrame(self.data, schema=self.schema) - pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) - - def test_filtered_frame(self): - df = self.spark.range(3).toDF("i") - pdf = df.filter("i < 0").toPandas() - self.assertEqual(len(pdf.columns), 1) - self.assertEqual(pdf.columns[0], "i") - self.assertTrue(pdf.empty) - - if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9c8e26a8eeadf..c641e4d3a23e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -847,24 +847,6 @@ object SQLConf { .intConf .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) - val ARROW_EXECUTION_ENABLE = - buildConf("spark.sql.execution.arrow.enable") - .internal() - .doc("Make use of Apache Arrow for columnar data transfers. Currently available " + - "for use with pyspark.sql.DataFrame.toPandas with the following data types: " + - "StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, " + - "LongType, ShortType") - .booleanConf - .createWithDefault(false) - - val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = - buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") - .internal() - .doc("When using Apache Arrow, limit the maximum number of records that can be written " + - "to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.") - .intConf - .createWithDefault(10000) - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1123,10 +1105,6 @@ class SQLConf extends Serializable with Logging { def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO) - def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) - - def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) - /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 661c31ded7148..1bc34a6b069d9 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -103,10 +103,6 @@ jackson-databind ${fasterxml.jackson.version} - - org.apache.arrow - arrow-vector - org.apache.xbean xbean-asm5-shaded diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 268a37ff5d271..7be4aa1ca9562 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -47,7 +47,6 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython @@ -2887,16 +2886,6 @@ class Dataset[T] private[sql]( } } - /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. - */ - private[sql] def collectAsArrowToPython(): Int = { - withNewExecutionId { - val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") - } - } - private[sql] def toPythonIterator(): Int = { withNewExecutionId { PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) @@ -2978,13 +2967,4 @@ class Dataset[T] private[sql]( Dataset(sparkSession, logicalPlan) } } - - /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload: RDD[ArrowPayload] = { - val schemaCaptured = this.schema - val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch - queryExecution.toRdd.mapPartitionsInternal { iter => - ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala deleted file mode 100644 index 6af5c73422377..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ /dev/null @@ -1,429 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.arrow - -import java.io.ByteArrayOutputStream -import java.nio.channels.Channels - -import scala.collection.JavaConverters._ - -import io.netty.buffer.ArrowBuf -import org.apache.arrow.memory.{BufferAllocator, RootAllocator} -import org.apache.arrow.vector._ -import org.apache.arrow.vector.BaseValueVector.BaseMutator -import org.apache.arrow.vector.file._ -import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} -import org.apache.arrow.vector.types.FloatingPointPrecision -import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} -import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils - - -/** - * Store Arrow data in a form that can be serialized by Spark and served to a Python process. - */ -private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable { - - /** - * Convert the ArrowPayload to an ArrowRecordBatch. - */ - def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { - ArrowConverters.byteArrayToBatch(payload, allocator) - } - - /** - * Get the ArrowPayload as a type that can be served to Python. - */ - def asPythonSerializable: Array[Byte] = payload -} - -private[sql] object ArrowPayload { - - /** - * Create an ArrowPayload from an ArrowRecordBatch and Spark schema. - */ - def apply( - batch: ArrowRecordBatch, - schema: StructType, - allocator: BufferAllocator): ArrowPayload = { - new ArrowPayload(ArrowConverters.batchToByteArray(batch, schema, allocator)) - } -} - -private[sql] object ArrowConverters { - - /** - * Map a Spark DataType to ArrowType. - */ - private[arrow] def sparkTypeToArrowType(dataType: DataType): ArrowType = { - dataType match { - case BooleanType => ArrowType.Bool.INSTANCE - case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true) - case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true) - case LongType => new ArrowType.Int(8 * LongType.defaultSize, true) - case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) - case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case ByteType => new ArrowType.Int(8, true) - case StringType => ArrowType.Utf8.INSTANCE - case BinaryType => ArrowType.Binary.INSTANCE - case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") - } - } - - /** - * Convert a Spark Dataset schema to Arrow schema. - */ - private[arrow] def schemaToArrowSchema(schema: StructType): Schema = { - val arrowFields = schema.fields.map { f => - new Field(f.name, f.nullable, sparkTypeToArrowType(f.dataType), List.empty[Field].asJava) - } - new Schema(arrowFields.toList.asJava) - } - - /** - * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload - * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. - */ - private[sql] def toPayloadIterator( - rowIter: Iterator[InternalRow], - schema: StructType, - maxRecordsPerBatch: Int): Iterator[ArrowPayload] = { - new Iterator[ArrowPayload] { - private val _allocator = new RootAllocator(Long.MaxValue) - private var _nextPayload = if (rowIter.nonEmpty) convert() else null - - override def hasNext: Boolean = _nextPayload != null - - override def next(): ArrowPayload = { - val obj = _nextPayload - if (hasNext) { - if (rowIter.hasNext) { - _nextPayload = convert() - } else { - _allocator.close() - _nextPayload = null - } - } - obj - } - - private def convert(): ArrowPayload = { - val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch) - ArrowPayload(batch, schema, _allocator) - } - } - } - - /** - * Iterate over InternalRows and write to an ArrowRecordBatch, stopping when rowIter is consumed - * or the number of records in the batch equals maxRecordsInBatch. If maxRecordsPerBatch is 0, - * then rowIter will be fully consumed. - */ - private def internalRowIterToArrowBatch( - rowIter: Iterator[InternalRow], - schema: StructType, - allocator: BufferAllocator, - maxRecordsPerBatch: Int = 0): ArrowRecordBatch = { - - val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) => - ColumnWriter(field.dataType, ordinal, allocator).init() - } - - val writerLength = columnWriters.length - var recordsInBatch = 0 - while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) { - val row = rowIter.next() - var i = 0 - while (i < writerLength) { - columnWriters(i).write(row) - i += 1 - } - recordsInBatch += 1 - } - - val (fieldNodes, bufferArrays) = columnWriters.map(_.finish()).unzip - val buffers = bufferArrays.flatten - - val rowLength = if (fieldNodes.nonEmpty) fieldNodes.head.getLength else 0 - val recordBatch = new ArrowRecordBatch(rowLength, - fieldNodes.toList.asJava, buffers.toList.asJava) - - buffers.foreach(_.release()) - recordBatch - } - - /** - * Convert an ArrowRecordBatch to a byte array and close batch to release resources. Once closed, - * the batch can no longer be used. - */ - private[arrow] def batchToByteArray( - batch: ArrowRecordBatch, - schema: StructType, - allocator: BufferAllocator): Array[Byte] = { - val arrowSchema = ArrowConverters.schemaToArrowSchema(schema) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val out = new ByteArrayOutputStream() - val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) - - // Write a batch to byte stream, ensure the batch, allocator and writer are closed - Utils.tryWithSafeFinally { - val loader = new VectorLoader(root) - loader.load(batch) - writer.writeBatch() // writeBatch can throw IOException - } { - batch.close() - root.close() - writer.close() - } - out.toByteArray - } - - /** - * Convert a byte array to an ArrowRecordBatch. - */ - private[arrow] def byteArrayToBatch( - batchBytes: Array[Byte], - allocator: BufferAllocator): ArrowRecordBatch = { - val in = new ByteArrayReadableSeekableByteChannel(batchBytes) - val reader = new ArrowFileReader(in, allocator) - - // Read a batch from a byte stream, ensure the reader is closed - Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch - } { - reader.close() - } - } -} - -/** - * Interface for writing InternalRows to Arrow Buffers. - */ -private[arrow] trait ColumnWriter { - def init(): this.type - def write(row: InternalRow): Unit - - /** - * Clear the column writer and return the ArrowFieldNode and ArrowBuf. - * This should be called only once after all the data is written. - */ - def finish(): (ArrowFieldNode, Array[ArrowBuf]) -} - -/** - * Base class for flat arrow column writer, i.e., column without children. - */ -private[arrow] abstract class PrimitiveColumnWriter(val ordinal: Int) - extends ColumnWriter { - - def getFieldType(dtype: ArrowType): FieldType = FieldType.nullable(dtype) - - def valueVector: BaseDataValueVector - def valueMutator: BaseMutator - - def setNull(): Unit - def setValue(row: InternalRow): Unit - - protected var count = 0 - protected var nullCount = 0 - - override def init(): this.type = { - valueVector.allocateNew() - this - } - - override def write(row: InternalRow): Unit = { - if (row.isNullAt(ordinal)) { - setNull() - nullCount += 1 - } else { - setValue(row) - } - count += 1 - } - - override def finish(): (ArrowFieldNode, Array[ArrowBuf]) = { - valueMutator.setValueCount(count) - val fieldNode = new ArrowFieldNode(count, nullCount) - val valueBuffers = valueVector.getBuffers(true) - (fieldNode, valueBuffers) - } -} - -private[arrow] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableBitVector - = new NullableBitVector("BooleanValue", getFieldType(dtype), allocator) - override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, if (row.getBoolean(ordinal)) 1 else 0 ) -} - -private[arrow] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableSmallIntVector - = new NullableSmallIntVector("ShortValue", getFieldType(dtype: ArrowType), allocator) - override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getShort(ordinal)) -} - -private[arrow] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableIntVector - = new NullableIntVector("IntValue", getFieldType(dtype), allocator) - override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getInt(ordinal)) -} - -private[arrow] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableBigIntVector - = new NullableBigIntVector("LongValue", getFieldType(dtype), allocator) - override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getLong(ordinal)) -} - -private[arrow] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableFloat4Vector - = new NullableFloat4Vector("FloatValue", getFieldType(dtype), allocator) - override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getFloat(ordinal)) -} - -private[arrow] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableFloat8Vector - = new NullableFloat8Vector("DoubleValue", getFieldType(dtype), allocator) - override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getDouble(ordinal)) -} - -private[arrow] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableUInt1Vector - = new NullableUInt1Vector("ByteValue", getFieldType(dtype), allocator) - override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getByte(ordinal)) -} - -private[arrow] class UTF8StringColumnWriter( - dtype: ArrowType, - ordinal: Int, - allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableVarCharVector - = new NullableVarCharVector("UTF8StringValue", getFieldType(dtype), allocator) - override val valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - val str = row.getUTF8String(ordinal) - valueMutator.setSafe(count, str.getByteBuffer, 0, str.numBytes) - } -} - -private[arrow] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableVarBinaryVector - = new NullableVarBinaryVector("BinaryValue", getFieldType(dtype), allocator) - override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - val bytes = row.getBinary(ordinal) - valueMutator.setSafe(count, bytes, 0, bytes.length) - } -} - -private[arrow] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableDateDayVector - = new NullableDateDayVector("DateValue", getFieldType(dtype), allocator) - override val valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - valueMutator.setSafe(count, row.getInt(ordinal)) - } -} - -private[arrow] class TimeStampColumnWriter( - dtype: ArrowType, - ordinal: Int, - allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableTimeStampMicroVector - = new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator) - override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - valueMutator.setSafe(count, row.getLong(ordinal)) - } -} - -private[arrow] object ColumnWriter { - - /** - * Create an Arrow ColumnWriter given the type and ordinal of row. - */ - def apply(dataType: DataType, ordinal: Int, allocator: BufferAllocator): ColumnWriter = { - val dtype = ArrowConverters.sparkTypeToArrowType(dataType) - dataType match { - case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator) - case ShortType => new ShortColumnWriter(dtype, ordinal, allocator) - case IntegerType => new IntegerColumnWriter(dtype, ordinal, allocator) - case LongType => new LongColumnWriter(dtype, ordinal, allocator) - case FloatType => new FloatColumnWriter(dtype, ordinal, allocator) - case DoubleType => new DoubleColumnWriter(dtype, ordinal, allocator) - case ByteType => new ByteColumnWriter(dtype, ordinal, allocator) - case StringType => new UTF8StringColumnWriter(dtype, ordinal, allocator) - case BinaryType => new BinaryColumnWriter(dtype, ordinal, allocator) - case DateType => new DateColumnWriter(dtype, ordinal, allocator) - case TimestampType => new TimeStampColumnWriter(dtype, ordinal, allocator) - case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala deleted file mode 100644 index 159328cc0d958..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ /dev/null @@ -1,1222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.arrow - -import java.io.File -import java.nio.charset.StandardCharsets -import java.sql.{Date, Timestamp} -import java.text.SimpleDateFormat -import java.util.Locale - -import com.google.common.io.Files -import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} -import org.apache.arrow.vector.file.json.JsonFileReader -import org.apache.arrow.vector.util.Validator -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.SparkException -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{BinaryType, StructField, StructType} -import org.apache.spark.util.Utils - - -class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { - import testImplicits._ - - private var tempDataPath: String = _ - - override def beforeAll(): Unit = { - super.beforeAll() - tempDataPath = Utils.createTempDir(namePrefix = "arrow").getAbsolutePath - } - - test("collect to arrow record batch") { - val indexData = (1 to 6).toDF("i") - val arrowPayloads = indexData.toArrowPayload.collect() - assert(arrowPayloads.nonEmpty) - assert(arrowPayloads.length == indexData.rdd.getNumPartitions) - val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) - val rowCount = arrowRecordBatches.map(_.getLength).sum - assert(rowCount === indexData.count()) - arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0)) - arrowRecordBatches.foreach(_.close()) - allocator.close() - } - - test("short conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_s", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 16 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 16 - | } ] - | } - | }, { - | "name" : "b_s", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 16 - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 16 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_s", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 2, -2, 32767, -32768 ] - | }, { - | "name" : "b_s", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1, 0, 0, -2, 0, -32768 ] - | } ] - | } ] - |} - """.stripMargin - - val a_s = List[Short](1, -1, 2, -2, 32767, -32768) - val b_s = List[Option[Short]](Some(1), None, None, Some(-2), None, Some(-32768)) - val df = a_s.zip(b_s).toDF("a_s", "b_s") - - collectAndValidate(df, json, "integer-16bit.json") - } - - test("int conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "b_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_i", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] - | }, { - | "name" : "b_i", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] - | } ] - | } ] - |} - """.stripMargin - - val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) - val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) - val df = a_i.zip(b_i).toDF("a_i", "b_i") - - collectAndValidate(df, json, "integer-32bit.json") - } - - test("long conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_l", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 64 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | }, { - | "name" : "b_l", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 64 - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_l", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 2, -2, 9223372036854775807, -9223372036854775808 ] - | }, { - | "name" : "b_l", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1, 0, 0, -2, 0, -9223372036854775808 ] - | } ] - | } ] - |} - """.stripMargin - - val a_l = List[Long](1, -1, 2, -2, 9223372036854775807L, -9223372036854775808L) - val b_l = List[Option[Long]](Some(1), None, None, Some(-2), None, Some(-9223372036854775808L)) - val df = a_l.zip(b_l).toDF("a_l", "b_l") - - collectAndValidate(df, json, "integer-64bit.json") - } - - test("float conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_f", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "SINGLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "b_f", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "SINGLE" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_f", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0 ] - | }, { - | "name" : "b_f", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] - | } ] - | } ] - |} - """.stripMargin - - val a_f = List(1.0f, 2.0f, 0.01f, 200.0f, 0.0001f, 20000.0f) - val b_f = List[Option[Float]](Some(1.1f), None, None, Some(2.2f), None, Some(3.3f)) - val df = a_f.zip(b_f).toDF("a_f", "b_f") - - collectAndValidate(df, json, "floating_point-single_precision.json") - } - - test("double conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_d", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "DOUBLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | }, { - | "name" : "b_d", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "DOUBLE" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_d", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 1.0E-4, 20000.0 ] - | }, { - | "name" : "b_d", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ] - | } ] - | } ] - |} - """.stripMargin - - val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0) - val b_d = List[Option[Double]](Some(1.1), None, None, Some(2.2), None, Some(3.3)) - val df = a_d.zip(b_d).toDF("a_d", "b_d") - - collectAndValidate(df, json, "floating_point-double_precision.json") - } - - test("index conversion") { - val data = List[Int](1, 2, 3, 4, 5, 6) - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "i", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, 2, 3, 4, 5, 6 ] - | } ] - | } ] - |} - """.stripMargin - val df = data.toDF("i") - - collectAndValidate(df, json, "indexData-ints.json") - } - - test("mixed numeric type conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 16 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 16 - | } ] - | } - | }, { - | "name" : "b", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "SINGLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "c", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "d", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "DOUBLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | }, { - | "name" : "e", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 64 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, 2, 3, 4, 5, 6 ] - | }, { - | "name" : "b", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] - | }, { - | "name" : "c", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, 2, 3, 4, 5, 6 ] - | }, { - | "name" : "d", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ] - | }, { - | "name" : "e", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, 2, 3, 4, 5, 6 ] - | } ] - | } ] - |} - """.stripMargin - - val data = List(1, 2, 3, 4, 5, 6) - val data_tuples = for (d <- data) yield { - (d.toShort, d.toFloat, d.toInt, d.toDouble, d.toLong) - } - val df = data_tuples.toDF("a", "b", "c", "d", "e") - - collectAndValidate(df, json, "mixed_numeric_types.json") - } - - test("string type conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "upper_case", - | "type" : { - | "name" : "utf8" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } - | }, { - | "name" : "lower_case", - | "type" : { - | "name" : "utf8" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } - | }, { - | "name" : "null_str", - | "type" : { - | "name" : "utf8" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 3, - | "columns" : [ { - | "name" : "upper_case", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "OFFSET" : [ 0, 1, 2, 3 ], - | "DATA" : [ "A", "B", "C" ] - | }, { - | "name" : "lower_case", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "OFFSET" : [ 0, 1, 2, 3 ], - | "DATA" : [ "a", "b", "c" ] - | }, { - | "name" : "null_str", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 0 ], - | "OFFSET" : [ 0, 2, 5, 5 ], - | "DATA" : [ "ab", "CDE", "" ] - | } ] - | } ] - |} - """.stripMargin - - val upperCase = Seq("A", "B", "C") - val lowerCase = Seq("a", "b", "c") - val nullStr = Seq("ab", "CDE", null) - val df = (upperCase, lowerCase, nullStr).zipped.toList - .toDF("upper_case", "lower_case", "null_str") - - collectAndValidate(df, json, "stringData.json") - } - - test("boolean type conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_bool", - | "type" : { - | "name" : "bool" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 1 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 4, - | "columns" : [ { - | "name" : "a_bool", - | "count" : 4, - | "VALIDITY" : [ 1, 1, 1, 1 ], - | "DATA" : [ true, true, false, true ] - | } ] - | } ] - |} - """.stripMargin - val df = Seq(true, true, false, true).toDF("a_bool") - collectAndValidate(df, json, "boolData.json") - } - - test("byte type conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_byte", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 8 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 4, - | "columns" : [ { - | "name" : "a_byte", - | "count" : 4, - | "VALIDITY" : [ 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 64, 127 ] - | } ] - | } ] - |} - | - """.stripMargin - val df = List[Byte](1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue).toDF("a_byte") - collectAndValidate(df, json, "byteData.json") - } - - test("binary type conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_binary", - | "type" : { - | "name" : "binary" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "OFFSET", - | "typeBitWidth" : 32 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 8 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 3, - | "columns" : [ { - | "name" : "a_binary", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "OFFSET" : [ 0, 3, 4, 6 ], - | "DATA" : [ "616263", "64", "6566" ] - | } ] - | } ] - |} - """.stripMargin - - val data = Seq("abc", "d", "ef") - val rdd = sparkContext.parallelize(data.map(s => Row(s.getBytes("utf-8")))) - val df = spark.createDataFrame(rdd, StructType(Seq(StructField("a_binary", BinaryType)))) - - collectAndValidate(df, json, "binaryData.json") - } - - test("floating-point NaN") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "NaN_f", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "SINGLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "NaN_d", - | "type" : { - | "name" : "floatingpoint", - | "precision" : "DOUBLE" - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 2, - | "columns" : [ { - | "name" : "NaN_f", - | "count" : 2, - | "VALIDITY" : [ 1, 1 ], - | "DATA" : [ 1.2000000476837158, "NaN" ] - | }, { - | "name" : "NaN_d", - | "count" : 2, - | "VALIDITY" : [ 1, 1 ], - | "DATA" : [ "NaN", 1.2 ] - | } ] - | } ] - |} - """.stripMargin - - val fnan = Seq(1.2F, Float.NaN) - val dnan = Seq(Double.NaN, 1.2) - val df = fnan.zip(dnan).toDF("NaN_f", "NaN_d") - - collectAndValidate(df, json, "nanData-floating_point.json") - } - - test("partitioned DataFrame") { - val json1 = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "b", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 3, - | "columns" : [ { - | "name" : "a", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "DATA" : [ 1, 1, 2 ] - | }, { - | "name" : "b", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "DATA" : [ 1, 2, 1 ] - | } ] - | } ] - |} - """.stripMargin - val json2 = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "b", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 3, - | "columns" : [ { - | "name" : "a", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "DATA" : [ 2, 3, 3 ] - | }, { - | "name" : "b", - | "count" : 3, - | "VALIDITY" : [ 1, 1, 1 ], - | "DATA" : [ 2, 1, 2 ] - | } ] - | } ] - |} - """.stripMargin - - val arrowPayloads = testData2.toArrowPayload.collect() - // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload - assert(arrowPayloads.length === 2) - val schema = testData2.schema - - val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") - val tempFile2 = new File(tempDataPath, "testData2-ints-part2.json") - Files.write(json1, tempFile1, StandardCharsets.UTF_8) - Files.write(json2, tempFile2, StandardCharsets.UTF_8) - - validateConversion(schema, arrowPayloads(0), tempFile1) - validateConversion(schema, arrowPayloads(1), tempFile2) - } - - test("empty frame collect") { - val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() - assert(arrowPayload.isEmpty) - - val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") - val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect() - assert(filteredArrowPayload.isEmpty) - } - - test("empty partition collect") { - val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") - val arrowPayloads = emptyPart.toArrowPayload.collect() - assert(arrowPayloads.length === 1) - val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) - assert(arrowRecordBatches.head.getLength == 1) - arrowRecordBatches.foreach(_.close()) - allocator.close() - } - - test("max records in batch conf") { - val totalRecords = 10 - val maxRecordsPerBatch = 3 - spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) - val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") - val arrowPayloads = df.toArrowPayload.collect() - val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) - var recordCount = 0 - arrowRecordBatches.foreach { batch => - assert(batch.getLength > 0) - assert(batch.getLength <= maxRecordsPerBatch) - recordCount += batch.getLength - batch.close() - } - assert(recordCount == totalRecords) - allocator.close() - spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") - } - - testQuietly("unsupported types") { - def runUnsupported(block: => Unit): Unit = { - val msg = intercept[SparkException] { - block - } - assert(msg.getMessage.contains("Unsupported data type")) - assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) - } - - runUnsupported { decimalData.toArrowPayload.collect() } - runUnsupported { arrayData.toDF().toArrowPayload.collect() } - runUnsupported { mapData.toDF().toArrowPayload.collect() } - runUnsupported { complexData.toArrowPayload.collect() } - - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) - val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) - val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) - runUnsupported { Seq(d1, d2).toDF("date").toArrowPayload.collect() } - - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) - val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) - runUnsupported { Seq(ts1, ts2).toDF("timestamp").toArrowPayload.collect() } - } - - test("test Arrow Validator") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "a_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "b_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_i", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] - | }, { - | "name" : "b_i", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] - | } ] - | } ] - |} - """.stripMargin - val json_diff_col_order = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "b_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | }, { - | "name" : "a_i", - | "type" : { - | "name" : "int", - | "isSigned" : true, - | "bitWidth" : 32 - | }, - | "nullable" : false, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 32 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 6, - | "columns" : [ { - | "name" : "a_i", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], - | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ] - | }, { - | "name" : "b_i", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], - | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ] - | } ] - | } ] - |} - """.stripMargin - - val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648) - val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648)) - val df = a_i.zip(b_i).toDF("a_i", "b_i") - - // Different schema - intercept[IllegalArgumentException] { - collectAndValidate(df, json_diff_col_order, "validator_diff_schema.json") - } - - // Different values - intercept[IllegalArgumentException] { - collectAndValidate(df.sort($"a_i".desc), json, "validator_diff_values.json") - } - } - - /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ - private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = { - // NOTE: coalesce to single partition because can only load 1 batch in validator - val arrowPayload = df.coalesce(1).toArrowPayload.collect().head - val tempFile = new File(tempDataPath, file) - Files.write(json, tempFile, StandardCharsets.UTF_8) - validateConversion(df.schema, arrowPayload, tempFile) - } - - private def validateConversion( - sparkSchema: StructType, - arrowPayload: ArrowPayload, - jsonFile: File): Unit = { - val allocator = new RootAllocator(Long.MaxValue) - val jsonReader = new JsonFileReader(jsonFile, allocator) - - val arrowSchema = ArrowConverters.schemaToArrowSchema(sparkSchema) - val jsonSchema = jsonReader.start() - Validator.compareSchemas(arrowSchema, jsonSchema) - - val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) - val vectorLoader = new VectorLoader(arrowRoot) - val arrowRecordBatch = arrowPayload.loadBatch(allocator) - vectorLoader.load(arrowRecordBatch) - val jsonRoot = jsonReader.read() - Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) - - jsonRoot.close() - jsonReader.close() - arrowRecordBatch.close() - arrowRoot.close() - allocator.close() - } -} From e68aed70fbf1cfa59ba51df70287d718d737a193 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 28 Jun 2017 10:45:45 -0700 Subject: [PATCH 092/118] [SPARK-21216][SS] Hive strategies missed in Structured Streaming IncrementalExecution ## What changes were proposed in this pull request? If someone creates a HiveSession, the planner in `IncrementalExecution` doesn't take into account the Hive scan strategies. This causes joins of Streaming DataFrame's with Hive tables to fail. ## How was this patch tested? Regression test Author: Burak Yavuz Closes #18426 from brkyvz/hive-join. --- .../streaming/IncrementalExecution.scala | 4 ++ .../sql/hive/execution/HiveDDLSuite.scala | 41 ++++++++++++++++++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index ab89dc6b705d5..dbe652b3b1ed2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -47,6 +47,10 @@ class IncrementalExecution( sparkSession.sparkContext, sparkSession.sessionState.conf, sparkSession.sessionState.experimentalMethods) { + override def strategies: Seq[Strategy] = + extraPlanningStrategies ++ + sparkSession.sessionState.planner.strategies + override def extraPlanningStrategies: Seq[Strategy] = StatefulAggregationStrategy :: FlatMapGroupsWithStateStrategy :: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index aca964907d4cd..31fa3d2447467 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -160,7 +160,6 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA test("drop table") { testDropTable(isDatasourceTable = false) } - } class HiveDDLSuite @@ -1956,4 +1955,44 @@ class HiveDDLSuite } } } + + test("SPARK-21216: join with a streaming DataFrame") { + import org.apache.spark.sql.execution.streaming.MemoryStream + import testImplicits._ + + implicit val _sqlContext = spark.sqlContext + + Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word").createOrReplaceTempView("t1") + // Make a table and ensure it will be broadcast. + sql("""CREATE TABLE smallTable(word string, number int) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |STORED AS TEXTFILE + """.stripMargin) + + sql( + """INSERT INTO smallTable + |SELECT word, number from t1 + """.stripMargin) + + val inputData = MemoryStream[Int] + val joined = inputData.toDS().toDF() + .join(spark.table("smallTable"), $"value" === $"number") + + val sq = joined.writeStream + .format("memory") + .queryName("t2") + .start() + try { + inputData.addData(1, 2) + + sq.processAllAvailable() + + checkAnswer( + spark.table("t2"), + Seq(Row(1, "one", 1), Row(2, "two", 2)) + ) + } finally { + sq.stop() + } + } } From b72b8521d9cad878a1a4e4dbb19cf980169dcbc7 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 29 Jun 2017 08:47:31 +0800 Subject: [PATCH 093/118] [SPARK-21222] Move elimination of Distinct clause from analyzer to optimizer ## What changes were proposed in this pull request? Move elimination of Distinct clause from analyzer to optimizer Distinct clause is useless after MAX/MIN clause. For example, "Select MAX(distinct a) FROM src from" is equivalent of "Select MAX(a) FROM src from" However, this optimization is implemented in analyzer. It should be in optimizer. ## How was this patch tested? Unit test gatorsmile cloud-fan Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Wang Gengliang Closes #18429 from gengliangwang/distinct_opt. --- .../sql/catalyst/analysis/Analyzer.scala | 5 -- .../spark/sql/catalyst/dsl/package.scala | 2 + .../sql/catalyst/optimizer/Optimizer.scala | 15 +++++ .../optimizer/EliminateDistinctSuite.scala | 56 +++++++++++++++++++ 4 files changed, 73 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 434b6ffee37fa..53536496d0457 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1197,11 +1197,6 @@ class Analyzer( case u @ UnresolvedFunction(funcId, children, isDistinct) => withPosition(u) { catalog.lookupFunction(funcId, children) match { - // DISTINCT is not meaningful for a Max or a Min. - case max: Max if isDistinct => - AggregateExpression(max, Complete, isDistinct = false) - case min: Min if isDistinct => - AggregateExpression(min, Complete, isDistinct = false) // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within // the context of a Window clause. They do not need to be wrapped in an // AggregateExpression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index beee93d906f0f..f6792569b704e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -159,7 +159,9 @@ package object dsl { def first(e: Expression): Expression = new First(e).toAggregateExpression() def last(e: Expression): Expression = new Last(e).toAggregateExpression() def min(e: Expression): Expression = Min(e).toAggregateExpression() + def minDistinct(e: Expression): Expression = Min(e).toAggregateExpression(isDistinct = true) def max(e: Expression): Expression = Max(e).toAggregateExpression() + def maxDistinct(e: Expression): Expression = Max(e).toAggregateExpression(isDistinct = true) def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b410312030c5d..946fa7bae0199 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -40,6 +40,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) protected val fixedPoint = FixedPoint(conf.optimizerMaxIterations) def batches: Seq[Batch] = { + Batch("Eliminate Distinct", Once, EliminateDistinct) :: // Technically some of the rules in Finish Analysis are not optimizer rules and belong more // in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime). // However, because we also use the analyzer to canonicalized queries (for view definition), @@ -151,6 +152,20 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil } +/** + * Remove useless DISTINCT for MAX and MIN. + * This rule should be applied before RewriteDistinctAggregates. + */ +object EliminateDistinct extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformExpressions { + case ae: AggregateExpression if ae.isDistinct => + ae.aggregateFunction match { + case _: Max | _: Min => ae.copy(isDistinct = false) + case _ => ae + } + } +} + /** * An optimizer used in test code. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala new file mode 100644 index 0000000000000..f40691bd1a038 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class EliminateDistinctSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", Once, + EliminateDistinct) :: Nil + } + + val testRelation = LocalRelation('a.int) + + test("Eliminate Distinct in Max") { + val query = testRelation + .select(maxDistinct('a).as('result)) + .analyze + val answer = testRelation + .select(max('a).as('result)) + .analyze + assert(query != answer) + comparePlans(Optimize.execute(query), answer) + } + + test("Eliminate Distinct in Min") { + val query = testRelation + .select(minDistinct('a).as('result)) + .analyze + val answer = testRelation + .select(min('a).as('result)) + .analyze + assert(query != answer) + comparePlans(Optimize.execute(query), answer) + } +} From 376d90d556fcd4fd84f70ee42a1323e1f48f829d Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Wed, 28 Jun 2017 19:31:54 -0700 Subject: [PATCH 094/118] [SPARK-20889][SPARKR] Grouped documentation for STRING column methods ## What changes were proposed in this pull request? Grouped documentation for string column methods. Author: actuaryzhang Author: Wayne Zhang Closes #18366 from actuaryzhang/sparkRDocString. --- R/pkg/R/functions.R | 573 +++++++++++++++++++------------------------- R/pkg/R/generics.R | 84 ++++--- 2 files changed, 300 insertions(+), 357 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 23ccdf941a8c7..70ea620b471fe 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -111,6 +111,27 @@ NULL #' head(tmp)} NULL +#' String functions for Column operations +#' +#' String functions defined for \code{Column}. +#' +#' @param x Column to compute on except in the following methods: +#' \itemize{ +#' \item \code{instr}: \code{character}, the substring to check. See 'Details'. +#' \item \code{format_number}: \code{numeric}, the number of decimal place to +#' format to. See 'Details'. +#' } +#' @param y Column to compute on. +#' @param ... additional columns. +#' @name column_string_functions +#' @rdname column_string_functions +#' @family string functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(as.data.frame(Titanic, stringsAsFactors = FALSE))} +NULL + #' lit #' #' A new \linkS4class{Column} is created to represent the literal value. @@ -188,19 +209,17 @@ setMethod("approxCountDistinct", column(jc) }) -#' ascii -#' -#' Computes the numeric value of the first character of the string column, and returns the -#' result as a int column. -#' -#' @param x Column to compute on. +#' @details +#' \code{ascii}: Computes the numeric value of the first character of the string column, +#' and returns the result as an int column. #' -#' @rdname ascii -#' @name ascii -#' @family string functions +#' @rdname column_string_functions #' @export -#' @aliases ascii,Column-method -#' @examples \dontrun{\dontrun{ascii(df$c)}} +#' @aliases ascii ascii,Column-method +#' @examples +#' +#' \dontrun{ +#' head(select(df, ascii(df$Class), ascii(df$Sex)))} #' @note ascii since 1.5.0 setMethod("ascii", signature(x = "Column"), @@ -256,19 +275,22 @@ setMethod("avg", column(jc) }) -#' base64 -#' -#' Computes the BASE64 encoding of a binary column and returns it as a string column. -#' This is the reverse of unbase64. -#' -#' @param x Column to compute on. +#' @details +#' \code{base64}: Computes the BASE64 encoding of a binary column and returns it as +#' a string column. This is the reverse of unbase64. #' -#' @rdname base64 -#' @name base64 -#' @family string functions +#' @rdname column_string_functions #' @export -#' @aliases base64,Column-method -#' @examples \dontrun{base64(df$c)} +#' @aliases base64 base64,Column-method +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, s1 = encode(df$Class, "UTF-8")) +#' str(tmp) +#' tmp2 <- mutate(tmp, s2 = base64(tmp$s1), s3 = decode(tmp$s1, "UTF-8"), +#' s4 = soundex(tmp$Sex)) +#' head(tmp2) +#' head(select(tmp2, unbase64(tmp2$s2)))} #' @note base64 since 1.5.0 setMethod("base64", signature(x = "Column"), @@ -620,20 +642,16 @@ setMethod("dayofyear", column(jc) }) -#' decode -#' -#' Computes the first argument into a string from a binary using the provided character set -#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). +#' @details +#' \code{decode}: Computes the first argument into a string from a binary using the provided +#' character set. #' -#' @param x Column to compute on. -#' @param charset Character set to use +#' @param charset Character set to use (one of "US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", +#' "UTF-16LE", "UTF-16"). #' -#' @rdname decode -#' @name decode -#' @family string functions -#' @aliases decode,Column,character-method +#' @rdname column_string_functions +#' @aliases decode decode,Column,character-method #' @export -#' @examples \dontrun{decode(df$c, "UTF-8")} #' @note decode since 1.6.0 setMethod("decode", signature(x = "Column", charset = "character"), @@ -642,20 +660,13 @@ setMethod("decode", column(jc) }) -#' encode -#' -#' Computes the first argument into a binary from a string using the provided character set -#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). -#' -#' @param x Column to compute on. -#' @param charset Character set to use +#' @details +#' \code{encode}: Computes the first argument into a binary from a string using the provided +#' character set. #' -#' @rdname encode -#' @name encode -#' @family string functions -#' @aliases encode,Column,character-method +#' @rdname column_string_functions +#' @aliases encode encode,Column,character-method #' @export -#' @examples \dontrun{encode(df$c, "UTF-8")} #' @note encode since 1.6.0 setMethod("encode", signature(x = "Column", charset = "character"), @@ -788,21 +799,23 @@ setMethod("hour", column(jc) }) -#' initcap -#' -#' Returns a new string column by converting the first letter of each word to uppercase. -#' Words are delimited by whitespace. -#' -#' For example, "hello world" will become "Hello World". -#' -#' @param x Column to compute on. +#' @details +#' \code{initcap}: Returns a new string column by converting the first letter of +#' each word to uppercase. Words are delimited by whitespace. For example, "hello world" +#' will become "Hello World". #' -#' @rdname initcap -#' @name initcap -#' @family string functions -#' @aliases initcap,Column-method +#' @rdname column_string_functions +#' @aliases initcap initcap,Column-method #' @export -#' @examples \dontrun{initcap(df$c)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, sex_lower = lower(df$Sex), age_upper = upper(df$age), +#' sex_age = concat_ws(" ", lower(df$sex), lower(df$age))) +#' head(tmp) +#' tmp2 <- mutate(tmp, s1 = initcap(tmp$sex_lower), s2 = initcap(tmp$sex_age), +#' s3 = reverse(df$Sex)) +#' head(tmp2)} #' @note initcap since 1.5.0 setMethod("initcap", signature(x = "Column"), @@ -918,18 +931,12 @@ setMethod("last_day", column(jc) }) -#' length -#' -#' Computes the length of a given string or binary column. -#' -#' @param x Column to compute on. +#' @details +#' \code{length}: Computes the length of a given string or binary column. #' -#' @rdname length -#' @name length -#' @aliases length,Column-method -#' @family string functions +#' @rdname column_string_functions +#' @aliases length length,Column-method #' @export -#' @examples \dontrun{length(df$c)} #' @note length since 1.5.0 setMethod("length", signature(x = "Column"), @@ -994,18 +1001,12 @@ setMethod("log2", column(jc) }) -#' lower -#' -#' Converts a string column to lower case. -#' -#' @param x Column to compute on. +#' @details +#' \code{lower}: Converts a string column to lower case. #' -#' @rdname lower -#' @name lower -#' @family string functions -#' @aliases lower,Column-method +#' @rdname column_string_functions +#' @aliases lower lower,Column-method #' @export -#' @examples \dontrun{lower(df$c)} #' @note lower since 1.4.0 setMethod("lower", signature(x = "Column"), @@ -1014,18 +1015,24 @@ setMethod("lower", column(jc) }) -#' ltrim -#' -#' Trim the spaces from left end for the specified string value. -#' -#' @param x Column to compute on. +#' @details +#' \code{ltrim}: Trims the spaces from left end for the specified string value. #' -#' @rdname ltrim -#' @name ltrim -#' @family string functions -#' @aliases ltrim,Column-method +#' @rdname column_string_functions +#' @aliases ltrim ltrim,Column-method #' @export -#' @examples \dontrun{ltrim(df$c)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, SexLpad = lpad(df$Sex, 6, " "), SexRpad = rpad(df$Sex, 7, " ")) +#' head(select(tmp, length(tmp$Sex), length(tmp$SexLpad), length(tmp$SexRpad))) +#' tmp2 <- mutate(tmp, SexLtrim = ltrim(tmp$SexLpad), SexRtrim = rtrim(tmp$SexRpad), +#' SexTrim = trim(tmp$SexLpad)) +#' head(select(tmp2, length(tmp2$Sex), length(tmp2$SexLtrim), +#' length(tmp2$SexRtrim), length(tmp2$SexTrim))) +#' +#' tmp <- mutate(df, SexLpad = lpad(df$Sex, 6, "xx"), SexRpad = rpad(df$Sex, 7, "xx")) +#' head(tmp)} #' @note ltrim since 1.5.0 setMethod("ltrim", signature(x = "Column"), @@ -1198,18 +1205,12 @@ setMethod("quarter", column(jc) }) -#' reverse -#' -#' Reverses the string column and returns it as a new string column. -#' -#' @param x Column to compute on. +#' @details +#' \code{reverse}: Reverses the string column and returns it as a new string column. #' -#' @rdname reverse -#' @name reverse -#' @family string functions -#' @aliases reverse,Column-method +#' @rdname column_string_functions +#' @aliases reverse reverse,Column-method #' @export -#' @examples \dontrun{reverse(df$c)} #' @note reverse since 1.5.0 setMethod("reverse", signature(x = "Column"), @@ -1268,18 +1269,12 @@ setMethod("bround", column(jc) }) -#' rtrim -#' -#' Trim the spaces from right end for the specified string value. -#' -#' @param x Column to compute on. +#' @details +#' \code{rtrim}: Trims the spaces from right end for the specified string value. #' -#' @rdname rtrim -#' @name rtrim -#' @family string functions -#' @aliases rtrim,Column-method +#' @rdname column_string_functions +#' @aliases rtrim rtrim,Column-method #' @export -#' @examples \dontrun{rtrim(df$c)} #' @note rtrim since 1.5.0 setMethod("rtrim", signature(x = "Column"), @@ -1409,18 +1404,12 @@ setMethod("skewness", column(jc) }) -#' soundex -#' -#' Return the soundex code for the specified expression. -#' -#' @param x Column to compute on. +#' @details +#' \code{soundex}: Returns the soundex code for the specified expression. #' -#' @rdname soundex -#' @name soundex -#' @family string functions -#' @aliases soundex,Column-method +#' @rdname column_string_functions +#' @aliases soundex soundex,Column-method #' @export -#' @examples \dontrun{soundex(df$c)} #' @note soundex since 1.5.0 setMethod("soundex", signature(x = "Column"), @@ -1731,18 +1720,12 @@ setMethod("to_timestamp", column(jc) }) -#' trim -#' -#' Trim the spaces from both ends for the specified string column. -#' -#' @param x Column to compute on. +#' @details +#' \code{trim}: Trims the spaces from both ends for the specified string column. #' -#' @rdname trim -#' @name trim -#' @family string functions -#' @aliases trim,Column-method +#' @rdname column_string_functions +#' @aliases trim trim,Column-method #' @export -#' @examples \dontrun{trim(df$c)} #' @note trim since 1.5.0 setMethod("trim", signature(x = "Column"), @@ -1751,19 +1734,13 @@ setMethod("trim", column(jc) }) -#' unbase64 -#' -#' Decodes a BASE64 encoded string column and returns it as a binary column. +#' @details +#' \code{unbase64}: Decodes a BASE64 encoded string column and returns it as a binary column. #' This is the reverse of base64. #' -#' @param x Column to compute on. -#' -#' @rdname unbase64 -#' @name unbase64 -#' @family string functions -#' @aliases unbase64,Column-method +#' @rdname column_string_functions +#' @aliases unbase64 unbase64,Column-method #' @export -#' @examples \dontrun{unbase64(df$c)} #' @note unbase64 since 1.5.0 setMethod("unbase64", signature(x = "Column"), @@ -1787,18 +1764,12 @@ setMethod("unhex", column(jc) }) -#' upper -#' -#' Converts a string column to upper case. -#' -#' @param x Column to compute on. +#' @details +#' \code{upper}: Converts a string column to upper case. #' -#' @rdname upper -#' @name upper -#' @family string functions -#' @aliases upper,Column-method +#' @rdname column_string_functions +#' @aliases upper upper,Column-method #' @export -#' @examples \dontrun{upper(df$c)} #' @note upper since 1.4.0 setMethod("upper", signature(x = "Column"), @@ -1949,19 +1920,19 @@ setMethod("hypot", signature(y = "Column"), column(jc) }) -#' levenshtein -#' -#' Computes the Levenshtein distance of the two given string columns. -#' -#' @param x Column to compute on. -#' @param y Column to compute on. +#' @details +#' \code{levenshtein}: Computes the Levenshtein distance of the two given string columns. #' -#' @rdname levenshtein -#' @name levenshtein -#' @family string functions -#' @aliases levenshtein,Column-method +#' @rdname column_string_functions +#' @aliases levenshtein levenshtein,Column-method #' @export -#' @examples \dontrun{levenshtein(df$c, x)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, d1 = levenshtein(df$Class, df$Sex), +#' d2 = levenshtein(df$Age, df$Sex), +#' d3 = levenshtein(df$Age, df$Age)) +#' head(tmp)} #' @note levenshtein since 1.5.0 setMethod("levenshtein", signature(y = "Column"), function(y, x) { @@ -2061,20 +2032,22 @@ setMethod("countDistinct", column(jc) }) - -#' concat -#' -#' Concatenates multiple input string columns together into a single string column. -#' -#' @param x Column to compute on -#' @param ... other columns +#' @details +#' \code{concat}: Concatenates multiple input string columns together into a single string column. #' -#' @family string functions -#' @rdname concat -#' @name concat -#' @aliases concat,Column-method +#' @rdname column_string_functions +#' @aliases concat concat,Column-method #' @export -#' @examples \dontrun{concat(df$strings, df$strings2)} +#' @examples +#' +#' \dontrun{ +#' # concatenate strings +#' tmp <- mutate(df, s1 = concat(df$Class, df$Sex), +#' s2 = concat(df$Class, df$Sex, df$Age), +#' s3 = concat(df$Class, df$Sex, df$Age, df$Class), +#' s4 = concat_ws("_", df$Class, df$Sex), +#' s5 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived)) +#' head(tmp)} #' @note concat since 1.5.0 setMethod("concat", signature(x = "Column"), @@ -2243,22 +2216,21 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), column(jc) }) -#' instr -#' -#' Locate the position of the first occurrence of substr column in the given string. -#' Returns null if either of the arguments are null. -#' -#' Note: The position is not zero based, but 1 based index. Returns 0 if substr -#' could not be found in str. +#' @details +#' \code{instr}: Locates the position of the first occurrence of a substring (\code{x}) +#' in the given string column (\code{y}). Returns null if either of the arguments are null. +#' Note: The position is not zero based, but 1 based index. Returns 0 if the substring +#' could not be found in the string column. #' -#' @param y column to check -#' @param x substring to check -#' @family string functions -#' @aliases instr,Column,character-method -#' @rdname instr -#' @name instr +#' @rdname column_string_functions +#' @aliases instr instr,Column,character-method #' @export -#' @examples \dontrun{instr(df$c, 'b')} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, s1 = instr(df$Sex, "m"), s2 = instr(df$Sex, "M"), +#' s3 = locate("m", df$Sex), s4 = locate("m", df$Sex, pos = 4)) +#' head(tmp)} #' @note instr since 1.5.0 setMethod("instr", signature(y = "Column", x = "character"), function(y, x) { @@ -2345,22 +2317,22 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), column(jc) }) -#' format_number -#' -#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places -#' with HALF_EVEN round mode, and returns the result as a string column. -#' -#' If x is 0, the result has no decimal point or fractional part. -#' If x < 0, the result will be null. +#' @details +#' \code{format_number}: Formats numeric column \code{y} to a format like '#,###,###.##', +#' rounded to \code{x} decimal places with HALF_EVEN round mode, and returns the result +#' as a string column. +#' If \code{x} is 0, the result has no decimal point or fractional part. +#' If \code{x} < 0, the result will be null. #' -#' @param y column to format -#' @param x number of decimal place to format to -#' @family string functions -#' @rdname format_number -#' @name format_number -#' @aliases format_number,Column,numeric-method +#' @rdname column_string_functions +#' @aliases format_number format_number,Column,numeric-method #' @export -#' @examples \dontrun{format_number(df$n, 4)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, v1 = df$Freq/3) +#' head(select(tmp, format_number(tmp$v1, 0), format_number(tmp$v1, 2), +#' format_string("%4.2f %s", tmp$v1, tmp$Sex)), 10)} #' @note format_number since 1.5.0 setMethod("format_number", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2438,21 +2410,14 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), column(jc) }) -#' concat_ws -#' -#' Concatenates multiple input string columns together into a single string column, -#' using the given separator. +#' @details +#' \code{concat_ws}: Concatenates multiple input string columns together into a single +#' string column, using the given separator. #' -#' @param x column to concatenate. #' @param sep separator to use. -#' @param ... other columns to concatenate. -#' -#' @family string functions -#' @rdname concat_ws -#' @name concat_ws -#' @aliases concat_ws,character,Column-method +#' @rdname column_string_functions +#' @aliases concat_ws concat_ws,character,Column-method #' @export -#' @examples \dontrun{concat_ws('-', df$s, df$d)} #' @note concat_ws since 1.5.0 setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { @@ -2499,19 +2464,14 @@ setMethod("expr", signature(x = "character"), column(jc) }) -#' format_string -#' -#' Formats the arguments in printf-style and returns the result as a string column. +#' @details +#' \code{format_string}: Formats the arguments in printf-style and returns the result +#' as a string column. #' #' @param format a character object of format strings. -#' @param x a Column. -#' @param ... additional Column(s). -#' @family string functions -#' @rdname format_string -#' @name format_string -#' @aliases format_string,character,Column-method +#' @rdname column_string_functions +#' @aliases format_string format_string,character,Column-method #' @export -#' @examples \dontrun{format_string('%d %s', df$a, df$b)} #' @note format_string since 1.5.0 setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { @@ -2620,23 +2580,17 @@ setMethod("window", signature(x = "Column"), column(jc) }) -#' locate -#' -#' Locate the position of the first occurrence of substr. -#' +#' @details +#' \code{locate}: Locates the position of the first occurrence of substr. #' Note: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param substr a character string to be matched. #' @param str a Column where matches are sought for each entry. #' @param pos start position of search. -#' @param ... further arguments to be passed to or from other methods. -#' @family string functions -#' @rdname locate -#' @aliases locate,character,Column-method -#' @name locate +#' @rdname column_string_functions +#' @aliases locate locate,character,Column-method #' @export -#' @examples \dontrun{locate('b', df$c, 1)} #' @note locate since 1.5.0 setMethod("locate", signature(substr = "character", str = "Column"), function(substr, str, pos = 1) { @@ -2646,19 +2600,14 @@ setMethod("locate", signature(substr = "character", str = "Column"), column(jc) }) -#' lpad -#' -#' Left-pad the string column with +#' @details +#' \code{lpad}: Left-padded with pad to a length of len. #' -#' @param x the string Column to be left-padded. #' @param len maximum length of each output result. #' @param pad a character string to be padded with. -#' @family string functions -#' @rdname lpad -#' @aliases lpad,Column,numeric,character-method -#' @name lpad +#' @rdname column_string_functions +#' @aliases lpad lpad,Column,numeric,character-method #' @export -#' @examples \dontrun{lpad(df$c, 6, '#')} #' @note lpad since 1.5.0 setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2728,20 +2677,27 @@ setMethod("randn", signature(seed = "numeric"), column(jc) }) -#' regexp_extract -#' -#' Extract a specific \code{idx} group identified by a Java regex, from the specified string column. -#' If the regex did not match, or the specified group did not match, an empty string is returned. +#' @details +#' \code{regexp_extract}: Extracts a specific \code{idx} group identified by a Java regex, +#' from the specified string column. If the regex did not match, or the specified group did +#' not match, an empty string is returned. #' -#' @param x a string Column. #' @param pattern a regular expression. #' @param idx a group index. -#' @family string functions -#' @rdname regexp_extract -#' @name regexp_extract -#' @aliases regexp_extract,Column,character,numeric-method +#' @rdname column_string_functions +#' @aliases regexp_extract regexp_extract,Column,character,numeric-method #' @export -#' @examples \dontrun{regexp_extract(df$c, '(\d+)-(\d+)', 1)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, s1 = regexp_extract(df$Class, "(\\d+)\\w+", 1), +#' s2 = regexp_extract(df$Sex, "^(\\w)\\w+", 1), +#' s3 = regexp_replace(df$Class, "\\D+", ""), +#' s4 = substring_index(df$Sex, "a", 1), +#' s5 = substring_index(df$Sex, "a", -1), +#' s6 = translate(df$Sex, "ale", ""), +#' s7 = translate(df$Sex, "a", "-")) +#' head(tmp)} #' @note regexp_extract since 1.5.0 setMethod("regexp_extract", signature(x = "Column", pattern = "character", idx = "numeric"), @@ -2752,19 +2708,14 @@ setMethod("regexp_extract", column(jc) }) -#' regexp_replace -#' -#' Replace all substrings of the specified string value that match regexp with rep. +#' @details +#' \code{regexp_replace}: Replaces all substrings of the specified string value that +#' match regexp with rep. #' -#' @param x a string Column. -#' @param pattern a regular expression. #' @param replacement a character string that a matched \code{pattern} is replaced with. -#' @family string functions -#' @rdname regexp_replace -#' @name regexp_replace -#' @aliases regexp_replace,Column,character,character-method +#' @rdname column_string_functions +#' @aliases regexp_replace regexp_replace,Column,character,character-method #' @export -#' @examples \dontrun{regexp_replace(df$c, '(\\d+)', '--')} #' @note regexp_replace since 1.5.0 setMethod("regexp_replace", signature(x = "Column", pattern = "character", replacement = "character"), @@ -2775,19 +2726,12 @@ setMethod("regexp_replace", column(jc) }) -#' rpad -#' -#' Right-padded with pad to a length of len. +#' @details +#' \code{rpad}: Right-padded with pad to a length of len. #' -#' @param x the string Column to be right-padded. -#' @param len maximum length of each output result. -#' @param pad a character string to be padded with. -#' @family string functions -#' @rdname rpad -#' @name rpad -#' @aliases rpad,Column,numeric,character-method +#' @rdname column_string_functions +#' @aliases rpad rpad,Column,numeric,character-method #' @export -#' @examples \dontrun{rpad(df$c, 6, '#')} #' @note rpad since 1.5.0 setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2797,28 +2741,20 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), column(jc) }) -#' substring_index -#' -#' Returns the substring from string str before count occurrences of the delimiter delim. -#' If count is positive, everything the left of the final delimiter (counting from left) is -#' returned. If count is negative, every to the right of the final delimiter (counting from the -#' right) is returned. substring_index performs a case-sensitive match when searching for delim. +#' @details +#' \code{substring_index}: Returns the substring from string str before count occurrences of +#' the delimiter delim. If count is positive, everything the left of the final delimiter +#' (counting from left) is returned. If count is negative, every to the right of the final +#' delimiter (counting from the right) is returned. substring_index performs a case-sensitive +#' match when searching for delim. #' -#' @param x a Column. #' @param delim a delimiter string. #' @param count number of occurrences of \code{delim} before the substring is returned. #' A positive number means counting from the left, while negative means #' counting from the right. -#' @family string functions -#' @rdname substring_index -#' @aliases substring_index,Column,character,numeric-method -#' @name substring_index +#' @rdname column_string_functions +#' @aliases substring_index substring_index,Column,character,numeric-method #' @export -#' @examples -#'\dontrun{ -#'substring_index(df$c, '.', 2) -#'substring_index(df$c, '.', -1) -#'} #' @note substring_index since 1.5.0 setMethod("substring_index", signature(x = "Column", delim = "character", count = "numeric"), @@ -2829,24 +2765,19 @@ setMethod("substring_index", column(jc) }) -#' translate -#' -#' Translate any character in the src by a character in replaceString. +#' @details +#' \code{translate}: Translates any character in the src by a character in replaceString. #' The characters in replaceString is corresponding to the characters in matchingString. #' The translate will happen when any character in the string matching with the character #' in the matchingString. #' -#' @param x a string Column. #' @param matchingString a source string where each character will be translated. #' @param replaceString a target string where each \code{matchingString} character will #' be replaced by the character in \code{replaceString} #' at the same location, if any. -#' @family string functions -#' @rdname translate -#' @name translate -#' @aliases translate,Column,character,character-method +#' @rdname column_string_functions +#' @aliases translate translate,Column,character,character-method #' @export -#' @examples \dontrun{translate(df$c, 'rnlt', '123')} #' @note translate since 1.5.0 setMethod("translate", signature(x = "Column", matchingString = "character", replaceString = "character"), @@ -3419,28 +3350,20 @@ setMethod("collect_set", column(jc) }) -#' split_string -#' -#' Splits string on regular expression. -#' -#' Equivalent to \code{split} SQL function -#' -#' @param x Column to compute on -#' @param pattern Java regular expression +#' @details +#' \code{split_string}: Splits string on regular expression. +#' Equivalent to \code{split} SQL function. #' -#' @rdname split_string -#' @family string functions -#' @aliases split_string,Column-method +#' @rdname column_string_functions +#' @aliases split_string split_string,Column-method #' @export #' @examples -#' \dontrun{ -#' df <- read.text("README.md") -#' -#' head(select(df, split_string(df$value, "\\s+"))) #' +#' \dontrun{ +#' head(select(df, split_string(df$Sex, "a"))) +#' head(select(df, split_string(df$Class, "\\d"))) #' # This is equivalent to the following SQL expression -#' head(selectExpr(df, "split(value, '\\\\s+')")) -#' } +#' head(selectExpr(df, "split(Class, '\\\\d')"))} #' @note split_string 2.3.0 setMethod("split_string", signature(x = "Column", pattern = "character"), @@ -3449,28 +3372,20 @@ setMethod("split_string", column(jc) }) -#' repeat_string -#' -#' Repeats string n times. -#' -#' Equivalent to \code{repeat} SQL function +#' @details +#' \code{repeat_string}: Repeats string n times. +#' Equivalent to \code{repeat} SQL function. #' -#' @param x Column to compute on #' @param n Number of repetitions -#' -#' @rdname repeat_string -#' @family string functions -#' @aliases repeat_string,Column-method +#' @rdname column_string_functions +#' @aliases repeat_string repeat_string,Column-method #' @export #' @examples -#' \dontrun{ -#' df <- read.text("README.md") -#' -#' first(select(df, repeat_string(df$value, 3))) #' +#' \dontrun{ +#' head(select(df, repeat_string(df$Class, 3))) #' # This is equivalent to the following SQL expression -#' first(selectExpr(df, "repeat(value, 3)")) -#' } +#' head(selectExpr(df, "repeat(Class, 3)"))} #' @note repeat_string since 2.3.0 setMethod("repeat_string", signature(x = "Column", n = "numeric"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0248ec585d771..dc99e3d94b269 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -917,8 +917,9 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @export setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) -#' @rdname ascii +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) #' @param x Column to compute on or a GroupedData object. @@ -927,8 +928,9 @@ setGeneric("ascii", function(x) { standardGeneric("ascii") }) #' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) -#' @rdname base64 +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("base64", function(x) { standardGeneric("base64") }) #' @rdname column_math_functions @@ -969,12 +971,14 @@ setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @export setGeneric("column", function(x) { standardGeneric("column") }) -#' @rdname concat +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("concat", function(x, ...) { standardGeneric("concat") }) -#' @rdname concat_ws +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) #' @rdname column_math_functions @@ -1038,8 +1042,9 @@ setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @name NULL setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) -#' @rdname decode +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @param x empty. Should be used with no argument. @@ -1047,8 +1052,9 @@ setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @export setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) -#' @rdname encode +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) #' @rdname explode @@ -1068,12 +1074,14 @@ setGeneric("expr", function(x) { standardGeneric("expr") }) #' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) -#' @rdname format_number +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) -#' @rdname format_string +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) #' @rdname from_json @@ -1114,8 +1122,9 @@ setGeneric("hour", function(x) { standardGeneric("hour") }) #' @name NULL setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) -#' @rdname initcap +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("initcap", function(x) { standardGeneric("initcap") }) #' @param x empty. Should be used with no argument. @@ -1124,8 +1133,9 @@ setGeneric("initcap", function(x) { standardGeneric("initcap") }) setGeneric("input_file_name", function(x = "missing") { standardGeneric("input_file_name") }) -#' @rdname instr +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("instr", function(y, x) { standardGeneric("instr") }) #' @rdname is.nan @@ -1158,28 +1168,33 @@ setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("l #' @export setGeneric("least", function(x, ...) { standardGeneric("least") }) -#' @rdname levenshtein +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) #' @rdname lit #' @export setGeneric("lit", function(x) { standardGeneric("lit") }) -#' @rdname locate +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("locate", function(substr, str, ...) { standardGeneric("locate") }) -#' @rdname lower +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("lower", function(x) { standardGeneric("lower") }) -#' @rdname lpad +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) -#' @rdname ltrim +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) #' @rdname md5 @@ -1272,21 +1287,25 @@ setGeneric("randn", function(seed) { standardGeneric("randn") }) #' @export setGeneric("rank", function(x, ...) { standardGeneric("rank") }) -#' @rdname regexp_extract +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp_extract") }) -#' @rdname regexp_replace +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("regexp_replace", function(x, pattern, replacement) { standardGeneric("regexp_replace") }) -#' @rdname repeat_string +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) -#' @rdname reverse +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @rdname column_math_functions @@ -1299,12 +1318,14 @@ setGeneric("rint", function(x) { standardGeneric("rint") }) #' @export setGeneric("row_number", function(x = "missing") { standardGeneric("row_number") }) -#' @rdname rpad +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) -#' @rdname rtrim +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) #' @rdname column_aggregate_functions @@ -1358,12 +1379,14 @@ setGeneric("skewness", function(x) { standardGeneric("skewness") }) #' @export setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) -#' @rdname split_string +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") }) -#' @rdname soundex +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("soundex", function(x) { standardGeneric("soundex") }) #' @param x empty. Should be used with no argument. @@ -1390,8 +1413,9 @@ setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) #' @export setGeneric("struct", function(x, ...) { standardGeneric("struct") }) -#' @rdname substring_index +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) #' @rdname column_aggregate_functions @@ -1428,16 +1452,19 @@ setGeneric("to_timestamp", function(x, format) { standardGeneric("to_timestamp") #' @name NULL setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) -#' @rdname translate +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("translate", function(x, matchingString, replaceString) { standardGeneric("translate") }) -#' @rdname trim +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("trim", function(x) { standardGeneric("trim") }) -#' @rdname unbase64 +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) #' @rdname column_math_functions @@ -1450,8 +1477,9 @@ setGeneric("unhex", function(x) { standardGeneric("unhex") }) #' @name NULL setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) -#' @rdname upper +#' @rdname column_string_functions #' @export +#' @name NULL setGeneric("upper", function(x) { standardGeneric("upper") }) #' @rdname column_aggregate_functions From 0c8444cf6d0620cd219ddcf5f50b12ff648639e9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 29 Jun 2017 10:32:32 +0800 Subject: [PATCH 095/118] [SPARK-14657][SPARKR][ML] RFormula w/o intercept should output reference category when encoding string terms ## What changes were proposed in this pull request? Please see [SPARK-14657](https://issues.apache.org/jira/browse/SPARK-14657) for detail of this bug. I searched online and test some other cases, found when we fit R glm model(or other models powered by R formula) w/o intercept on a dataset including string/category features, one of the categories in the first category feature is being used as reference category, we will not drop any category for that feature. I think we should keep consistent semantics between Spark RFormula and R formula. ## How was this patch tested? Add standard unit tests. cc mengxr Author: Yanbo Liang Closes #12414 from yanboliang/spark-14657. --- .../apache/spark/ml/feature/RFormula.scala | 10 ++- .../spark/ml/feature/RFormulaSuite.scala | 83 +++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 1fad0a6fc9443..4b44878784c90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -205,12 +205,20 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) }.toMap // Then we handle one-hot encoding and interactions between terms. + var keepReferenceCategory = false val encodedTerms = resolvedFormula.terms.map { case Seq(term) if dataset.schema(term).dataType == StringType => val encodedCol = tmpColumn("onehot") - encoderStages += new OneHotEncoder() + var encoder = new OneHotEncoder() .setInputCol(indexed(term)) .setOutputCol(encodedCol) + // Formula w/o intercept, one of the categories in the first category feature is + // being used as reference category, we will not drop any category for that feature. + if (!hasIntercept && !keepReferenceCategory) { + encoder = encoder.setDropLast(false) + keepReferenceCategory = true + } + encoderStages += encoder prefixesToRewrite(encodedCol + "_") = term + "_" encodedCol case Seq(term) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 41d0062c2cabd..23570d6e0b4cb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -213,6 +213,89 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(result.collect() === expected.collect()) } + test("formula w/o intercept, we should output reference category when encoding string terms") { + /* + R code: + + df <- data.frame(id = c(1, 2, 3, 4), + a = c("foo", "bar", "bar", "baz"), + b = c("zq", "zz", "zz", "zz"), + c = c(4, 4, 5, 5)) + model.matrix(id ~ a + b + c - 1, df) + + abar abaz afoo bzz c + 1 0 0 1 0 4 + 2 1 0 0 1 4 + 3 1 0 0 1 5 + 4 0 1 0 1 5 + + model.matrix(id ~ a:b + c - 1, df) + + c abar:bzq abaz:bzq afoo:bzq abar:bzz abaz:bzz afoo:bzz + 1 4 0 0 1 0 0 0 + 2 4 0 0 0 1 0 0 + 3 5 0 0 0 1 0 0 + 4 5 0 0 0 0 1 0 + */ + val original = Seq((1, "foo", "zq", 4), (2, "bar", "zz", 4), (3, "bar", "zz", 5), + (4, "baz", "zz", 5)).toDF("id", "a", "b", "c") + + val formula1 = new RFormula().setFormula("id ~ a + b + c - 1") + .setStringIndexerOrderType(StringIndexer.alphabetDesc) + val model1 = formula1.fit(original) + val result1 = model1.transform(original) + val resultSchema1 = model1.transformSchema(original.schema) + // Note the column order is different between R and Spark. + val expected1 = Seq( + (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0), + (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0), + (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0), + (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) + ).toDF("id", "a", "b", "c", "features", "label") + assert(result1.schema.toString == resultSchema1.toString) + assert(result1.collect() === expected1.collect()) + + val attrs1 = AttributeGroup.fromStructField(result1.schema("features")) + val expectedAttrs1 = new AttributeGroup( + "features", + Array[Attribute]( + new BinaryAttribute(Some("a_foo"), Some(1)), + new BinaryAttribute(Some("a_baz"), Some(2)), + new BinaryAttribute(Some("a_bar"), Some(3)), + new BinaryAttribute(Some("b_zz"), Some(4)), + new NumericAttribute(Some("c"), Some(5)))) + assert(attrs1 === expectedAttrs1) + + // There is no impact for string terms interaction. + val formula2 = new RFormula().setFormula("id ~ a:b + c - 1") + .setStringIndexerOrderType(StringIndexer.alphabetDesc) + val model2 = formula2.fit(original) + val result2 = model2.transform(original) + val resultSchema2 = model2.transformSchema(original.schema) + // Note the column order is different between R and Spark. + val expected2 = Seq( + (1, "foo", "zq", 4, Vectors.sparse(7, Array(1, 6), Array(1.0, 4.0)), 1.0), + (2, "bar", "zz", 4, Vectors.sparse(7, Array(4, 6), Array(1.0, 4.0)), 2.0), + (3, "bar", "zz", 5, Vectors.sparse(7, Array(4, 6), Array(1.0, 5.0)), 3.0), + (4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0) + ).toDF("id", "a", "b", "c", "features", "label") + assert(result2.schema.toString == resultSchema2.toString) + assert(result2.collect() === expected2.collect()) + + val attrs2 = AttributeGroup.fromStructField(result2.schema("features")) + val expectedAttrs2 = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_foo:b_zz"), Some(1)), + new NumericAttribute(Some("a_foo:b_zq"), Some(2)), + new NumericAttribute(Some("a_baz:b_zz"), Some(3)), + new NumericAttribute(Some("a_baz:b_zq"), Some(4)), + new NumericAttribute(Some("a_bar:b_zz"), Some(5)), + new NumericAttribute(Some("a_bar:b_zq"), Some(6)), + new NumericAttribute(Some("c"), Some(7)))) + assert(attrs2 === expectedAttrs2) + } + test("index string label") { val formula = new RFormula().setFormula("id ~ a + b") val original = From db44f5f3e8b5bc28c33b154319539d51c05a089c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 28 Jun 2017 19:36:00 -0700 Subject: [PATCH 096/118] [SPARK-21224][R] Specify a schema by using a DDL-formatted string when reading in R ## What changes were proposed in this pull request? This PR proposes to support a DDL-formetted string as schema as below: ```r mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"Justin\", \"age\":19}") jsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLines, jsonPath) df <- read.df(jsonPath, "json", "name STRING, age DOUBLE") collect(df) ``` ## How was this patch tested? Tests added in `test_streaming.R` and `test_sparkSQL.R` and manual tests. Author: hyukjinkwon Closes #18431 from HyukjinKwon/r-ddl-schema. --- R/pkg/R/SQLContext.R | 38 +++++++++++++------ R/pkg/tests/fulltests/test_sparkSQL.R | 20 +++++++++- R/pkg/tests/fulltests/test_streaming.R | 23 +++++++++++ .../org/apache/spark/sql/api/r/SQLUtils.scala | 15 -------- 4 files changed, 67 insertions(+), 29 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index e3528bc7c3135..3b7f71bbbffb8 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -584,7 +584,7 @@ tableToDF <- function(tableName) { #' #' @param path The path of files to load #' @param source The name of external data source -#' @param schema The data schema defined in structType +#' @param schema The data schema defined in structType or a DDL-formatted string. #' @param na.strings Default string value for NA when source is "csv" #' @param ... additional external data source specific named properties. #' @return SparkDataFrame @@ -600,6 +600,8 @@ tableToDF <- function(tableName) { #' structField("info", "map")) #' df2 <- read.df(mapTypeJsonPath, "json", schema, multiLine = TRUE) #' df3 <- loadDF("data/test_table", "parquet", mergeSchema = "true") +#' stringSchema <- "name STRING, info MAP" +#' df4 <- read.df(mapTypeJsonPath, "json", stringSchema, multiLine = TRUE) #' } #' @name read.df #' @method read.df default @@ -623,14 +625,19 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.string if (source == "csv" && is.null(options[["nullValue"]])) { options[["nullValue"]] <- na.strings } + read <- callJMethod(sparkSession, "read") + read <- callJMethod(read, "format", source) if (!is.null(schema)) { - stopifnot(class(schema) == "structType") - sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, - source, schema$jobj, options) - } else { - sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, - source, options) + if (class(schema) == "structType") { + read <- callJMethod(read, "schema", schema$jobj) + } else if (is.character(schema)) { + read <- callJMethod(read, "schema", schema) + } else { + stop("schema should be structType or character.") + } } + read <- callJMethod(read, "options", options) + sdf <- handledCallJMethod(read, "load") dataFrame(sdf) } @@ -717,8 +724,8 @@ read.jdbc <- function(url, tableName, #' "spark.sql.sources.default" will be used. #' #' @param source The name of external data source -#' @param schema The data schema defined in structType, this is required for file-based streaming -#' data source +#' @param schema The data schema defined in structType or a DDL-formatted string, this is +#' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for #' file-based streaming data source #' @return SparkDataFrame @@ -733,6 +740,8 @@ read.jdbc <- function(url, tableName, #' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") #' #' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' stringSchema <- "name STRING, info MAP" +#' df1 <- read.stream("json", path = jsonDir, schema = stringSchema, maxFilesPerTrigger = 1) #' } #' @name read.stream #' @note read.stream since 2.2.0 @@ -750,10 +759,15 @@ read.stream <- function(source = NULL, schema = NULL, ...) { read <- callJMethod(sparkSession, "readStream") read <- callJMethod(read, "format", source) if (!is.null(schema)) { - stopifnot(class(schema) == "structType") - read <- callJMethod(read, "schema", schema$jobj) + if (class(schema) == "structType") { + read <- callJMethod(read, "schema", schema$jobj) + } else if (is.character(schema)) { + read <- callJMethod(read, "schema", schema) + } else { + stop("schema should be structType or character.") + } } read <- callJMethod(read, "options", options) sdf <- handledCallJMethod(read, "load") - dataFrame(callJMethod(sdf, "toDF")) + dataFrame(sdf) } diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 911b73b9ee551..a2bcb5aefe16d 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -3248,9 +3248,9 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. expect_error(read.df(source = "json"), - paste("Error in loadDF : analysis error - Unable to infer schema for JSON.", + paste("Error in load : analysis error - Unable to infer schema for JSON.", "It must be specified manually")) - expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") + expect_error(read.df("arbitrary_path"), "Error in load : analysis error - Path does not exist") expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist") expect_error(read.text("arbitrary_path"), "Error in text : analysis error - Path does not exist") expect_error(read.orc("arbitrary_path"), "Error in orc : analysis error - Path does not exist") @@ -3268,6 +3268,22 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume "Unnamed arguments ignored: 2, 3, a.") }) +test_that("Specify a schema by using a DDL-formatted string when reading", { + # Test read.df with a user defined schema in a DDL-formatted string. + df1 <- read.df(jsonPath, "json", "name STRING, age DOUBLE") + expect_is(df1, "SparkDataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + expect_error(read.df(jsonPath, "json", "name stri"), "DataType stri is not supported.") + + # Test loadDF with a user defined schema in a DDL-formatted string. + df2 <- loadDF(jsonPath, "json", "name STRING, age DOUBLE") + expect_is(df2, "SparkDataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) + + expect_error(loadDF(jsonPath, "json", "name stri"), "DataType stri is not supported.") +}) + test_that("Collect on DataFrame when NAs exists at the top of a timestamp column", { ldf <- data.frame(col1 = c(0, 1, 2), col2 = c(as.POSIXct("2017-01-01 00:00:01"), diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index d691de7cd725d..54f40bbd5f517 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -46,6 +46,8 @@ schema <- structType(structField("name", "string"), structField("age", "integer"), structField("count", "double")) +stringSchema <- "name STRING, age INTEGER, count DOUBLE" + test_that("read.stream, write.stream, awaitTermination, stopQuery", { df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_true(isStreaming(df)) @@ -111,6 +113,27 @@ test_that("Stream other format", { unlink(parquetPath) }) +test_that("Specify a schema by using a DDL-formatted string when reading", { + # Test read.stream with a user defined schema in a DDL-formatted string. + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + df <- read.df(jsonPath, "json", schema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = stringSchema) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people3", outputMode = "complete") + + expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") + expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) + + expect_error(read.stream(path = parquetPath, schema = "name stri"), + "DataType stri is not supported.") + + unlink(parquetPath) +}) + test_that("Non-streaming DataFrame", { c <- as.DataFrame(cars) expect_false(isStreaming(c)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d94e528a3ad47..9bd2987057dbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -193,21 +193,6 @@ private[sql] object SQLUtils extends Logging { } } - def loadDF( - sparkSession: SparkSession, - source: String, - options: java.util.Map[String, String]): DataFrame = { - sparkSession.read.format(source).options(options).load() - } - - def loadDF( - sparkSession: SparkSession, - source: String, - schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - sparkSession.read.format(source).schema(schema).options(options).load() - } - def readSqlObject(dis: DataInputStream, dataType: Char): Object = { dataType match { case 's' => From fc92d25f2a27e81ef2d5031dcf856af1cc1d8c31 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 28 Jun 2017 20:06:29 -0700 Subject: [PATCH 097/118] Revert "[SPARK-21094][R] Terminate R's worker processes in the parent of R's daemon to prevent a leak" This reverts commit 6b3d02285ee0debc73cbcab01b10398a498fbeb8. --- R/pkg/inst/worker/daemon.R | 59 +++----------------------------------- 1 file changed, 4 insertions(+), 55 deletions(-) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 6e385b2a27622..3a318b71ea06d 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -30,55 +30,8 @@ port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) -# Waits indefinitely for a socket connecion by default. -selectTimeout <- NULL - -# Exit code that children send to the parent to indicate they exited. -exitCode <- 1 - while (TRUE) { - ready <- socketSelect(list(inputCon), timeout = selectTimeout) - - # Note that the children should be terminated in the parent. If each child terminates - # itself, it appears that the resource is not released properly, that causes an unexpected - # termination of this daemon due to, for example, running out of file descriptors - # (see SPARK-21093). Therefore, the current implementation tries to retrieve children - # that are exited (but not terminated) and then sends a kill signal to terminate them properly - # in the parent. - # - # There are two paths that it attempts to send a signal to terminate the children in the parent. - # - # 1. Every second if any socket connection is not available and if there are child workers - # running. - # 2. Right after a socket connection is available. - # - # In other words, the parent attempts to send the signal to the children every second if - # any worker is running or right before launching other worker children from the following - # new socket connection. - - # Only the process IDs of children sent data to the parent are returned below. The children - # send a custom exit code to the parent after being exited and the parent tries - # to terminate them only if they sent the exit code. - children <- parallel:::selectChildren(timeout = 0) - - if (is.integer(children)) { - lapply(children, function(child) { - # This data should be raw bytes if any data was sent from this child. - # Otherwise, this returns the PID. - data <- parallel:::readChild(child) - if (is.raw(data)) { - # This checks if the data from this child is the exit code that indicates an exited child. - if (unserialize(data) == exitCode) { - # If so, we terminate this child. - tools::pskill(child, tools::SIGUSR1) - } - } - }) - } else if (is.null(children)) { - # If it is NULL, there are no children. Waits indefinitely for a socket connecion. - selectTimeout <- NULL - } - + ready <- socketSelect(list(inputCon)) if (ready) { port <- SparkR:::readInt(inputCon) # There is a small chance that it could be interrupted by signal, retry one time @@ -91,16 +44,12 @@ while (TRUE) { } p <- parallel:::mcfork() if (inherits(p, "masterProcess")) { - # Reach here because this is a child process. close(inputCon) Sys.setenv(SPARKR_WORKER_PORT = port) try(source(script)) - # Note that this mcexit does not fully terminate this child. So, this writes back - # a custom exit code so that the parent can read and terminate this child. - parallel:::mcexit(0L, send = exitCode) - } else { - # Forking succeeded and we need to check if they finished their jobs every second. - selectTimeout <- 1 + # Set SIGUSR1 so that child can exit + tools::pskill(Sys.getpid(), tools::SIGUSR1) + parallel:::mcexit(0L) } } } From 25c2edf6f9da9d4d45fc628cf97de657f2a2cc7e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 29 Jun 2017 11:21:50 +0800 Subject: [PATCH 098/118] [SPARK-21229][SQL] remove QueryPlan.preCanonicalized ## What changes were proposed in this pull request? `QueryPlan.preCanonicalized` is only overridden in a few places, and it does introduce an extra concept to `QueryPlan` which may confuse people. This PR removes it and override `canonicalized` in these places ## How was this patch tested? existing tests Author: Wenchen Fan Closes #18440 from cloud-fan/minor. --- .../sql/catalyst/catalog/interface.scala | 23 +++++++++++-------- .../spark/sql/catalyst/plans/QueryPlan.scala | 13 ++++------- .../sql/execution/DataSourceScanExec.scala | 8 +++++-- .../datasources/LogicalRelation.scala | 5 +++- 4 files changed, 27 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index b63bef9193332..da50b0e7e8e42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -27,7 +27,8 @@ import com.google.common.base.Objects import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, Literal} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, ExprId, Literal} +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier @@ -425,15 +426,17 @@ case class CatalogRelation( Objects.hashCode(tableMeta.identifier, output) } - override def preCanonicalized: LogicalPlan = copy(tableMeta = CatalogTable( - identifier = tableMeta.identifier, - tableType = tableMeta.tableType, - storage = CatalogStorageFormat.empty, - schema = tableMeta.schema, - partitionColumnNames = tableMeta.partitionColumnNames, - bucketSpec = tableMeta.bucketSpec, - createTime = -1 - )) + override lazy val canonicalized: LogicalPlan = copy( + tableMeta = tableMeta.copy( + storage = CatalogStorageFormat.empty, + createTime = -1 + ), + dataCols = dataCols.zipWithIndex.map { + case (attr, index) => attr.withExprId(ExprId(index)) + }, + partitionCols = partitionCols.zipWithIndex.map { + case (attr, index) => attr.withExprId(ExprId(index + dataCols.length)) + }) override def computeStats: Statistics = { // For data source tables, we will create a `LogicalRelation` and won't call this method, for diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 01b3da3f7c482..7addbaaa9afa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -188,12 +188,13 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same * result. * - * Some nodes should overwrite this to provide proper canonicalize logic. + * Some nodes should overwrite this to provide proper canonicalize logic, but they should remove + * expressions cosmetic variations themselves. */ lazy val canonicalized: PlanType = { val canonicalizedChildren = children.map(_.canonicalized) var id = -1 - preCanonicalized.mapExpressions { + mapExpressions { case a: Alias => id += 1 // As the root of the expression, Alias will always take an arbitrary exprId, we need to @@ -206,18 +207,12 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT // Top level `AttributeReference` may also be used for output like `Alias`, we should // normalize the epxrId too. id += 1 - ar.withExprId(ExprId(id)) + ar.withExprId(ExprId(id)).canonicalized case other => QueryPlan.normalizeExprId(other, allAttributes) }.withNewChildren(canonicalizedChildren) } - /** - * Do some simple transformation on this plan before canonicalizing. Implementations can override - * this method to provide customized canonicalize logic without rewriting the whole logic. - */ - protected def preCanonicalized: PlanType = this - /** * Returns true when the given query plan will return the same results as this query plan. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 74fc23a52a141..a0def68d88e0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -138,8 +138,12 @@ case class RowDataSourceScanExec( } // Only care about `relation` and `metadata` when canonicalizing. - override def preCanonicalized: SparkPlan = - copy(rdd = null, outputPartitioning = null, metastoreTableIdentifier = None) + override lazy val canonicalized: SparkPlan = + copy( + output.map(QueryPlan.normalizeExprId(_, output)), + rdd = null, + outputPartitioning = null, + metastoreTableIdentifier = None) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index c1b2895f1747e..6ba190b9e5dcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.util.Utils @@ -43,7 +44,9 @@ case class LogicalRelation( } // Only care about relation when canonicalizing. - override def preCanonicalized: LogicalPlan = copy(catalogTable = None) + override lazy val canonicalized: LogicalPlan = copy( + output = output.map(QueryPlan.normalizeExprId(_, output)), + catalogTable = None) @transient override def computeStats: Statistics = { catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( From 82e24912d6e15a9e4fbadd83da9a08d4f80a592b Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Thu, 29 Jun 2017 11:32:29 +0800 Subject: [PATCH 099/118] [SPARK-21237][SQL] Invalidate stats once table data is changed ## What changes were proposed in this pull request? Invalidate spark's stats after data changing commands: - InsertIntoHadoopFsRelationCommand - InsertIntoHiveTable - LoadDataCommand - TruncateTableCommand - AlterTableSetLocationCommand - AlterTableDropPartitionCommand ## How was this patch tested? Added test cases. Author: wangzhenhua Closes #18449 from wzhfy/removeStats. --- .../catalyst/catalog/ExternalCatalog.scala | 3 +- .../catalyst/catalog/InMemoryCatalog.scala | 4 +- .../sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../catalog/ExternalCatalogSuite.scala | 2 +- .../catalog/SessionCatalogSuite.scala | 2 +- .../command/AnalyzeColumnCommand.scala | 4 +- .../command/AnalyzeTableCommand.scala | 76 +--------- .../sql/execution/command/CommandUtils.scala | 102 ++++++++++++++ .../spark/sql/execution/command/ddl.scala | 9 +- .../spark/sql/execution/command/tables.scala | 7 + .../InsertIntoHadoopFsRelationCommand.scala | 5 + .../spark/sql/StatisticsCollectionSuite.scala | 85 ++++++++++-- .../apache/spark/sql/test/SQLTestUtils.scala | 14 ++ .../spark/sql/hive/HiveExternalCatalog.scala | 24 ++-- .../hive/execution/InsertIntoHiveTable.scala | 4 +- .../spark/sql/hive/StatisticsSuite.scala | 130 ++++++++++++++---- 16 files changed, 340 insertions(+), 133 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 12ba5aedde026..0254b6bb6d136 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -160,7 +160,8 @@ abstract class ExternalCatalog */ def alterTableSchema(db: String, table: String, schema: StructType): Unit - def alterTableStats(db: String, table: String, stats: CatalogStatistics): Unit + /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ + def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit def getTable(db: String, table: String): CatalogTable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 9820522a230e3..747190faa3c8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -315,10 +315,10 @@ class InMemoryCatalog( override def alterTableStats( db: String, table: String, - stats: CatalogStatistics): Unit = synchronized { + stats: Option[CatalogStatistics]): Unit = synchronized { requireTableExists(db, table) val origTable = catalog(db).tables(table).table - catalog(db).tables(table).table = origTable.copy(stats = Some(stats)) + catalog(db).tables(table).table = origTable.copy(stats = stats) } override def getTable(db: String, table: String): CatalogTable = synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index cf02da8993658..7ece77df7fc14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -380,7 +380,7 @@ class SessionCatalog( * Alter Spark's statistics of an existing metastore table identified by the provided table * identifier. */ - def alterTableStats(identifier: TableIdentifier, newStats: CatalogStatistics): Unit = { + def alterTableStats(identifier: TableIdentifier, newStats: Option[CatalogStatistics]): Unit = { val db = formatDatabaseName(identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(identifier.table) val tableIdentifier = TableIdentifier(table, Some(db)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 557b0970b54e5..c22d55fc96a65 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -260,7 +260,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac val oldTableStats = catalog.getTable("db2", "tbl1").stats assert(oldTableStats.isEmpty) val newStats = CatalogStatistics(sizeInBytes = 1) - catalog.alterTableStats("db2", "tbl1", newStats) + catalog.alterTableStats("db2", "tbl1", Some(newStats)) val newTableStats = catalog.getTable("db2", "tbl1").stats assert(newTableStats.get == newStats) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index a6dc21b03d446..fc3893e197792 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -454,7 +454,7 @@ abstract class SessionCatalogSuite extends AnalysisTest { val oldTableStats = catalog.getTableMetadata(tableId).stats assert(oldTableStats.isEmpty) val newStats = CatalogStatistics(sizeInBytes = 1) - catalog.alterTableStats(tableId, newStats) + catalog.alterTableStats(tableId, Some(newStats)) val newTableStats = catalog.getTableMetadata(tableId).stats assert(newTableStats.get == newStats) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 2f273b63e8348..6588993ef9ad9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -42,7 +42,7 @@ case class AnalyzeColumnCommand( if (tableMeta.tableType == CatalogTableType.VIEW) { throw new AnalysisException("ANALYZE TABLE is not supported on views.") } - val sizeInBytes = AnalyzeTableCommand.calculateTotalSize(sessionState, tableMeta) + val sizeInBytes = CommandUtils.calculateTotalSize(sessionState, tableMeta) // Compute stats for each column val (rowCount, newColStats) = computeColumnStats(sparkSession, tableIdentWithDB, columnNames) @@ -54,7 +54,7 @@ case class AnalyzeColumnCommand( // Newly computed column stats should override the existing ones. colStats = tableMeta.stats.map(_.colStats).getOrElse(Map.empty) ++ newColStats) - sessionState.catalog.alterTableStats(tableIdentWithDB, statistics) + sessionState.catalog.alterTableStats(tableIdentWithDB, Some(statistics)) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 13b8faff844c7..d780ef42f3fae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -17,18 +17,10 @@ package org.apache.spark.sql.execution.command -import java.net.URI - -import scala.util.control.NonFatal - -import org.apache.hadoop.fs.{FileSystem, Path} - -import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.internal.SessionState /** @@ -46,7 +38,7 @@ case class AnalyzeTableCommand( if (tableMeta.tableType == CatalogTableType.VIEW) { throw new AnalysisException("ANALYZE TABLE is not supported on views.") } - val newTotalSize = AnalyzeTableCommand.calculateTotalSize(sessionState, tableMeta) + val newTotalSize = CommandUtils.calculateTotalSize(sessionState, tableMeta) val oldTotalSize = tableMeta.stats.map(_.sizeInBytes.toLong).getOrElse(0L) val oldRowCount = tableMeta.stats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) @@ -74,7 +66,7 @@ case class AnalyzeTableCommand( // Update the metastore if the above statistics of the table are different from those // recorded in the metastore. if (newStats.isDefined) { - sessionState.catalog.alterTableStats(tableIdentWithDB, newStats.get) + sessionState.catalog.alterTableStats(tableIdentWithDB, newStats) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) } @@ -82,65 +74,3 @@ case class AnalyzeTableCommand( Seq.empty[Row] } } - -object AnalyzeTableCommand extends Logging { - - def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): Long = { - if (catalogTable.partitionColumnNames.isEmpty) { - calculateLocationSize(sessionState, catalogTable.identifier, catalogTable.storage.locationUri) - } else { - // Calculate table size as a sum of the visible partitions. See SPARK-21079 - val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) - partitions.map(p => - calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) - ).sum - } - } - - private def calculateLocationSize( - sessionState: SessionState, - tableId: TableIdentifier, - locationUri: Option[URI]): Long = { - // This method is mainly based on - // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) - // in Hive 0.13 (except that we do not use fs.getContentSummary). - // TODO: Generalize statistics collection. - // TODO: Why fs.getContentSummary returns wrong size on Jenkins? - // Can we use fs.getContentSummary in future? - // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use - // countFileSize to count the table size. - val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") - - def calculateLocationSize(fs: FileSystem, path: Path): Long = { - val fileStatus = fs.getFileStatus(path) - val size = if (fileStatus.isDirectory) { - fs.listStatus(path) - .map { status => - if (!status.getPath.getName.startsWith(stagingDir)) { - calculateLocationSize(fs, status.getPath) - } else { - 0L - } - }.sum - } else { - fileStatus.getLen - } - - size - } - - locationUri.map { p => - val path = new Path(p) - try { - val fs = path.getFileSystem(sessionState.newHadoopConf()) - calculateLocationSize(fs, path) - } catch { - case NonFatal(e) => - logWarning( - s"Failed to get the size of table ${tableId.table} in the " + - s"database ${tableId.database} because of ${e.toString}", e) - 0L - } - }.getOrElse(0L) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala new file mode 100644 index 0000000000000..92397607f38fd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -0,0 +1,102 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.command + +import java.net.URI + +import scala.util.control.NonFatal + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable} +import org.apache.spark.sql.internal.SessionState + + +object CommandUtils extends Logging { + + /** Change statistics after changing data by commands. */ + def updateTableStats(sparkSession: SparkSession, table: CatalogTable): Unit = { + if (table.stats.nonEmpty) { + val catalog = sparkSession.sessionState.catalog + catalog.alterTableStats(table.identifier, None) + } + } + + def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): BigInt = { + if (catalogTable.partitionColumnNames.isEmpty) { + calculateLocationSize(sessionState, catalogTable.identifier, catalogTable.storage.locationUri) + } else { + // Calculate table size as a sum of the visible partitions. See SPARK-21079 + val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) + partitions.map { p => + calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) + }.sum + } + } + + def calculateLocationSize( + sessionState: SessionState, + identifier: TableIdentifier, + locationUri: Option[URI]): Long = { + // This method is mainly based on + // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) + // in Hive 0.13 (except that we do not use fs.getContentSummary). + // TODO: Generalize statistics collection. + // TODO: Why fs.getContentSummary returns wrong size on Jenkins? + // Can we use fs.getContentSummary in future? + // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use + // countFileSize to count the table size. + val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") + + def getPathSize(fs: FileSystem, path: Path): Long = { + val fileStatus = fs.getFileStatus(path) + val size = if (fileStatus.isDirectory) { + fs.listStatus(path) + .map { status => + if (!status.getPath.getName.startsWith(stagingDir)) { + getPathSize(fs, status.getPath) + } else { + 0L + } + }.sum + } else { + fileStatus.getLen + } + + size + } + + locationUri.map { p => + val path = new Path(p) + try { + val fs = path.getFileSystem(sessionState.newHadoopConf()) + getPathSize(fs, path) + } catch { + case NonFatal(e) => + logWarning( + s"Failed to get the size of table ${identifier.table} in the " + + s"database ${identifier.database} because of ${e.toString}", e) + 0L + } + }.getOrElse(0L) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 413f5f3ba539c..ac897c1b22d77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -433,9 +433,11 @@ case class AlterTableAddPartitionCommand( sparkSession.sessionState.conf.resolver) // inherit table storage format (possibly except for location) CatalogTablePartition(normalizedSpec, table.storage.copy( - locationUri = location.map(CatalogUtils.stringToURI(_)))) + locationUri = location.map(CatalogUtils.stringToURI))) } catalog.createPartitions(table.identifier, parts, ignoreIfExists = ifNotExists) + + CommandUtils.updateTableStats(sparkSession, table) Seq.empty[Row] } @@ -519,6 +521,9 @@ case class AlterTableDropPartitionCommand( catalog.dropPartitions( table.identifier, normalizedSpecs, ignoreIfNotExists = ifExists, purge = purge, retainData = retainData) + + CommandUtils.updateTableStats(sparkSession, table) + Seq.empty[Row] } @@ -768,6 +773,8 @@ case class AlterTableSetLocationCommand( // No partition spec is specified, so we set the location for the table itself catalog.alterTable(table.withNewStorage(locationUri = Some(locUri))) } + + CommandUtils.updateTableStats(sparkSession, table) Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index b937a8a9f375b..8ded1060f7bf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -400,6 +400,7 @@ case class LoadDataCommand( // Refresh the metadata cache to ensure the data visible to the users catalog.refreshTable(targetTable.identifier) + CommandUtils.updateTableStats(sparkSession, targetTable) Seq.empty[Row] } } @@ -487,6 +488,12 @@ case class TruncateTableCommand( case NonFatal(e) => log.warn(s"Exception when attempting to uncache table $tableIdentWithDB", e) } + + if (table.stats.nonEmpty) { + // empty table after truncation + val newStats = CatalogStatistics(sizeInBytes = 0, rowCount = Some(0)) + catalog.alterTableStats(tableName, Some(newStats)) + } Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 00aa1240886e4..ab26f2affbce5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -161,6 +161,11 @@ case class InsertIntoHadoopFsRelationCommand( fileIndex.foreach(_.refresh()) // refresh data cache if table is cached sparkSession.catalog.refreshByPath(outputPath.toString) + + if (catalogTable.nonEmpty) { + CommandUtils.updateTableStats(sparkSession, catalogTable.get) + } + } else { logInfo("Skipping insertion into a relation that already exists.") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 9824062f969b3..b031c52dad8b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -40,17 +40,6 @@ import org.apache.spark.sql.types._ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with SharedSQLContext { import testImplicits._ - private def checkTableStats(tableName: String, expectedRowCount: Option[Int]) - : Option[CatalogStatistics] = { - val df = spark.table(tableName) - val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => - assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) - rel.catalogTable.get.stats - } - assert(stats.size == 1) - stats.head - } - test("estimates the size of a limit 0 on outer join") { withTempView("test") { Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") @@ -96,11 +85,11 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared // noscan won't count the number of rows sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") - checkTableStats(tableName, expectedRowCount = None) + checkTableStats(tableName, hasSizeInBytes = true, expectedRowCounts = None) // without noscan, we count the number of rows sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") - checkTableStats(tableName, expectedRowCount = Some(2)) + checkTableStats(tableName, hasSizeInBytes = true, expectedRowCounts = Some(2)) } } @@ -168,6 +157,60 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(stats.simpleString == expectedString) } } + + test("change stats after truncate command") { + val table = "change_stats_truncate_table" + withTable(table) { + spark.range(100).select($"id", $"id" % 5 as "value").write.saveAsTable(table) + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS id, value") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(100)) + assert(fetched1.get.sizeInBytes > 0) + assert(fetched1.get.colStats.size == 2) + + // truncate table command + sql(s"TRUNCATE TABLE $table") + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched2.get.sizeInBytes == 0) + assert(fetched2.get.colStats.isEmpty) + } + } + + test("change stats after set location command") { + val table = "change_stats_set_location_table" + withTable(table) { + spark.range(100).select($"id", $"id" % 5 as "value").write.saveAsTable(table) + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS id, value") + val fetched1 = checkTableStats( + table, hasSizeInBytes = true, expectedRowCounts = Some(100)) + assert(fetched1.get.sizeInBytes > 0) + assert(fetched1.get.colStats.size == 2) + + // set location command + withTempDir { newLocation => + sql(s"ALTER TABLE $table SET LOCATION '${newLocation.toURI.toString}'") + checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + } + } + } + + test("change stats after insert command for datasource table") { + val table = "change_stats_insert_datasource_table" + withTable(table) { + sql(s"CREATE TABLE $table (i int, j string) USING PARQUET") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + // insert into command + sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") + checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + } + } + } @@ -219,6 +262,22 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils private val randomName = new Random(31) + def checkTableStats( + tableName: String, + hasSizeInBytes: Boolean, + expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { + val stats = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).stats + if (hasSizeInBytes || expectedRowCounts.nonEmpty) { + assert(stats.isDefined) + assert(stats.get.sizeInBytes >= 0) + assert(stats.get.rowCount === expectedRowCounts) + } else { + assert(stats.isEmpty) + } + + stats + } + /** * Compute column stats for the given DataFrame and compare it with colStats. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index f6d47734d7e83..d74a7cce25ed6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -149,6 +149,7 @@ private[sql] trait SQLTestUtils .getExecutorInfos.map(_.numRunningTasks()).sum == 0) } } + /** * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` * returns. @@ -164,6 +165,19 @@ private[sql] trait SQLTestUtils } } + /** + * Creates the specified number of temporary directories, which is then passed to `f` and will be + * deleted after `f` returns. + */ + protected def withTempPaths(numPaths: Int)(f: Seq[File] => Unit): Unit = { + val files = Array.fill[File](numPaths)(Utils.createTempDir().getCanonicalFile) + try f(files) finally { + // wait for all tasks to finish before deleting files + waitForTasksToFinish() + files.foreach(Utils.deleteRecursively) + } + } + /** * Drops functions after calling `f`. A function is represented by (functionName, isTemporary). */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 6e7c475fa34c9..2a17849fa8a34 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -631,21 +631,23 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat override def alterTableStats( db: String, table: String, - stats: CatalogStatistics): Unit = withClient { + stats: Option[CatalogStatistics]): Unit = withClient { requireTableExists(db, table) val rawTable = getRawTable(db, table) // convert table statistics to properties so that we can persist them through hive client - var statsProperties: Map[String, String] = - Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) - if (stats.rowCount.isDefined) { - statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() - } - val colNameTypeMap: Map[String, DataType] = - rawTable.schema.fields.map(f => (f.name, f.dataType)).toMap - stats.colStats.foreach { case (colName, colStat) => - colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => - statsProperties += (columnStatKeyPropName(colName, k) -> v) + val statsProperties = new mutable.HashMap[String, String]() + if (stats.isDefined) { + statsProperties += STATISTICS_TOTAL_SIZE -> stats.get.sizeInBytes.toString() + if (stats.get.rowCount.isDefined) { + statsProperties += STATISTICS_NUM_ROWS -> stats.get.rowCount.get.toString() + } + val colNameTypeMap: Map[String, DataType] = + rawTable.schema.fields.map(f => (f.name, f.dataType)).toMap + stats.get.colStats.foreach { case (colName, colStat) => + colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => + statsProperties += (columnStatKeyPropName(colName, k) -> v) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 392b7cfaa8eff..223d375232393 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.execution.command.{CommandUtils, RunnableCommand} import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} @@ -434,6 +434,8 @@ case class InsertIntoHiveTable( sparkSession.catalog.uncacheTable(table.identifier.quotedString) sparkSession.sessionState.catalog.refreshTable(table.identifier) + CommandUtils.updateTableStats(sparkSession, table) + // It would be nice to just return the childRdd unchanged so insert operations could be chained, // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 64deb3818d5d1..5fd266c2d033c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -30,10 +30,12 @@ import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.hive.HiveExternalCatalog._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ + class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { test("Hive serde tables should fallback to HDFS for size estimation") { @@ -219,23 +221,6 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } - private def checkTableStats( - tableName: String, - hasSizeInBytes: Boolean, - expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { - val stats = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).stats - - if (hasSizeInBytes || expectedRowCounts.nonEmpty) { - assert(stats.isDefined) - assert(stats.get.sizeInBytes > 0) - assert(stats.get.rowCount === expectedRowCounts) - } else { - assert(stats.isEmpty) - } - - stats - } - test("test table-level statistics for hive tables created in HiveExternalCatalog") { val textTable = "textTable" withTable(textTable) { @@ -326,7 +311,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto descOutput: Seq[String], propKey: String): Option[BigInt] = { val str = descOutput - .filterNot(_.contains(HiveExternalCatalog.STATISTICS_PREFIX)) + .filterNot(_.contains(STATISTICS_PREFIX)) .filter(_.contains(propKey)) if (str.isEmpty) { None @@ -448,6 +433,103 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto "ALTER TABLE unset_prop_table UNSET TBLPROPERTIES ('prop1')") } + /** + * To see if stats exist, we need to check spark's stats properties instead of catalog + * statistics, because hive would change stats in metastore and thus change catalog statistics. + */ + private def getStatsProperties(tableName: String): Map[String, String] = { + val hTable = hiveClient.getTable(spark.sessionState.catalog.getCurrentDatabase, tableName) + hTable.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + } + + test("change stats after insert command for hive table") { + val table = s"change_stats_insert_hive_table" + withTable(table) { + sql(s"CREATE TABLE $table (i int, j string)") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + // insert into command + sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") + assert(getStatsProperties(table).isEmpty) + } + } + + test("change stats after load data command") { + val table = "change_stats_load_table" + withTable(table) { + sql(s"CREATE TABLE $table (i INT, j STRING) STORED AS PARQUET") + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetched1.get.sizeInBytes == 0) + assert(fetched1.get.colStats.size == 2) + + withTempDir { loadPath => + // load data command + val file = new File(loadPath + "/data") + val writer = new PrintWriter(file) + writer.write("2,xyz") + writer.close() + sql(s"LOAD DATA INPATH '${loadPath.toURI.toString}' INTO TABLE $table") + assert(getStatsProperties(table).isEmpty) + } + } + } + + test("change stats after add/drop partition command") { + val table = "change_stats_part_table" + withTable(table) { + sql(s"CREATE TABLE $table (i INT, j STRING) PARTITIONED BY (ds STRING, hr STRING)") + // table has two partitions initially + for (ds <- Seq("2008-04-08"); hr <- Seq("11", "12")) { + sql(s"INSERT OVERWRITE TABLE $table PARTITION (ds='$ds',hr='$hr') SELECT 1, 'a'") + } + // analyze to get initial stats + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)) + assert(fetched1.get.sizeInBytes > 0) + assert(fetched1.get.colStats.size == 2) + + withTempPaths(numPaths = 2) { case Seq(dir1, dir2) => + val file1 = new File(dir1 + "/data") + val writer1 = new PrintWriter(file1) + writer1.write("1,a") + writer1.close() + + val file2 = new File(dir2 + "/data") + val writer2 = new PrintWriter(file2) + writer2.write("1,a") + writer2.close() + + // add partition command + sql( + s""" + |ALTER TABLE $table ADD + |PARTITION (ds='2008-04-09', hr='11') LOCATION '${dir1.toURI.toString}' + |PARTITION (ds='2008-04-09', hr='12') LOCATION '${dir2.toURI.toString}' + """.stripMargin) + assert(getStatsProperties(table).isEmpty) + + // generate stats again + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS i, j") + val fetched2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(4)) + assert(fetched2.get.sizeInBytes > 0) + assert(fetched2.get.colStats.size == 2) + + // drop partition command + sql(s"ALTER TABLE $table DROP PARTITION (ds='2008-04-08'), PARTITION (hr='12')") + // only one partition left + assert(spark.sessionState.catalog.listPartitions(TableIdentifier(table)) + .map(_.spec).toSet == Set(Map("ds" -> "2008-04-09", "hr" -> "11"))) + assert(getStatsProperties(table).isEmpty) + } + } + } + test("add/drop partitions - managed table") { val catalog = spark.sessionState.catalog val managedTable = "partitionedTable" @@ -483,23 +565,19 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto assert(catalog.listPartitions(TableIdentifier(managedTable)).map(_.spec).toSet == Set(Map("ds" -> "2008-04-09", "hr" -> "11"))) - val stats2 = checkTableStats( - managedTable, hasSizeInBytes = true, expectedRowCounts = Some(4)) - assert(stats1 == stats2) - sql(s"ANALYZE TABLE $managedTable COMPUTE STATISTICS") - val stats3 = checkTableStats( + val stats2 = checkTableStats( managedTable, hasSizeInBytes = true, expectedRowCounts = Some(1)) - assert(stats2.get.sizeInBytes > stats3.get.sizeInBytes) + assert(stats1.get.sizeInBytes > stats2.get.sizeInBytes) sql(s"ALTER TABLE $managedTable ADD PARTITION (ds='2008-04-08', hr='12')") sql(s"ANALYZE TABLE $managedTable COMPUTE STATISTICS") val stats4 = checkTableStats( managedTable, hasSizeInBytes = true, expectedRowCounts = Some(1)) - assert(stats2.get.sizeInBytes > stats4.get.sizeInBytes) - assert(stats4.get.sizeInBytes == stats3.get.sizeInBytes) + assert(stats1.get.sizeInBytes > stats4.get.sizeInBytes) + assert(stats4.get.sizeInBytes == stats2.get.sizeInBytes) } } From a946be35ac177737e99942ad42de6f319f186138 Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Thu, 29 Jun 2017 14:25:51 +0800 Subject: [PATCH 100/118] [SPARK-3577] Report Spill size on disk for UnsafeExternalSorter ## What changes were proposed in this pull request? Report Spill size on disk for UnsafeExternalSorter ## How was this patch tested? Tested by running a job on cluster and verify the spill size on disk. Author: Sital Kedia Closes #17471 from sitalkedia/fix_disk_spill_size. --- .../unsafe/sort/UnsafeExternalSorter.java | 9 +++---- .../sort/UnsafeExternalSorterSuite.java | 25 +++++++++++++++++++ 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index f312fa2b2ddd7..82d03e3e9190c 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -54,7 +54,6 @@ public final class UnsafeExternalSorter extends MemoryConsumer { private final BlockManager blockManager; private final SerializerManager serializerManager; private final TaskContext taskContext; - private ShuffleWriteMetrics writeMetrics; /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; @@ -144,10 +143,6 @@ private UnsafeExternalSorter( // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024 this.fileBufferSizeBytes = 32 * 1024; - // The spill metrics are stored in a new ShuffleWriteMetrics, - // and then discarded (this fixes SPARK-16827). - // TODO: Instead, separate spill metrics should be stored and reported (tracked in SPARK-3577). - this.writeMetrics = new ShuffleWriteMetrics(); if (existingInMemorySorter == null) { this.inMemSorter = new UnsafeInMemorySorter( @@ -199,6 +194,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { spillWriters.size(), spillWriters.size() > 1 ? " times" : " time"); + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); // We only write out contents of the inMemSorter if it is not empty. if (inMemSorter.numRecords() > 0) { final UnsafeSorterSpillWriter spillWriter = @@ -226,6 +222,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { // pages, we might not be able to get memory for the pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten()); totalSpillBytes += spillSize; return spillSize; } @@ -502,6 +499,7 @@ public long spill() throws IOException { UnsafeInMemorySorter.SortedIterator inMemIterator = ((UnsafeInMemorySorter.SortedIterator) upstream).clone(); + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); // Iterate over the records that have not been returned and spill them. final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords); @@ -540,6 +538,7 @@ public long spill() throws IOException { inMemSorter.free(); inMemSorter = null; taskContext.taskMetrics().incMemoryBytesSpilled(released); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten()); totalSpillBytes += released; return released; } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 771d39016c188..d31d7c1c0900c 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -405,6 +405,31 @@ public void forcedSpillingWithoutComparator() throws Exception { assertSpillFilesWereCleanedUp(); } + @Test + public void testDiskSpilledBytes() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + long[] record = new long[100]; + int recordSize = record.length * 8; + int n = (int) pageSizeBytes / recordSize * 3; + for (int i = 0; i < n; i++) { + record[0] = (long) i; + sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false); + } + // We will have at-least 2 memory pages allocated because of rounding happening due to + // integer division of pageSizeBytes and recordSize. + assertTrue(sorter.getNumberOfAllocatedPages() >= 2); + assertTrue(taskContext.taskMetrics().diskBytesSpilled() == 0); + UnsafeExternalSorter.SpillableIterator iter = + (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator(); + assertTrue(iter.spill() > 0); + assertTrue(taskContext.taskMetrics().diskBytesSpilled() > 0); + assertEquals(0, iter.spill()); + // Even if we did not spill second time, the disk spilled bytes should still be non-zero + assertTrue(taskContext.taskMetrics().diskBytesSpilled() > 0); + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + @Test public void testPeakMemoryUsed() throws Exception { final long recordLengthBytes = 8; From 9f6b3e65ccfa0daec31b58c5a6386b3a890c2149 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 29 Jun 2017 14:37:42 +0800 Subject: [PATCH 101/118] [SPARK-21238][SQL] allow nested SQL execution ## What changes were proposed in this pull request? This is kind of another follow-up for https://github.com/apache/spark/pull/18064 . In #18064 , we wrap every SQL command with SQL execution, which makes nested SQL execution very likely to happen. #18419 trid to improve it a little bit, by introduing `SQLExecition.ignoreNestedExecutionId`. However, this is not friendly to data source developers, they may need to update their code to use this `ignoreNestedExecutionId` API. This PR proposes a new solution, to just allow nested execution. The downside is that, we may have multiple executions for one query. We can improve this by updating the data organization in SQLListener, to have 1-n mapping from query to execution, instead of 1-1 mapping. This can be done in a follow-up. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #18450 from cloud-fan/execution-id. --- .../spark/sql/execution/SQLExecution.scala | 88 ++++--------------- .../command/AnalyzeTableCommand.scala | 4 +- .../spark/sql/execution/command/cache.scala | 16 ++-- .../datasources/csv/CSVDataSource.scala | 4 +- .../datasources/jdbc/JDBCRelation.scala | 8 +- .../sql/execution/streaming/console.scala | 12 +-- .../sql/execution/streaming/memory.scala | 32 ++++--- .../sql/execution/SQLExecutionSuite.scala | 24 ----- 8 files changed, 50 insertions(+), 138 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index ca8bed5214f87..e991da7df0bde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -22,15 +22,12 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, - SparkListenerSQLExecutionStart} +import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} object SQLExecution { val EXECUTION_ID_KEY = "spark.sql.execution.id" - private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId" - private val _nextExecutionId = new AtomicLong(0) private def nextExecutionId: Long = _nextExecutionId.getAndIncrement @@ -45,10 +42,8 @@ object SQLExecution { private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = { val sc = sparkSession.sparkContext - val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null - val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null // only throw an exception during tests. a missing execution ID should not fail a job. - if (testing && !isNestedExecution && !hasExecutionId) { + if (testing && sc.getLocalProperty(EXECUTION_ID_KEY) == null) { // Attention testers: when a test fails with this exception, it means that the action that // started execution of a query didn't call withNewExecutionId. The execution ID should be // set by calling withNewExecutionId in the action that begins execution, like @@ -66,56 +61,27 @@ object SQLExecution { queryExecution: QueryExecution)(body: => T): T = { val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) - if (oldExecutionId == null) { - val executionId = SQLExecution.nextExecutionId - sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) - executionIdToQueryExecution.put(executionId, queryExecution) - try { - // sparkContext.getCallSite() would first try to pick up any call site that was previously - // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on - // streaming queries would give us call site like "run at :0" - val callSite = sparkSession.sparkContext.getCallSite() - - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) - } - } finally { - executionIdToQueryExecution.remove(executionId) - sc.setLocalProperty(EXECUTION_ID_KEY, null) - } - } else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) { - // If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the - // `body`, so that Spark jobs issued in the `body` won't be tracked. + val executionId = SQLExecution.nextExecutionId + sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) + executionIdToQueryExecution.put(executionId, queryExecution) + try { + // sparkContext.getCallSite() would first try to pick up any call site that was previously + // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on + // streaming queries would give us call site like "run at :0" + val callSite = sparkSession.sparkContext.getCallSite() + + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) try { - sc.setLocalProperty(EXECUTION_ID_KEY, null) body } finally { - sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId) + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) } - } else { - // Don't support nested `withNewExecutionId`. This is an example of the nested - // `withNewExecutionId`: - // - // class DataFrame { - // def foo: T = withNewExecutionId { something.createNewDataFrame().collect() } - // } - // - // Note: `collect` will call withNewExecutionId - // In this case, only the "executedPlan" for "collect" will be executed. The "executedPlan" - // for the outer DataFrame won't be executed. So it's meaningless to create a new Execution - // for the outer DataFrame. Even if we track it, since its "executedPlan" doesn't run, - // all accumulator metrics will be 0. It will confuse people if we show them in Web UI. - // - // A real case is the `DataFrame.count` method. - throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " + - "action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " + - "jobs issued by the nested execution.") + } finally { + executionIdToQueryExecution.remove(executionId) + sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId) } } @@ -133,20 +99,4 @@ object SQLExecution { sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) } } - - /** - * Wrap an action which may have nested execution id. This method can be used to run an execution - * inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that, - * all Spark jobs issued in the body won't be tracked in UI. - */ - def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = { - val sc = sparkSession.sparkContext - val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) - try { - sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true") - body - } finally { - sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index d780ef42f3fae..42e2a9ca5c4e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -51,9 +51,7 @@ case class AnalyzeTableCommand( // 2. when total size is changed, `oldRowCount` becomes invalid. // This is to make sure that we only record the right statistics. if (!noscan) { - val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) { - sparkSession.table(tableIdentWithDB).count() - } + val newRowCount = sparkSession.table(tableIdentWithDB).count() if (newRowCount >= 0 && newRowCount != oldRowCount) { newStats = if (newStats.isDefined) { newStats.map(_.copy(rowCount = Some(BigInt(newRowCount)))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index d36eb7587a3ef..47952f2f227a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -34,16 +34,14 @@ case class CacheTableCommand( override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq override def run(sparkSession: SparkSession): Seq[Row] = { - SQLExecution.ignoreNestedExecutionId(sparkSession) { - plan.foreach { logicalPlan => - Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString) - } - sparkSession.catalog.cacheTable(tableIdent.quotedString) + plan.foreach { logicalPlan => + Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString) + } + sparkSession.catalog.cacheTable(tableIdent.quotedString) - if (!isLazy) { - // Performs eager caching - sparkSession.table(tableIdent).count() - } + if (!isLazy) { + // Performs eager caching + sparkSession.table(tableIdent).count() } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 99133bd70989a..2031381dd2e10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -145,9 +145,7 @@ object TextInputCSVDataSource extends CSVDataSource { inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): StructType = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) - val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) { - CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption - } + val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index b11da7045de22..a521fd1323852 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -130,11 +130,9 @@ private[sql] case class JDBCRelation( } override def insert(data: DataFrame, overwrite: Boolean): Unit = { - SQLExecution.ignoreNestedExecutionId(data.sparkSession) { - data.write - .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) - .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) - } + data.write + .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) + .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 6fa7c113defaa..3baea6376069f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -48,11 +48,9 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging { println(batchIdStr) println("-------------------------------------------") // scalastyle:off println - SQLExecution.ignoreNestedExecutionId(data.sparkSession) { - data.sparkSession.createDataFrame( - data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) - .show(numRowsToShow, isTruncated) - } + data.sparkSession.createDataFrame( + data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) + .show(numRowsToShow, isTruncated) } } @@ -82,9 +80,7 @@ class ConsoleSinkProvider extends StreamSinkProvider // Truncate the displayed data if it is too long, by default it is true val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true) - SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) { - data.show(numRowsToShow, isTruncated) - } + data.show(numRowsToShow, isTruncated) ConsoleRelation(sqlContext, data) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 198a342582804..4979873ee3c7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -194,23 +194,21 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi } if (notCommitted) { logDebug(s"Committing batch $batchId to $this") - SQLExecution.ignoreNestedExecutionId(data.sparkSession) { - outputMode match { - case Append | Update => - val rows = AddedData(batchId, data.collect()) - synchronized { batches += rows } - - case Complete => - val rows = AddedData(batchId, data.collect()) - synchronized { - batches.clear() - batches += rows - } - - case _ => - throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") - } + outputMode match { + case Append | Update => + val rows = AddedData(batchId, data.collect()) + synchronized { batches += rows } + + case Complete => + val rows = AddedData(batchId, data.collect()) + synchronized { + batches.clear() + batches += rows + } + + case _ => + throw new IllegalArgumentException( + s"Output mode $outputMode is not supported by MemorySink") } } else { logDebug(s"Skipping already committed batch: $batchId") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index fe78a76568837..f6b006b98edd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -26,22 +26,9 @@ import org.apache.spark.sql.SparkSession class SQLExecutionSuite extends SparkFunSuite { test("concurrent query execution (SPARK-10548)") { - // Try to reproduce the issue with the old SparkContext val conf = new SparkConf() .setMaster("local[*]") .setAppName("test") - val badSparkContext = new BadSparkContext(conf) - try { - testConcurrentQueryExecution(badSparkContext) - fail("unable to reproduce SPARK-10548") - } catch { - case e: IllegalArgumentException => - assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY)) - } finally { - badSparkContext.stop() - } - - // Verify that the issue is fixed with the latest SparkContext val goodSparkContext = new SparkContext(conf) try { testConcurrentQueryExecution(goodSparkContext) @@ -134,17 +121,6 @@ class SQLExecutionSuite extends SparkFunSuite { } } -/** - * A bad [[SparkContext]] that does not clone the inheritable thread local properties - * when passing them to children threads. - */ -private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { - protected[spark] override val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) - override protected def initialValue(): Properties = new Properties() - } -} - object SQLExecutionSuite { @volatile var canProgress = false } From a2d5623548194f15989e7b68118d744673e33819 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Thu, 29 Jun 2017 01:23:13 -0700 Subject: [PATCH 102/118] [SPARK-20889][SPARKR] Grouped documentation for NONAGGREGATE column methods ## What changes were proposed in this pull request? Grouped documentation for nonaggregate column methods. Author: actuaryzhang Author: Wayne Zhang Closes #18422 from actuaryzhang/sparkRDocNonAgg. --- R/pkg/R/functions.R | 360 ++++++++++++++++++-------------------------- R/pkg/R/generics.R | 55 ++++--- 2 files changed, 182 insertions(+), 233 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 70ea620b471fe..cb09e847d739a 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -132,23 +132,39 @@ NULL #' df <- createDataFrame(as.data.frame(Titanic, stringsAsFactors = FALSE))} NULL -#' lit +#' Non-aggregate functions for Column operations #' -#' A new \linkS4class{Column} is created to represent the literal value. -#' If the parameter is a \linkS4class{Column}, it is returned unchanged. +#' Non-aggregate functions defined for \code{Column}. #' -#' @param x a literal value or a Column. +#' @param x Column to compute on. In \code{lit}, it is a literal value or a Column. +#' In \code{expr}, it contains an expression character object to be parsed. +#' @param y Column to compute on. +#' @param ... additional Columns. +#' @name column_nonaggregate_functions +#' @rdname column_nonaggregate_functions +#' @seealso coalesce,SparkDataFrame-method #' @family non-aggregate functions -#' @rdname lit -#' @name lit +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))} +NULL + +#' @details +#' \code{lit}: A new Column is created to represent the literal value. +#' If the parameter is a Column, it is returned unchanged. +#' +#' @rdname column_nonaggregate_functions #' @export -#' @aliases lit,ANY-method +#' @aliases lit lit,ANY-method #' @examples +#' #' \dontrun{ -#' lit(df$name) -#' select(df, lit("x")) -#' select(df, lit("2015-01-01")) -#'} +#' tmp <- mutate(df, v1 = lit(df$mpg), v2 = lit("x"), v3 = lit("2015-01-01"), +#' v4 = negate(df$mpg), v5 = expr('length(model)'), +#' v6 = greatest(df$vs, df$am), v7 = least(df$vs, df$am), +#' v8 = column("mpg")) +#' head(tmp)} #' @note lit since 1.5.0 setMethod("lit", signature("ANY"), function(x) { @@ -314,18 +330,16 @@ setMethod("bin", column(jc) }) -#' bitwiseNOT -#' -#' Computes bitwise NOT. -#' -#' @param x Column to compute on. +#' @details +#' \code{bitwiseNOT}: Computes bitwise NOT. #' -#' @rdname bitwiseNOT -#' @name bitwiseNOT -#' @family non-aggregate functions +#' @rdname column_nonaggregate_functions #' @export -#' @aliases bitwiseNOT,Column-method -#' @examples \dontrun{bitwiseNOT(df$c)} +#' @aliases bitwiseNOT bitwiseNOT,Column-method +#' @examples +#' +#' \dontrun{ +#' head(select(df, bitwiseNOT(cast(df$vs, "int"))))} #' @note bitwiseNOT since 1.5.0 setMethod("bitwiseNOT", signature(x = "Column"), @@ -375,16 +389,12 @@ setMethod("ceiling", ceil(x) }) -#' Returns the first column that is not NA -#' -#' Returns the first column that is not NA, or NA if all inputs are. +#' @details +#' \code{coalesce}: Returns the first column that is not NA, or NA if all inputs are. #' -#' @rdname coalesce -#' @name coalesce -#' @family non-aggregate functions +#' @rdname column_nonaggregate_functions #' @export #' @aliases coalesce,Column-method -#' @examples \dontrun{coalesce(df$c, df$d, df$e)} #' @note coalesce(Column) since 2.1.1 setMethod("coalesce", signature(x = "Column"), @@ -824,22 +834,24 @@ setMethod("initcap", column(jc) }) -#' is.nan -#' -#' Return true if the column is NaN, alias for \link{isnan} -#' -#' @param x Column to compute on. +#' @details +#' \code{isnan}: Returns true if the column is NaN. +#' @rdname column_nonaggregate_functions +#' @aliases isnan isnan,Column-method +#' @note isnan since 2.0.0 +setMethod("isnan", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "isnan", x@jc) + column(jc) + }) + +#' @details +#' \code{is.nan}: Alias for \link{isnan}. #' -#' @rdname is.nan -#' @name is.nan -#' @family non-aggregate functions -#' @aliases is.nan,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases is.nan is.nan,Column-method #' @export -#' @examples -#' \dontrun{ -#' is.nan(df$c) -#' isnan(df$c) -#' } #' @note is.nan since 2.0.0 setMethod("is.nan", signature(x = "Column"), @@ -847,17 +859,6 @@ setMethod("is.nan", isnan(x) }) -#' @rdname is.nan -#' @name isnan -#' @aliases isnan,Column-method -#' @note isnan since 2.0.0 -setMethod("isnan", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "isnan", x@jc) - column(jc) - }) - #' @details #' \code{kurtosis}: Returns the kurtosis of the values in a group. #' @@ -1129,27 +1130,24 @@ setMethod("minute", column(jc) }) -#' monotonically_increasing_id -#' -#' Return a column that generates monotonically increasing 64-bit integers. -#' -#' The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. -#' The current implementation puts the partition ID in the upper 31 bits, and the record number -#' within each partition in the lower 33 bits. The assumption is that the SparkDataFrame has -#' less than 1 billion partitions, and each partition has less than 8 billion records. -#' -#' As an example, consider a SparkDataFrame with two partitions, each with 3 records. +#' @details +#' \code{monotonically_increasing_id}: Returns a column that generates monotonically increasing +#' 64-bit integers. The generated ID is guaranteed to be monotonically increasing and unique, +#' but not consecutive. The current implementation puts the partition ID in the upper 31 bits, +#' and the record number within each partition in the lower 33 bits. The assumption is that the +#' SparkDataFrame has less than 1 billion partitions, and each partition has less than 8 billion +#' records. As an example, consider a SparkDataFrame with two partitions, each with 3 records. #' This expression would return the following IDs: #' 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. -#' #' This is equivalent to the MONOTONICALLY_INCREASING_ID function in SQL. +#' The method should be used with no argument. #' -#' @rdname monotonically_increasing_id -#' @aliases monotonically_increasing_id,missing-method -#' @name monotonically_increasing_id -#' @family misc functions +#' @rdname column_nonaggregate_functions +#' @aliases monotonically_increasing_id monotonically_increasing_id,missing-method #' @export -#' @examples \dontrun{select(df, monotonically_increasing_id())} +#' @examples +#' +#' \dontrun{head(select(df, monotonically_increasing_id()))} setMethod("monotonically_increasing_id", signature("missing"), function() { @@ -1171,18 +1169,12 @@ setMethod("month", column(jc) }) -#' negate -#' -#' Unary minus, i.e. negate the expression. -#' -#' @param x Column to compute on. +#' @details +#' \code{negate}: Unary minus, i.e. negate the expression. #' -#' @rdname negate -#' @name negate -#' @family non-aggregate functions -#' @aliases negate,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases negate negate,Column-method #' @export -#' @examples \dontrun{negate(df$c)} #' @note negate since 1.5.0 setMethod("negate", signature(x = "Column"), @@ -1481,23 +1473,19 @@ setMethod("stddev_samp", column(jc) }) -#' struct -#' -#' Creates a new struct column that composes multiple input columns. -#' -#' @param x a column to compute on. -#' @param ... optional column(s) to be included. +#' @details +#' \code{struct}: Creates a new struct column that composes multiple input columns. #' -#' @rdname struct -#' @name struct -#' @family non-aggregate functions -#' @aliases struct,characterOrColumn-method +#' @rdname column_nonaggregate_functions +#' @aliases struct struct,characterOrColumn-method #' @export #' @examples +#' #' \dontrun{ -#' struct(df$c, df$d) -#' struct("col1", "col2") -#' } +#' tmp <- mutate(df, v1 = struct(df$mpg, df$cyl), v2 = struct("hp", "wt", "vs"), +#' v3 = create_array(df$mpg, df$cyl, df$hp), +#' v4 = create_map(lit("x"), lit(1.0), lit("y"), lit(-1.0))) +#' head(tmp)} #' @note struct since 1.6.0 setMethod("struct", signature(x = "characterOrColumn"), @@ -1959,20 +1947,13 @@ setMethod("months_between", signature(y = "Column"), column(jc) }) -#' nanvl -#' -#' Returns col1 if it is not NaN, or col2 if col1 is NaN. -#' Both inputs should be floating point columns (DoubleType or FloatType). -#' -#' @param x first Column. -#' @param y second Column. +#' @details +#' \code{nanvl}: Returns the first column (\code{y}) if it is not NaN, or the second column (\code{x}) if +#' the first column is NaN. Both inputs should be floating point columns (DoubleType or FloatType). #' -#' @rdname nanvl -#' @name nanvl -#' @family non-aggregate functions -#' @aliases nanvl,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases nanvl nanvl,Column-method #' @export -#' @examples \dontrun{nanvl(df$c, x)} #' @note nanvl since 1.5.0 setMethod("nanvl", signature(y = "Column"), function(y, x) { @@ -2060,20 +2041,13 @@ setMethod("concat", column(jc) }) -#' greatest -#' -#' Returns the greatest value of the list of column names, skipping null values. +#' @details +#' \code{greatest}: Returns the greatest value of the list of column names, skipping null values. #' This function takes at least 2 parameters. It will return null if all parameters are null. #' -#' @param x Column to compute on -#' @param ... other columns -#' -#' @family non-aggregate functions -#' @rdname greatest -#' @name greatest -#' @aliases greatest,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases greatest greatest,Column-method #' @export -#' @examples \dontrun{greatest(df$c, df$d)} #' @note greatest since 1.5.0 setMethod("greatest", signature(x = "Column"), @@ -2087,20 +2061,13 @@ setMethod("greatest", column(jc) }) -#' least -#' -#' Returns the least value of the list of column names, skipping null values. +#' @details +#' \code{least}: Returns the least value of the list of column names, skipping null values. #' This function takes at least 2 parameters. It will return null if all parameters are null. #' -#' @param x Column to compute on -#' @param ... other columns -#' -#' @family non-aggregate functions -#' @rdname least -#' @aliases least,Column-method -#' @name least +#' @rdname column_nonaggregate_functions +#' @aliases least least,Column-method #' @export -#' @examples \dontrun{least(df$c, df$d)} #' @note least since 1.5.0 setMethod("least", signature(x = "Column"), @@ -2445,18 +2412,13 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri column(jc) }) -#' expr -#' -#' Parses the expression string into the column that it represents, similar to -#' SparkDataFrame.selectExpr +#' @details +#' \code{expr}: Parses the expression string into the column that it represents, similar to +#' \code{SparkDataFrame.selectExpr} #' -#' @param x an expression character object to be parsed. -#' @family non-aggregate functions -#' @rdname expr -#' @aliases expr,character-method -#' @name expr +#' @rdname column_nonaggregate_functions +#' @aliases expr expr,character-method #' @export -#' @examples \dontrun{expr('length(name)')} #' @note expr since 1.5.0 setMethod("expr", signature(x = "character"), function(x) { @@ -2617,18 +2579,19 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), column(jc) }) -#' rand -#' -#' Generate a random column with independent and identically distributed (i.i.d.) samples +#' @details +#' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) samples #' from U[0.0, 1.0]. #' +#' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. -#' @family non-aggregate functions -#' @rdname rand -#' @name rand -#' @aliases rand,missing-method +#' @aliases rand rand,missing-method #' @export -#' @examples \dontrun{rand()} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, r1 = rand(), r2 = rand(10), r3 = randn(), r4 = randn(10)) +#' head(tmp)} #' @note rand since 1.5.0 setMethod("rand", signature(seed = "missing"), function(seed) { @@ -2636,8 +2599,7 @@ setMethod("rand", signature(seed = "missing"), column(jc) }) -#' @rdname rand -#' @name rand +#' @rdname column_nonaggregate_functions #' @aliases rand,numeric-method #' @export #' @note rand(numeric) since 1.5.0 @@ -2647,18 +2609,13 @@ setMethod("rand", signature(seed = "numeric"), column(jc) }) -#' randn -#' -#' Generate a column with independent and identically distributed (i.i.d.) samples from +#' @details +#' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples from #' the standard normal distribution. #' -#' @param seed a random seed. Can be missing. -#' @family non-aggregate functions -#' @rdname randn -#' @name randn -#' @aliases randn,missing-method +#' @rdname column_nonaggregate_functions +#' @aliases randn randn,missing-method #' @export -#' @examples \dontrun{randn()} #' @note randn since 1.5.0 setMethod("randn", signature(seed = "missing"), function(seed) { @@ -2666,8 +2623,7 @@ setMethod("randn", signature(seed = "missing"), column(jc) }) -#' @rdname randn -#' @name randn +#' @rdname column_nonaggregate_functions #' @aliases randn,numeric-method #' @export #' @note randn(numeric) since 1.5.0 @@ -2819,20 +2775,26 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc, format) column(jc) }) -#' when -#' -#' Evaluates a list of conditions and returns one of multiple possible result expressions. + +#' @details +#' \code{when}: Evaluates a list of conditions and returns one of multiple possible result expressions. #' For unmatched expressions null is returned. #' +#' @rdname column_nonaggregate_functions #' @param condition the condition to test on. Must be a Column expression. #' @param value result expression. -#' @family non-aggregate functions -#' @rdname when -#' @name when -#' @aliases when,Column-method -#' @seealso \link{ifelse} +#' @aliases when when,Column-method #' @export -#' @examples \dontrun{when(df$age == 2, df$age + 1)} +#' @examples +#' +#' \dontrun{ +#' tmp <- mutate(df, mpg_na = otherwise(when(df$mpg > 20, df$mpg), lit(NaN)), +#' mpg2 = ifelse(df$mpg > 20 & df$am > 0, 0, 1), +#' mpg3 = ifelse(df$mpg > 20, df$mpg, 20.0)) +#' head(tmp) +#' tmp <- mutate(tmp, ind_na1 = is.nan(tmp$mpg_na), ind_na2 = isnan(tmp$mpg_na)) +#' head(select(tmp, coalesce(tmp$mpg_na, tmp$mpg))) +#' head(select(tmp, nanvl(tmp$mpg_na, tmp$hp)))} #' @note when since 1.5.0 setMethod("when", signature(condition = "Column", value = "ANY"), function(condition, value) { @@ -2842,25 +2804,16 @@ setMethod("when", signature(condition = "Column", value = "ANY"), column(jc) }) -#' ifelse -#' -#' Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. +#' @details +#' \code{ifelse}: Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. #' Otherwise \code{no} is returned for unmatched conditions. #' +#' @rdname column_nonaggregate_functions #' @param test a Column expression that describes the condition. #' @param yes return values for \code{TRUE} elements of test. #' @param no return values for \code{FALSE} elements of test. -#' @family non-aggregate functions -#' @rdname ifelse -#' @name ifelse -#' @aliases ifelse,Column-method -#' @seealso \link{when} +#' @aliases ifelse ifelse,Column-method #' @export -#' @examples -#' \dontrun{ -#' ifelse(df$a > 1 & df$b > 2, 0, 1) -#' ifelse(df$a > 1, df$a, 1) -#' } #' @note ifelse since 1.5.0 setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), @@ -3263,19 +3216,12 @@ setMethod("posexplode", column(jc) }) -#' create_array -#' -#' Creates a new array column. The input columns must all have the same data type. -#' -#' @param x Column to compute on -#' @param ... additional Column(s). +#' @details +#' \code{create_array}: Creates a new array column. The input columns must all have the same data type. #' -#' @family non-aggregate functions -#' @rdname create_array -#' @name create_array -#' @aliases create_array,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases create_array create_array,Column-method #' @export -#' @examples \dontrun{create_array(df$x, df$y, df$z)} #' @note create_array since 2.3.0 setMethod("create_array", signature(x = "Column"), @@ -3288,22 +3234,15 @@ setMethod("create_array", column(jc) }) -#' create_map -#' -#' Creates a new map column. The input columns must be grouped as key-value pairs, +#' @details +#' \code{create_map}: Creates a new map column. The input columns must be grouped as key-value pairs, #' e.g. (key1, value1, key2, value2, ...). #' The key columns must all have the same data type, and can't be null. #' The value columns must all have the same data type. #' -#' @param x Column to compute on -#' @param ... additional Column(s). -#' -#' @family non-aggregate functions -#' @rdname create_map -#' @name create_map -#' @aliases create_map,Column-method +#' @rdname column_nonaggregate_functions +#' @aliases create_map create_map,Column-method #' @export -#' @examples \dontrun{create_map(lit("x"), lit(1.0), lit("y"), lit(-1.0))} #' @note create_map since 2.3.0 setMethod("create_map", signature(x = "Column"), @@ -3554,21 +3493,18 @@ setMethod("grouping_id", column(jc) }) -#' input_file_name -#' -#' Creates a string column with the input file name for a given row +#' @details +#' \code{input_file_name}: Creates a string column with the input file name for a given row. +#' The method should be used with no argument. #' -#' @rdname input_file_name -#' @name input_file_name -#' @family non-aggregate functions -#' @aliases input_file_name,missing-method +#' @rdname column_nonaggregate_functions +#' @aliases input_file_name input_file_name,missing-method #' @export #' @examples -#' \dontrun{ -#' df <- read.text("README.md") #' -#' head(select(df, input_file_name())) -#' } +#' \dontrun{ +#' tmp <- read.text("README.md") +#' head(select(tmp, input_file_name()))} #' @note input_file_name since 2.3.0 setMethod("input_file_name", signature("missing"), function() { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index dc99e3d94b269..1deb057bb1b82 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -422,9 +422,8 @@ setGeneric("cache", function(x) { standardGeneric("cache") }) setGeneric("checkpoint", function(x, eager = TRUE) { standardGeneric("checkpoint") }) #' @rdname coalesce -#' @param x a Column or a SparkDataFrame. -#' @param ... additional argument(s). If \code{x} is a Column, additional Columns can be optionally -#' provided. +#' @param x a SparkDataFrame. +#' @param ... additional argument(s). #' @export setGeneric("coalesce", function(x, ...) { standardGeneric("coalesce") }) @@ -863,8 +862,9 @@ setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) #' @export setGeneric("startsWith", function(x, prefix) { standardGeneric("startsWith") }) -#' @rdname when +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("when", function(condition, value) { standardGeneric("when") }) #' @rdname otherwise @@ -938,8 +938,9 @@ setGeneric("base64", function(x) { standardGeneric("base64") }) #' @name NULL setGeneric("bin", function(x) { standardGeneric("bin") }) -#' @rdname bitwiseNOT +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) #' @rdname column_math_functions @@ -995,12 +996,14 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) -#' @rdname create_array +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("create_array", function(x, ...) { standardGeneric("create_array") }) -#' @rdname create_map +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) #' @rdname hash @@ -1065,8 +1068,9 @@ setGeneric("explode", function(x) { standardGeneric("explode") }) #' @export setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) -#' @rdname expr +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("expr", function(x) { standardGeneric("expr") }) #' @rdname column_datetime_diff_functions @@ -1093,8 +1097,9 @@ setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") #' @name NULL setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) -#' @rdname greatest +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) #' @rdname column_aggregate_functions @@ -1127,9 +1132,9 @@ setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) #' @name NULL setGeneric("initcap", function(x) { standardGeneric("initcap") }) -#' @param x empty. Should be used with no argument. -#' @rdname input_file_name +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("input_file_name", function(x = "missing") { standardGeneric("input_file_name") }) @@ -1138,8 +1143,9 @@ setGeneric("input_file_name", #' @name NULL setGeneric("instr", function(y, x) { standardGeneric("instr") }) -#' @rdname is.nan +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("isnan", function(x) { standardGeneric("isnan") }) #' @rdname column_aggregate_functions @@ -1164,8 +1170,9 @@ setGeneric("last_day", function(x) { standardGeneric("last_day") }) #' @export setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") }) -#' @rdname least +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("least", function(x, ...) { standardGeneric("least") }) #' @rdname column_string_functions @@ -1173,8 +1180,9 @@ setGeneric("least", function(x, ...) { standardGeneric("least") }) #' @name NULL setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) -#' @rdname lit +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("lit", function(x) { standardGeneric("lit") }) #' @rdname column_string_functions @@ -1206,9 +1214,9 @@ setGeneric("md5", function(x) { standardGeneric("md5") }) #' @name NULL setGeneric("minute", function(x) { standardGeneric("minute") }) -#' @param x empty. Should be used with no argument. -#' @rdname monotonically_increasing_id +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("monotonically_increasing_id", function(x = "missing") { standardGeneric("monotonically_increasing_id") }) @@ -1226,12 +1234,14 @@ setGeneric("months_between", function(y, x) { standardGeneric("months_between") #' @export setGeneric("n", function(x) { standardGeneric("n") }) -#' @rdname nanvl +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) -#' @rdname negate +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("negate", function(x) { standardGeneric("negate") }) #' @rdname not @@ -1275,12 +1285,14 @@ setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") #' @name NULL setGeneric("quarter", function(x) { standardGeneric("quarter") }) -#' @rdname rand +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("rand", function(seed) { standardGeneric("rand") }) -#' @rdname randn +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("randn", function(seed) { standardGeneric("randn") }) #' @rdname rank @@ -1409,8 +1421,9 @@ setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) #' @name NULL setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) -#' @rdname struct +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("struct", function(x, ...) { standardGeneric("struct") }) #' @rdname column_string_functions From 70085e83d1ee728b23f7df15f570eb8d77f67a7a Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 29 Jun 2017 09:51:12 +0100 Subject: [PATCH 103/118] [SPARK-21210][DOC][ML] Javadoc 8 fixes for ML shared param traits PR #15999 included fixes for doc strings in the ML shared param traits (occurrences of `>` and `>=`). This PR simply uses the HTML-escaped version of the param doc to embed into the Scaladoc, to ensure that when `SharedParamsCodeGen` is run, the generated javadoc will be compliant for Java 8. ## How was this patch tested? Existing tests Author: Nick Pentreath Closes #18420 from MLnick/shared-params-javadoc8. --- .../apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 5 ++++- .../org/apache/spark/ml/param/shared/sharedParams.scala | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index c94b8b4e9dfda..013817a41baf5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.param.shared import java.io.PrintWriter import scala.reflect.ClassTag +import scala.xml.Utility /** * Code generator for shared params (sharedParams.scala). Run under the Spark folder with @@ -167,6 +168,8 @@ private[shared] object SharedParamsCodeGen { "def" } + val htmlCompliantDoc = Utility.escape(doc) + s""" |/** | * Trait for shared param $name$defaultValueDoc. @@ -174,7 +177,7 @@ private[shared] object SharedParamsCodeGen { |private[ml] trait Has$Name extends Params { | | /** - | * Param for $doc. + | * Param for $htmlCompliantDoc. | * @group ${groupStr(0)} | */ | final val $name: $Param = new $Param(this, "$name", "$doc"$isValid) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index e3e03dfd43dd6..50619607a5054 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -176,7 +176,7 @@ private[ml] trait HasThreshold extends Params { private[ml] trait HasThresholds extends Params { /** - * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold. + * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold. * @group param */ final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold", (t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1) From d106a74c53f493c3c18741a9b19cb821dace4ba2 Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 29 Jun 2017 09:59:36 +0100 Subject: [PATCH 104/118] [SPARK-21240] Fix code style for constructing and stopping a SparkContext in UT. ## What changes were proposed in this pull request? Same with SPARK-20985. Fix code style for constructing and stopping a `SparkContext`. Assure the context is stopped to avoid other tests complain that there's only one `SparkContext` can exist. Author: jinxing Closes #18454 from jinxing64/SPARK-21240. --- .../scala/org/apache/spark/scheduler/MapStatusSuite.scala | 6 ++---- .../apache/spark/sql/execution/ui/SQLListenerSuite.scala | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index e6120139f4958..276169e02f01d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -26,6 +26,7 @@ import org.roaringbitmap.RoaringBitmap import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.internal.config +import org.apache.spark.LocalSparkContext._ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.storage.BlockManagerId @@ -160,12 +161,9 @@ class MapStatusSuite extends SparkFunSuite { .set("spark.serializer", classOf[KryoSerializer].getName) .setMaster("local") .setAppName("SPARK-21133") - val sc = new SparkContext(conf) - try { + withSpark(new SparkContext(conf)) { sc => val count = sc.parallelize(0 until 3000, 10).repartition(2001).collect().length assert(count === 3000) - } finally { - sc.stop() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index e6cd41e4facf1..82eff5e6491ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -25,6 +25,7 @@ import org.mockito.Mockito.mock import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.config +import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SparkSession} @@ -496,8 +497,7 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { .setAppName("test") .set(config.MAX_TASK_FAILURES, 1) // Don't retry the tasks to run this test quickly .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly - val sc = new SparkContext(conf) - try { + withSpark(new SparkContext(conf)) { sc => SparkSession.sqlListener.set(null) val spark = new SparkSession(sc) import spark.implicits._ @@ -522,8 +522,6 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { assert(spark.sharedState.listener.executionIdToData.size <= 100) assert(spark.sharedState.listener.jobIdToExecutionId.size <= 100) assert(spark.sharedState.listener.stageIdToStageMetrics.size <= 100) - } finally { - sc.stop() } } } From d7da2b94d6107341b33ca9224e9bfa4c9a92ed88 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Thu, 29 Jun 2017 10:01:12 +0100 Subject: [PATCH 105/118] =?UTF-8?q?[SPARK-21135][WEB=20UI]=20On=20history?= =?UTF-8?q?=20server=20page=EF=BC=8Cduration=20of=20incompleted=20applicat?= =?UTF-8?q?ions=20should=20be=20hidden=20instead=20of=20showing=20up=20as?= =?UTF-8?q?=200?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Hide duration of incompleted applications. ## How was this patch tested? manual tests Author: fjh100456 Closes #18351 from fjh100456/master. --- .../spark/ui/static/historypage-template.html | 4 ++-- .../org/apache/spark/ui/static/historypage.js | 15 ++++++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index bfe31aae555ba..6cff0068d8bcb 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -44,7 +44,7 @@ Completed - + Duration @@ -74,7 +74,7 @@ {{attemptId}} {{startTime}} {{endTime}} - {{duration}} + {{duration}} {{sparkUser}} {{lastUpdated}} Download diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 5ec1ce15a2127..9edd3ba0e0ba6 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -182,12 +182,17 @@ $(document).ready(function() { for (i = 0; i < completedCells.length; i++) { completedCells[i].style.display='none'; } - } - var durationCells = document.getElementsByClassName("durationClass"); - for (i = 0; i < durationCells.length; i++) { - var timeInMilliseconds = parseInt(durationCells[i].title); - durationCells[i].innerHTML = formatDuration(timeInMilliseconds); + var durationCells = document.getElementsByClassName("durationColumn"); + for (i = 0; i < durationCells.length; i++) { + durationCells[i].style.display='none'; + } + } else { + var durationCells = document.getElementsByClassName("durationClass"); + for (i = 0; i < durationCells.length; i++) { + var timeInMilliseconds = parseInt(durationCells[i].title); + durationCells[i].innerHTML = formatDuration(timeInMilliseconds); + } } if ($(selector.concat(" tr")).length < 20) { From 29bd251dd5914fc3b6146eb4fe0b45f1c84dba62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E6=B2=BB=E5=9B=BD10192065?= Date: Thu, 29 Jun 2017 20:53:48 +0800 Subject: [PATCH 106/118] [SPARK-21225][CORE] Considering CPUS_PER_TASK when allocating task slots for each WorkerOffer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit JIRA Issue:https://issues.apache.org/jira/browse/SPARK-21225 In the function "resourceOffers", It declare a variable "tasks" for storage the tasks which have allocated a executor. It declared like this: `val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))` But, I think this code only conside a situation for that one task per core. If the user set "spark.task.cpus" as 2 or 3, It really don't need so much Mem. I think It can motify as follow: val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK)) to instead. Motify like this the other earning is that it's more easy to understand the way how the tasks allocate offers. Author: 杨治国10192065 Closes #18435 from JackYangzg/motifyTaskCoreDisp. --- .../scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 91ec172ffeda1..737b383631148 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -345,7 +345,7 @@ private[spark] class TaskSchedulerImpl( val shuffledOffers = shuffleOffers(filteredOffers) // Build a list of tasks to assign to each worker. - val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK)) val availableCpus = shuffledOffers.map(o => o.cores).toArray val sortedTaskSets = rootPool.getSortedTaskSetQueue for (taskSet <- sortedTaskSets) { From 18066f2e61f430b691ed8a777c9b4e5786bf9dbc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 29 Jun 2017 21:28:48 +0800 Subject: [PATCH 107/118] [SPARK-21052][SQL] Add hash map metrics to join ## What changes were proposed in this pull request? This adds the average hash map probe metrics to join operator such as `BroadcastHashJoin` and `ShuffledHashJoin`. This PR adds the API to `HashedRelation` to get average hash map probe. ## How was this patch tested? Related test cases are added. Author: Liang-Chi Hsieh Closes #18301 from viirya/SPARK-21052. --- .../aggregate/HashAggregateExec.scala | 15 +- .../TungstenAggregationIterator.scala | 34 ++-- .../joins/BroadcastHashJoinExec.scala | 30 ++- .../spark/sql/execution/joins/HashJoin.scala | 8 +- .../sql/execution/joins/HashedRelation.scala | 43 +++- .../joins/ShuffledHashJoinExec.scala | 6 +- .../sql/execution/metric/SQLMetrics.scala | 32 ++- .../execution/metric/SQLMetricsSuite.scala | 188 ++++++++++++++++-- 8 files changed, 296 insertions(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 5027a615ced7a..56f61c30c4a38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -60,7 +60,7 @@ case class HashAggregateExec( "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"), - "avgHashmapProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hashmap probe")) + "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -94,7 +94,7 @@ case class HashAggregateExec( val numOutputRows = longMetric("numOutputRows") val peakMemory = longMetric("peakMemory") val spillSize = longMetric("spillSize") - val avgHashmapProbe = longMetric("avgHashmapProbe") + val avgHashProbe = longMetric("avgHashProbe") child.execute().mapPartitions { iter => @@ -119,7 +119,7 @@ case class HashAggregateExec( numOutputRows, peakMemory, spillSize, - avgHashmapProbe) + avgHashProbe) if (!hasInput && groupingExpressions.isEmpty) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) @@ -344,7 +344,7 @@ case class HashAggregateExec( sorter: UnsafeKVExternalSorter, peakMemory: SQLMetric, spillSize: SQLMetric, - avgHashmapProbe: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = { + avgHashProbe: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = { // update peak execution memory val mapMemory = hashMap.getPeakMemoryUsedBytes @@ -355,8 +355,7 @@ case class HashAggregateExec( metrics.incPeakExecutionMemory(maxMemory) // Update average hashmap probe - val avgProbes = hashMap.getAverageProbesPerLookup() - avgHashmapProbe.add(avgProbes.ceil.toLong) + avgHashProbe.set(hashMap.getAverageProbesPerLookup()) if (sorter == null) { // not spilled @@ -584,7 +583,7 @@ case class HashAggregateExec( val doAgg = ctx.freshName("doAggregateWithKeys") val peakMemory = metricTerm(ctx, "peakMemory") val spillSize = metricTerm(ctx, "spillSize") - val avgHashmapProbe = metricTerm(ctx, "avgHashmapProbe") + val avgHashProbe = metricTerm(ctx, "avgHashProbe") def generateGenerateCode(): String = { if (isFastHashMapEnabled) { @@ -611,7 +610,7 @@ case class HashAggregateExec( s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"} else ""} $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm, $peakMemory, $spillSize, - $avgHashmapProbe); + $avgHashProbe); } """) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 8efa95d48aea0..cfa930607360c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -89,7 +89,7 @@ class TungstenAggregationIterator( numOutputRows: SQLMetric, peakMemory: SQLMetric, spillSize: SQLMetric, - avgHashmapProbe: SQLMetric) + avgHashProbe: SQLMetric) extends AggregationIterator( groupingExpressions, originalInputAttributes, @@ -367,6 +367,22 @@ class TungstenAggregationIterator( } } + TaskContext.get().addTaskCompletionListener(_ => { + // At the end of the task, update the task's peak memory usage. Since we destroy + // the map to create the sorter, their memory usages should not overlap, so it is safe + // to just use the max of the two. + val mapMemory = hashMap.getPeakMemoryUsedBytes + val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) + val maxMemory = Math.max(mapMemory, sorterMemory) + val metrics = TaskContext.get().taskMetrics() + peakMemory.set(maxMemory) + spillSize.set(metrics.memoryBytesSpilled - spillSizeBefore) + metrics.incPeakExecutionMemory(maxMemory) + + // Updating average hashmap probe + avgHashProbe.set(hashMap.getAverageProbesPerLookup()) + }) + /////////////////////////////////////////////////////////////////////////// // Part 7: Iterator's public methods. /////////////////////////////////////////////////////////////////////////// @@ -409,22 +425,6 @@ class TungstenAggregationIterator( } } - // If this is the last record, update the task's peak memory usage. Since we destroy - // the map to create the sorter, their memory usages should not overlap, so it is safe - // to just use the max of the two. - if (!hasNext) { - val mapMemory = hashMap.getPeakMemoryUsedBytes - val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) - val maxMemory = Math.max(mapMemory, sorterMemory) - val metrics = TaskContext.get().taskMetrics() - peakMemory += maxMemory - spillSize += metrics.memoryBytesSpilled - spillSizeBefore - metrics.incPeakExecutionMemory(maxMemory) - - // Update average hashmap probe if this is the last record. - val averageProbes = hashMap.getAverageProbesPerLookup() - avgHashmapProbe.add(averageProbes.ceil.toLong) - } numOutputRows += 1 res } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 0bc261d593df4..bfa1e9d49a545 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Dist import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType +import org.apache.spark.util.TaskCompletionListener /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -46,7 +47,8 @@ case class BroadcastHashJoinExec( extends BinaryExecNode with HashJoin with CodegenSupport { override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildKeys) @@ -60,12 +62,13 @@ case class BroadcastHashJoinExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") + val avgHashProbe = longMetric("avgHashProbe") val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => val hashed = broadcastRelation.value.asReadOnlyCopy() TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) - join(streamedIter, hashed, numOutputRows) + join(streamedIter, hashed, numOutputRows, avgHashProbe) } } @@ -90,6 +93,23 @@ case class BroadcastHashJoinExec( } } + /** + * Returns the codes used to add a task completion listener to update avg hash probe + * at the end of the task. + */ + private def genTaskListener(avgHashProbe: String, relationTerm: String): String = { + val listenerClass = classOf[TaskCompletionListener].getName + val taskContextClass = classOf[TaskContext].getName + s""" + | $taskContextClass$$.MODULE$$.get().addTaskCompletionListener(new $listenerClass() { + | @Override + | public void onTaskCompletion($taskContextClass context) { + | $avgHashProbe.set($relationTerm.getAverageProbesPerLookup()); + | } + | }); + """.stripMargin + } + /** * Returns a tuple of Broadcast of HashedRelation and the variable name for it. */ @@ -99,10 +119,16 @@ case class BroadcastHashJoinExec( val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) val relationTerm = ctx.freshName("relation") val clsName = broadcastRelation.value.getClass.getName + + // At the end of the task, we update the avg hash probe. + val avgHashProbe = metricTerm(ctx, "avgHashProbe") + val addTaskListener = genTaskListener(avgHashProbe, relationTerm) + ctx.addMutableState(clsName, relationTerm, s""" | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); | incPeakExecutionMemory($relationTerm.estimatedSize()); + | $addTaskListener """.stripMargin) (broadcastRelation, relationTerm) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 1aef5f6864263..b09edf380c2d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -193,7 +194,8 @@ trait HashJoin { protected def join( streamedIter: Iterator[InternalRow], hashed: HashedRelation, - numOutputRows: SQLMetric): Iterator[InternalRow] = { + numOutputRows: SQLMetric, + avgHashProbe: SQLMetric): Iterator[InternalRow] = { val joinedIter = joinType match { case _: InnerLike => @@ -211,6 +213,10 @@ trait HashJoin { s"BroadcastHashJoin should not take $x as the JoinType") } + // At the end of the task, we update the avg hash probe. + TaskContext.get().addTaskCompletionListener(_ => + avgHashProbe.set(hashed.getAverageProbesPerLookup())) + val resultProj = createResultProjection joinedIter.map { r => numOutputRows += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 2dd1dc3da96c9..3c702856114f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -79,6 +79,11 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { * Release any used resources. */ def close(): Unit + + /** + * Returns the average number of probes per key lookup. + */ + def getAverageProbesPerLookup(): Double } private[execution] object HashedRelation { @@ -242,7 +247,8 @@ private[joins] class UnsafeHashedRelation( binaryMap = new BytesToBytesMap( taskMemoryManager, (nKeys * 1.5 + 1).toInt, // reduce hash collision - pageSizeBytes) + pageSizeBytes, + true) var i = 0 var keyBuffer = new Array[Byte](1024) @@ -273,6 +279,8 @@ private[joins] class UnsafeHashedRelation( override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException { read(in.readInt, in.readLong, in.readBytes) } + + override def getAverageProbesPerLookup(): Double = binaryMap.getAverageProbesPerLookup() } private[joins] object UnsafeHashedRelation { @@ -290,7 +298,8 @@ private[joins] object UnsafeHashedRelation { taskMemoryManager, // Only 70% of the slots can be used before growing, more capacity help to reduce collision (sizeEstimate * 1.5 + 1).toInt, - pageSizeBytes) + pageSizeBytes, + true) // Create a mapping of buildKeys -> rows val keyGenerator = UnsafeProjection.create(key) @@ -344,7 +353,7 @@ private[joins] object UnsafeHashedRelation { * determined by `key1 - minKey`. * * The map is created as sparse mode, then key-value could be appended into it. Once finish - * appending, caller could all optimize() to try to turn the map into dense mode, which is faster + * appending, caller could call optimize() to try to turn the map into dense mode, which is faster * to probe. * * see http://java-performance.info/implementing-world-fastest-java-int-to-int-hash-map/ @@ -385,6 +394,10 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap // The number of unique keys. private var numKeys = 0L + // Tracking average number of probes per key lookup. + private var numKeyLookups = 0L + private var numProbes = 0L + // needed by serializer def this() = { this( @@ -469,6 +482,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { if (isDense) { + numKeyLookups += 1 + numProbes += 1 if (key >= minKey && key <= maxKey) { val value = array((key - minKey).toInt) if (value > 0) { @@ -477,11 +492,14 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } else { var pos = firstSlot(key) + numKeyLookups += 1 + numProbes += 1 while (array(pos + 1) != 0) { if (array(pos) == key) { return getRow(array(pos + 1), resultRow) } pos = nextSlot(pos) + numProbes += 1 } } null @@ -509,6 +527,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { if (isDense) { + numKeyLookups += 1 + numProbes += 1 if (key >= minKey && key <= maxKey) { val value = array((key - minKey).toInt) if (value > 0) { @@ -517,11 +537,14 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } else { var pos = firstSlot(key) + numKeyLookups += 1 + numProbes += 1 while (array(pos + 1) != 0) { if (array(pos) == key) { return valueIter(array(pos + 1), resultRow) } pos = nextSlot(pos) + numProbes += 1 } } null @@ -573,8 +596,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap private def updateIndex(key: Long, address: Long): Unit = { var pos = firstSlot(key) assert(numKeys < array.length / 2) + numKeyLookups += 1 + numProbes += 1 while (array(pos) != key && array(pos + 1) != 0) { pos = nextSlot(pos) + numProbes += 1 } if (array(pos + 1) == 0) { // this is the first value for this key, put the address in array. @@ -686,6 +712,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap writeLong(maxKey) writeLong(numKeys) writeLong(numValues) + writeLong(numKeyLookups) + writeLong(numProbes) writeLong(array.length) writeLongArray(writeBuffer, array, array.length) @@ -727,6 +755,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap maxKey = readLong() numKeys = readLong() numValues = readLong() + numKeyLookups = readLong() + numProbes = readLong() val length = readLong().toInt mask = length - 2 @@ -742,6 +772,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap override def read(kryo: Kryo, in: Input): Unit = { read(in.readBoolean, in.readLong, in.readBytes) } + + /** + * Returns the average number of probes per key lookup. + */ + def getAverageProbesPerLookup(): Double = numProbes.toDouble / numKeyLookups } private[joins] class LongHashedRelation( @@ -793,6 +828,8 @@ private[joins] class LongHashedRelation( resultRow = new UnsafeRow(nFields) map = in.readObject().asInstanceOf[LongToUnsafeRowMap] } + + override def getAverageProbesPerLookup(): Double = map.getAverageProbesPerLookup() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index afb6e5e3dd235..f1df41ca49c27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -42,7 +42,8 @@ case class ShuffledHashJoinExec( override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"), - "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map")) + "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"), + "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil @@ -62,9 +63,10 @@ case class ShuffledHashJoinExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") + val avgHashProbe = longMetric("avgHashProbe") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => val hashed = buildHashedRelation(buildIter) - join(streamIter, hashed, numOutputRows) + join(streamIter, hashed, numOutputRows, avgHashProbe) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 49cab04de2bf0..b4653c1b564f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -57,6 +57,12 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato override def add(v: Long): Unit = _value += v + // We can set a double value to `SQLMetric` which stores only long value, if it is + // average metrics. + def set(v: Double): Unit = SQLMetrics.setDoubleForAverageMetrics(this, v) + + def set(v: Long): Unit = _value = v + def +=(v: Long): Unit = _value += v override def value: Long = _value @@ -74,6 +80,19 @@ object SQLMetrics { private val TIMING_METRIC = "timing" private val AVERAGE_METRIC = "average" + private val baseForAvgMetric: Int = 10 + + /** + * Converts a double value to long value by multiplying a base integer, so we can store it in + * `SQLMetrics`. It only works for average metrics. When showing the metrics on UI, we restore + * it back to a double value up to the decimal places bound by the base integer. + */ + private[sql] def setDoubleForAverageMetrics(metric: SQLMetric, v: Double): Unit = { + assert(metric.metricType == AVERAGE_METRIC, + s"Can't set a double to a metric of metrics type: ${metric.metricType}") + metric.set((v * baseForAvgMetric).toLong) + } + def createMetric(sc: SparkContext, name: String): SQLMetric = { val acc = new SQLMetric(SUM_METRIC) acc.register(sc, name = Some(name), countFailedValues = false) @@ -104,15 +123,14 @@ object SQLMetrics { /** * Create a metric to report the average information (including min, med, max) like - * avg hashmap probe. Because `SQLMetric` stores long values, we take the ceil of the average - * values before storing them. This metric is used to record an average value computed in the - * end of a task. It should be set once. The initial values (zeros) of this metrics will be - * excluded after. + * avg hash probe. As average metrics are double values, this kind of metrics should be + * only set with `SQLMetric.set` method instead of other methods like `SQLMetric.add`. + * The initial values (zeros) of this metrics will be excluded after. */ def createAverageMetric(sc: SparkContext, name: String): SQLMetric = { // The final result of this metric in physical operator UI may looks like: // probe avg (min, med, max): - // (1, 2, 6) + // (1.2, 2.2, 6.3) val acc = new SQLMetric(AVERAGE_METRIC) acc.register(sc, name = Some(s"$name (min, med, max)"), countFailedValues = false) acc @@ -127,7 +145,7 @@ object SQLMetrics { val numberFormat = NumberFormat.getIntegerInstance(Locale.US) numberFormat.format(values.sum) } else if (metricsType == AVERAGE_METRIC) { - val numberFormat = NumberFormat.getIntegerInstance(Locale.US) + val numberFormat = NumberFormat.getNumberInstance(Locale.US) val validValues = values.filter(_ > 0) val Seq(min, med, max) = { @@ -137,7 +155,7 @@ object SQLMetrics { val sorted = validValues.sorted Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) } - metric.map(numberFormat.format) + metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric)) } s"\n($min, $med, $max)" } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index a12ce2b9eba34..cb3405b2fe19b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -47,9 +47,10 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { private def getSparkPlanMetrics( df: DataFrame, expectedNumOfJobs: Int, - expectedNodeIds: Set[Long]): Option[Map[Long, (String, Map[String, Any])]] = { + expectedNodeIds: Set[Long], + enableWholeStage: Boolean = false): Option[Map[Long, (String, Map[String, Any])]] = { val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + withSQLConf("spark.sql.codegen.wholeStage" -> enableWholeStage.toString) { df.collect() } sparkContext.listenerBus.waitUntilEmpty(10000) @@ -110,6 +111,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } } + /** + * Generates a `DataFrame` by filling randomly generated bytes for hash collision. + */ + private def generateRandomBytesDF(numRows: Int = 65535): DataFrame = { + val random = new Random() + val manyBytes = (0 until numRows).map { _ => + val byteArrSize = random.nextInt(100) + val bytes = new Array[Byte](byteArrSize) + random.nextBytes(bytes) + (bytes, random.nextInt(100)) + } + manyBytes.toSeq.toDF("a", "b") + } + test("LocalTableScanExec computes metrics in collect and take") { val df1 = spark.createDataset(Seq(1, 2, 3)) val logical = df1.queryExecution.logical @@ -151,9 +166,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df = testData2.groupBy().count() // 2 partitions val expected1 = Seq( Map("number of output rows" -> 2L, - "avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)"), + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"), Map("number of output rows" -> 1L, - "avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)")) + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) testSparkPlanMetrics(df, 1, Map( 2L -> ("HashAggregate", expected1(0)), 0L -> ("HashAggregate", expected1(1))) @@ -163,9 +178,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df2 = testData2.groupBy('a).count() val expected2 = Seq( Map("number of output rows" -> 4L, - "avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)"), + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"), Map("number of output rows" -> 3L, - "avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)")) + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) testSparkPlanMetrics(df2, 1, Map( 2L -> ("HashAggregate", expected2(0)), 0L -> ("HashAggregate", expected2(1))) @@ -173,19 +188,42 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } test("Aggregate metrics: track avg probe") { - val random = new Random() - val manyBytes = (0 until 65535).map { _ => - val byteArrSize = random.nextInt(100) - val bytes = new Array[Byte](byteArrSize) - random.nextBytes(bytes) - (bytes, random.nextInt(100)) - } - val df = manyBytes.toSeq.toDF("a", "b").repartition(1).groupBy('a).count() - val metrics = getSparkPlanMetrics(df, 1, Set(2L, 0L)).get - Seq(metrics(2L)._2("avg hashmap probe (min, med, max)"), - metrics(0L)._2("avg hashmap probe (min, med, max)")).foreach { probes => - probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => - assert(probe.toInt > 1) + // The executed plan looks like: + // HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, count#71L]) + // +- Exchange hashpartitioning(a#61, 5) + // +- HashAggregate(keys=[a#61], functions=[partial_count(1)], output=[a#61, count#76L]) + // +- Exchange RoundRobinPartitioning(1) + // +- LocalTableScan [a#61] + // + // Assume the execution plan with node id is: + // Wholestage disabled: + // HashAggregate(nodeId = 0) + // Exchange(nodeId = 1) + // HashAggregate(nodeId = 2) + // Exchange (nodeId = 3) + // LocalTableScan(nodeId = 4) + // + // Wholestage enabled: + // WholeStageCodegen(nodeId = 0) + // HashAggregate(nodeId = 1) + // Exchange(nodeId = 2) + // WholeStageCodegen(nodeId = 3) + // HashAggregate(nodeId = 4) + // Exchange(nodeId = 5) + // LocalTableScan(nodeId = 6) + Seq(true, false).foreach { enableWholeStage => + val df = generateRandomBytesDF().repartition(1).groupBy('a).count() + val nodeIds = if (enableWholeStage) { + Set(4L, 1L) + } else { + Set(2L, 0L) + } + val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe (min, med, max)") + probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => + assert(probe.toDouble > 1.0) + } } } } @@ -267,10 +305,120 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df = df1.join(broadcast(df2), "key") testSparkPlanMetrics(df, 2, Map( 1L -> ("BroadcastHashJoin", Map( - "number of output rows" -> 2L))) + "number of output rows" -> 2L, + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"))) ) } + test("BroadcastHashJoin metrics: track avg probe") { + // The executed plan looks like: + // Project [a#210, b#211, b#221] + // +- BroadcastHashJoin [a#210], [a#220], Inner, BuildRight + // :- Project [_1#207 AS a#210, _2#208 AS b#211] + // : +- Filter isnotnull(_1#207) + // : +- LocalTableScan [_1#207, _2#208] + // +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, binary, true])) + // +- Project [_1#217 AS a#220, _2#218 AS b#221] + // +- Filter isnotnull(_1#217) + // +- LocalTableScan [_1#217, _2#218] + // + // Assume the execution plan with node id is + // WholeStageCodegen disabled: + // Project(nodeId = 0) + // BroadcastHashJoin(nodeId = 1) + // ...(ignored) + // + // WholeStageCodegen enabled: + // WholeStageCodegen(nodeId = 0) + // Project(nodeId = 1) + // BroadcastHashJoin(nodeId = 2) + // Project(nodeId = 3) + // Filter(nodeId = 4) + // ...(ignored) + Seq(true, false).foreach { enableWholeStage => + val df1 = generateRandomBytesDF() + val df2 = generateRandomBytesDF() + val df = df1.join(broadcast(df2), "a") + val nodeIds = if (enableWholeStage) { + Set(2L) + } else { + Set(1L) + } + val metrics = getSparkPlanMetrics(df, 2, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe (min, med, max)") + probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => + assert(probe.toDouble > 1.0) + } + } + } + } + + test("ShuffledHashJoin metrics") { + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "40", + "spark.sql.shuffle.partitions" -> "2", + "spark.sql.join.preferSortMergeJoin" -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = (1 to 10).map(i => (i, i.toString)).toSeq.toDF("key", "value") + // Assume the execution plan is + // ... -> ShuffledHashJoin(nodeId = 1) -> Project(nodeId = 0) + val df = df1.join(df2, "key") + val metrics = getSparkPlanMetrics(df, 1, Set(1L)) + testSparkPlanMetrics(df, 1, Map( + 1L -> ("ShuffledHashJoin", Map( + "number of output rows" -> 2L, + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"))) + ) + } + } + + test("ShuffledHashJoin metrics: track avg probe") { + // The executed plan looks like: + // Project [a#308, b#309, b#319] + // +- ShuffledHashJoin [a#308], [a#318], Inner, BuildRight + // :- Exchange hashpartitioning(a#308, 2) + // : +- Project [_1#305 AS a#308, _2#306 AS b#309] + // : +- Filter isnotnull(_1#305) + // : +- LocalTableScan [_1#305, _2#306] + // +- Exchange hashpartitioning(a#318, 2) + // +- Project [_1#315 AS a#318, _2#316 AS b#319] + // +- Filter isnotnull(_1#315) + // +- LocalTableScan [_1#315, _2#316] + // + // Assume the execution plan with node id is + // WholeStageCodegen disabled: + // Project(nodeId = 0) + // ShuffledHashJoin(nodeId = 1) + // ...(ignored) + // + // WholeStageCodegen enabled: + // WholeStageCodegen(nodeId = 0) + // Project(nodeId = 1) + // ShuffledHashJoin(nodeId = 2) + // ...(ignored) + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "5000000", + "spark.sql.shuffle.partitions" -> "2", + "spark.sql.join.preferSortMergeJoin" -> "false") { + Seq(true, false).foreach { enableWholeStage => + val df1 = generateRandomBytesDF(65535 * 5) + val df2 = generateRandomBytesDF(65535) + val df = df1.join(df2, "a") + val nodeIds = if (enableWholeStage) { + Set(2L) + } else { + Set(1L) + } + val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe (min, med, max)") + probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => + assert(probe.toDouble > 1.0) + } + } + } + } + } + test("BroadcastHashJoin(outer) metrics") { val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") From f9151bebca986d44cdab7699959fec2bc050773a Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Thu, 29 Jun 2017 16:03:15 -0700 Subject: [PATCH 108/118] [SPARK-21188][CORE] releaseAllLocksForTask should synchronize the whole method ## What changes were proposed in this pull request? Since the objects `readLocksByTask`, `writeLocksByTask` and `info`s are coupled and supposed to be modified by other threads concurrently, all the read and writes of them in the method `releaseAllLocksForTask` should be protected by a single synchronized block like other similar methods. ## How was this patch tested? existing tests Author: Feng Liu Closes #18400 from liufengdb/synchronize. --- .../spark/storage/BlockInfoManager.scala | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index 7064872ec1c77..219a0e799cc73 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -341,15 +341,11 @@ private[storage] class BlockInfoManager extends Logging { * * @return the ids of blocks whose pins were released */ - def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = { + def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = synchronized { val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]() - val readLocks = synchronized { - readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]()) - } - val writeLocks = synchronized { - writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty) - } + val readLocks = readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]()) + val writeLocks = writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty) for (blockId <- writeLocks) { infos.get(blockId).foreach { info => @@ -358,21 +354,19 @@ private[storage] class BlockInfoManager extends Logging { } blocksWithReleasedLocks += blockId } + readLocks.entrySet().iterator().asScala.foreach { entry => val blockId = entry.getElement val lockCount = entry.getCount blocksWithReleasedLocks += blockId - synchronized { - get(blockId).foreach { info => - info.readerCount -= lockCount - assert(info.readerCount >= 0) - } + get(blockId).foreach { info => + info.readerCount -= lockCount + assert(info.readerCount >= 0) } } - synchronized { - notifyAll() - } + notifyAll() + blocksWithReleasedLocks } From 4996c53949376153f9ebdc74524fed7226968808 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 30 Jun 2017 10:56:48 +0800 Subject: [PATCH 109/118] [SPARK-21253][CORE] Fix a bug that StreamCallback may not be notified if network errors happen ## What changes were proposed in this pull request? If a network error happens before processing StreamResponse/StreamFailure events, StreamCallback.onFailure won't be called. This PR fixes `failOutstandingRequests` to also notify outstanding StreamCallbacks. ## How was this patch tested? The new unit tests. Author: Shixiong Zhu Closes #18472 from zsxwing/fix-stream-2. --- .../spark/network/client/TransportClient.java | 2 +- .../client/TransportResponseHandler.java | 38 ++++++++++++++----- .../TransportResponseHandlerSuite.java | 31 ++++++++++++++- 3 files changed, 59 insertions(+), 12 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index a6f527c118218..8f354ad78bbaa 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -179,7 +179,7 @@ public void stream(String streamId, StreamCallback callback) { // written to the socket atomically, so that callbacks are called in the right order // when responses arrive. synchronized (this) { - handler.addStreamCallback(callback); + handler.addStreamCallback(streamId, callback); channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 41bead546cad6..be9f18203c8e4 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -24,6 +24,8 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; +import scala.Tuple2; + import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; import org.slf4j.Logger; @@ -56,7 +58,7 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; - private final Queue streamCallbacks; + private final Queue> streamCallbacks; private volatile boolean streamActive; /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ @@ -88,9 +90,9 @@ public void removeRpcRequest(long requestId) { outstandingRpcs.remove(requestId); } - public void addStreamCallback(StreamCallback callback) { + public void addStreamCallback(String streamId, StreamCallback callback) { timeOfLastRequestNs.set(System.nanoTime()); - streamCallbacks.offer(callback); + streamCallbacks.offer(Tuple2.apply(streamId, callback)); } @VisibleForTesting @@ -104,15 +106,31 @@ public void deactivateStream() { */ private void failOutstandingRequests(Throwable cause) { for (Map.Entry entry : outstandingFetches.entrySet()) { - entry.getValue().onFailure(entry.getKey().chunkIndex, cause); + try { + entry.getValue().onFailure(entry.getKey().chunkIndex, cause); + } catch (Exception e) { + logger.warn("ChunkReceivedCallback.onFailure throws exception", e); + } } for (Map.Entry entry : outstandingRpcs.entrySet()) { - entry.getValue().onFailure(cause); + try { + entry.getValue().onFailure(cause); + } catch (Exception e) { + logger.warn("RpcResponseCallback.onFailure throws exception", e); + } + } + for (Tuple2 entry : streamCallbacks) { + try { + entry._2().onFailure(entry._1(), cause); + } catch (Exception e) { + logger.warn("StreamCallback.onFailure throws exception", e); + } } // It's OK if new fetches appear, as they will fail immediately. outstandingFetches.clear(); outstandingRpcs.clear(); + streamCallbacks.clear(); } @Override @@ -190,8 +208,9 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof StreamResponse) { StreamResponse resp = (StreamResponse) message; - StreamCallback callback = streamCallbacks.poll(); - if (callback != null) { + Tuple2 entry = streamCallbacks.poll(); + if (entry != null) { + StreamCallback callback = entry._2(); if (resp.byteCount > 0) { StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, callback); @@ -216,8 +235,9 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof StreamFailure) { StreamFailure resp = (StreamFailure) message; - StreamCallback callback = streamCallbacks.poll(); - if (callback != null) { + Tuple2 entry = streamCallbacks.poll(); + if (entry != null) { + StreamCallback callback = entry._2(); try { callback.onFailure(resp.streamId, new RuntimeException(resp.error)); } catch (IOException ioe) { diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 09fc80d12d510..b4032c4c3f031 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.io.IOException; import java.nio.ByteBuffer; import io.netty.channel.Channel; @@ -127,7 +128,7 @@ public void testActiveStreams() throws Exception { StreamResponse response = new StreamResponse("stream", 1234L, null); StreamCallback cb = mock(StreamCallback.class); - handler.addStreamCallback(cb); + handler.addStreamCallback("stream", cb); assertEquals(1, handler.numOutstandingRequests()); handler.handle(response); assertEquals(1, handler.numOutstandingRequests()); @@ -135,9 +136,35 @@ public void testActiveStreams() throws Exception { assertEquals(0, handler.numOutstandingRequests()); StreamFailure failure = new StreamFailure("stream", "uh-oh"); - handler.addStreamCallback(cb); + handler.addStreamCallback("stream", cb); assertEquals(1, handler.numOutstandingRequests()); handler.handle(failure); assertEquals(0, handler.numOutstandingRequests()); } + + @Test + public void failOutstandingStreamCallbackOnClose() throws Exception { + Channel c = new LocalChannel(); + c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + TransportResponseHandler handler = new TransportResponseHandler(c); + + StreamCallback cb = mock(StreamCallback.class); + handler.addStreamCallback("stream-1", cb); + handler.channelInactive(); + + verify(cb).onFailure(eq("stream-1"), isA(IOException.class)); + } + + @Test + public void failOutstandingStreamCallbackOnException() throws Exception { + Channel c = new LocalChannel(); + c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + TransportResponseHandler handler = new TransportResponseHandler(c); + + StreamCallback cb = mock(StreamCallback.class); + handler.addStreamCallback("stream-1", cb); + handler.exceptionCaught(new IOException("Oops!")); + + verify(cb).onFailure(eq("stream-1"), isA(IOException.class)); + } } From 80f7ac3a601709dd9471092244612023363f54cd Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 30 Jun 2017 11:02:22 +0800 Subject: [PATCH 110/118] [SPARK-21253][CORE] Disable spark.reducer.maxReqSizeShuffleToMem ## What changes were proposed in this pull request? Disable spark.reducer.maxReqSizeShuffleToMem because it breaks the old shuffle service. Credits to wangyum Closes #18466 ## How was this patch tested? Jenkins Author: Shixiong Zhu Author: Yuming Wang Closes #18467 from zsxwing/SPARK-21253. --- .../scala/org/apache/spark/internal/config/package.scala | 3 ++- docs/configuration.md | 8 -------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index be63c637a3a13..8dee0d970c4c6 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -323,10 +323,11 @@ package object config { private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") + .internal() .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + "above this threshold. This is to avoid a giant request takes too much memory.") .bytesConf(ByteUnit.BYTE) - .createWithDefaultString("200m") + .createWithDefault(Long.MaxValue) private[spark] val TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES = ConfigBuilder("spark.taskMetrics.trackUpdatedBlockStatuses") diff --git a/docs/configuration.md b/docs/configuration.md index c8e61537a457c..bd6a1f9e240e2 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -528,14 +528,6 @@ Apart from these, the following properties are also available, and may be useful By allowing it to limit the number of fetch requests, this scenario can be mitigated. - - spark.reducer.maxReqSizeShuffleToMem - 200m - - The blocks of a shuffle request will be fetched to disk when size of the request is above - this threshold. This is to avoid a giant request takes too much memory. - - spark.shuffle.compress true From 88a536babf119b7e331d02aac5d52b57658803bf Mon Sep 17 00:00:00 2001 From: IngoSchuster Date: Fri, 30 Jun 2017 11:16:09 +0800 Subject: [PATCH 111/118] [SPARK-21176][WEB UI] Limit number of selector threads for admin ui proxy servlets to 8 ## What changes were proposed in this pull request? Please see also https://issues.apache.org/jira/browse/SPARK-21176 This change limits the number of selector threads that jetty creates to maximum 8 per proxy servlet (Jetty default is number of processors / 2). The newHttpClient for Jettys ProxyServlet class is overwritten to avoid the Jetty defaults (which are designed for high-performance http servers). Once https://github.com/eclipse/jetty.project/issues/1643 is available, the code could be cleaned up to avoid the method override. I really need this on v2.1.1 - what is the best way for a backport automatic merge works fine)? Shall I create another PR? ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) The patch was tested manually on a Spark cluster with a head node that has 88 processors using JMX to verify that the number of selector threads is now limited to 8 per proxy. gurvindersingh zsxwing can you please review the change? Author: IngoSchuster Author: Ingo Schuster Closes #18437 from IngoSchuster/master. --- .../main/scala/org/apache/spark/ui/JettyUtils.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index edf328b5ae538..b9371c7ad7b45 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -26,6 +26,8 @@ import scala.language.implicitConversions import scala.xml.Node import org.eclipse.jetty.client.api.Response +import org.eclipse.jetty.client.HttpClient +import org.eclipse.jetty.client.http.HttpClientTransportOverHTTP import org.eclipse.jetty.proxy.ProxyServlet import org.eclipse.jetty.server._ import org.eclipse.jetty.server.handler._ @@ -208,6 +210,16 @@ private[spark] object JettyUtils extends Logging { rewrittenURI.toString() } + override def newHttpClient(): HttpClient = { + // SPARK-21176: Use the Jetty logic to calculate the number of selector threads (#CPUs/2), + // but limit it to 8 max. + // Otherwise, it might happen that we exhaust the threadpool since in reverse proxy mode + // a proxy is instantiated for each executor. If the head node has many processors, this + // can quickly add up to an unreasonably high number of threads. + val numSelectors = math.max(1, math.min(8, Runtime.getRuntime().availableProcessors() / 2)) + new HttpClient(new HttpClientTransportOverHTTP(numSelectors), null) + } + override def filterServerResponseHeader( clientRequest: HttpServletRequest, serverResponse: Response, From cfc696f4a4289acf132cb26baf7c02c5b6305277 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 29 Jun 2017 20:56:37 -0700 Subject: [PATCH 112/118] [SPARK-21253][CORE][HOTFIX] Fix Scala 2.10 build ## What changes were proposed in this pull request? A follow up PR to fix Scala 2.10 build for #18472 ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #18478 from zsxwing/SPARK-21253-2. --- .../apache/spark/network/client/TransportResponseHandler.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index be9f18203c8e4..340b8b96aabc6 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -92,7 +92,7 @@ public void removeRpcRequest(long requestId) { public void addStreamCallback(String streamId, StreamCallback callback) { timeOfLastRequestNs.set(System.nanoTime()); - streamCallbacks.offer(Tuple2.apply(streamId, callback)); + streamCallbacks.offer(new Tuple2<>(streamId, callback)); } @VisibleForTesting From e2f32ee45ac907f1f53fde7e412676a849a94872 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 30 Jun 2017 12:34:09 +0800 Subject: [PATCH 113/118] [SPARK-21258][SQL] Fix WindowExec complex object aggregation with spilling ## What changes were proposed in this pull request? `WindowExec` currently improperly stores complex objects (UnsafeRow, UnsafeArrayData, UnsafeMapData, UTF8String) during aggregation by keeping a reference in the buffer used by `GeneratedMutableProjections` to the actual input data. Things go wrong when the input object (or the backing bytes) are reused for other things. This could happen in window functions when it starts spilling to disk. When reading the back the spill files the `UnsafeSorterSpillReader` reuses the buffer to which the `UnsafeRow` points, leading to weird corruption scenario's. Note that this only happens for aggregate functions that preserve (parts of) their input, for example `FIRST`, `LAST`, `MIN` & `MAX`. This was not seen before, because the spilling logic was not doing actual spills as much and actually used an in-memory page. This page was not cleaned up during window processing and made sure unsafe objects point to their own dedicated memory location. This was changed by https://github.com/apache/spark/pull/16909, after this PR Spark spills more eagerly. This PR provides a surgical fix because we are close to releasing Spark 2.2. This change just makes sure that there cannot be any object reuse at the expensive of a little bit of performance. We will follow-up with a more subtle solution at a later point. ## How was this patch tested? Added a regression test to `DataFrameWindowFunctionsSuite`. Author: Herman van Hovell Closes #18470 from hvanhovell/SPARK-21258. --- .../execution/window/AggregateProcessor.scala | 7 ++- .../sql/DataFrameWindowFunctionsSuite.scala | 47 ++++++++++++++++++- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala index bc141b36e63b4..2195c6ea95948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala @@ -145,10 +145,13 @@ private[window] final class AggregateProcessor( /** Update the buffer. */ def update(input: InternalRow): Unit = { - updateProjection(join(buffer, input)) + // TODO(hvanhovell) this sacrifices performance for correctness. We should make sure that + // MutableProjection makes copies of the complex input objects it buffer. + val copy = input.copy() + updateProjection(join(buffer, copy)) var i = 0 while (i < numImperatives) { - imperatives(i).update(buffer, input) + imperatives(i).update(buffer, copy) i += 1 } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 1255c49104718..204858fa29787 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{DataType, LongType, StructType} +import org.apache.spark.sql.types._ /** * Window function testing for DataFrame API. @@ -423,4 +424,48 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { df.select(selectList: _*).where($"value" < 2), Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0))) } + + test("SPARK-21258: complex object in combination with spilling") { + // Make sure we trigger the spilling path. + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "17") { + val sampleSchema = new StructType(). + add("f0", StringType). + add("f1", LongType). + add("f2", ArrayType(new StructType(). + add("f20", StringType))). + add("f3", ArrayType(new StructType(). + add("f30", StringType))) + + val w0 = Window.partitionBy("f0").orderBy("f1") + val w1 = w0.rowsBetween(Long.MinValue, Long.MaxValue) + + val c0 = first(struct($"f2", $"f3")).over(w0) as "c0" + val c1 = last(struct($"f2", $"f3")).over(w1) as "c1" + + val input = + """{"f1":1497820153720,"f2":[{"f20":"x","f21":0}],"f3":[{"f30":"x","f31":0}]} + |{"f1":1497802179638} + |{"f1":1497802189347} + |{"f1":1497802189593} + |{"f1":1497802189597} + |{"f1":1497802189599} + |{"f1":1497802192103} + |{"f1":1497802193414} + |{"f1":1497802193577} + |{"f1":1497802193709} + |{"f1":1497802202883} + |{"f1":1497802203006} + |{"f1":1497802203743} + |{"f1":1497802203834} + |{"f1":1497802203887} + |{"f1":1497802203893} + |{"f1":1497802203976} + |{"f1":1497820168098} + |""".stripMargin.split("\n").toSeq + + import testImplicits._ + + spark.read.schema(sampleSchema).json(input.toDS()).select(c0, c1).foreach { _ => () } + } + } } From fddb63f46345be36c40d9a7f3660920af6502bbd Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Thu, 29 Jun 2017 21:35:01 -0700 Subject: [PATCH 114/118] [SPARK-20889][SPARKR] Grouped documentation for MISC column methods ## What changes were proposed in this pull request? Grouped documentation for column misc methods. Author: actuaryzhang Author: Wayne Zhang Closes #18448 from actuaryzhang/sparkRDocMisc. --- R/pkg/R/functions.R | 98 +++++++++++++++++++++------------------------ R/pkg/R/generics.R | 15 ++++--- 2 files changed, 55 insertions(+), 58 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index cb09e847d739a..67cb7a7f6db08 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -150,6 +150,27 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))} NULL +#' Miscellaneous functions for Column operations +#' +#' Miscellaneous functions defined for \code{Column}. +#' +#' @param x Column to compute on. In \code{sha2}, it is one of 224, 256, 384, or 512. +#' @param y Column to compute on. +#' @param ... additional Columns. +#' @name column_misc_functions +#' @rdname column_misc_functions +#' @family misc functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)[, 1:2]) +#' tmp <- mutate(df, v1 = crc32(df$model), v2 = hash(df$model), +#' v3 = hash(df$model, df$mpg), v4 = md5(df$model), +#' v5 = sha1(df$model), v6 = sha2(df$model, 256)) +#' head(tmp) +#' } +NULL + #' @details #' \code{lit}: A new Column is created to represent the literal value. #' If the parameter is a Column, it is returned unchanged. @@ -569,19 +590,13 @@ setMethod("count", column(jc) }) -#' crc32 -#' -#' Calculates the cyclic redundancy check value (CRC32) of a binary column and -#' returns the value as a bigint. -#' -#' @param x Column to compute on. +#' @details +#' \code{crc32}: Calculates the cyclic redundancy check value (CRC32) of a binary column +#' and returns the value as a bigint. #' -#' @rdname crc32 -#' @name crc32 -#' @family misc functions -#' @aliases crc32,Column-method +#' @rdname column_misc_functions +#' @aliases crc32 crc32,Column-method #' @export -#' @examples \dontrun{crc32(df$c)} #' @note crc32 since 1.5.0 setMethod("crc32", signature(x = "Column"), @@ -590,19 +605,13 @@ setMethod("crc32", column(jc) }) -#' hash -#' -#' Calculates the hash code of given columns, and returns the result as a int column. -#' -#' @param x Column to compute on. -#' @param ... additional Column(s) to be included. +#' @details +#' \code{hash}: Calculates the hash code of given columns, and returns the result +#' as an int column. #' -#' @rdname hash -#' @name hash -#' @family misc functions -#' @aliases hash,Column-method +#' @rdname column_misc_functions +#' @aliases hash hash,Column-method #' @export -#' @examples \dontrun{hash(df$c)} #' @note hash since 2.0.0 setMethod("hash", signature(x = "Column"), @@ -1055,19 +1064,13 @@ setMethod("max", column(jc) }) -#' md5 -#' -#' Calculates the MD5 digest of a binary column and returns the value +#' @details +#' \code{md5}: Calculates the MD5 digest of a binary column and returns the value #' as a 32 character hex string. #' -#' @param x Column to compute on. -#' -#' @rdname md5 -#' @name md5 -#' @family misc functions -#' @aliases md5,Column-method +#' @rdname column_misc_functions +#' @aliases md5 md5,Column-method #' @export -#' @examples \dontrun{md5(df$c)} #' @note md5 since 1.5.0 setMethod("md5", signature(x = "Column"), @@ -1307,19 +1310,13 @@ setMethod("second", column(jc) }) -#' sha1 -#' -#' Calculates the SHA-1 digest of a binary column and returns the value +#' @details +#' \code{sha1}: Calculates the SHA-1 digest of a binary column and returns the value #' as a 40 character hex string. #' -#' @param x Column to compute on. -#' -#' @rdname sha1 -#' @name sha1 -#' @family misc functions -#' @aliases sha1,Column-method +#' @rdname column_misc_functions +#' @aliases sha1 sha1,Column-method #' @export -#' @examples \dontrun{sha1(df$c)} #' @note sha1 since 1.5.0 setMethod("sha1", signature(x = "Column"), @@ -2309,19 +2306,14 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), column(jc) }) -#' sha2 -#' -#' Calculates the SHA-2 family of hash functions of a binary column and -#' returns the value as a hex string. +#' @details +#' \code{sha2}: Calculates the SHA-2 family of hash functions of a binary column and +#' returns the value as a hex string. The second argument \code{x} specifies the number +#' of bits, and is one of 224, 256, 384, or 512. #' -#' @param y column to compute SHA-2 on. -#' @param x one of 224, 256, 384, or 512. -#' @family misc functions -#' @rdname sha2 -#' @name sha2 -#' @aliases sha2,Column,numeric-method +#' @rdname column_misc_functions +#' @aliases sha2 sha2,Column,numeric-method #' @export -#' @examples \dontrun{sha2(df$c, 256)} #' @note sha2 since 1.5.0 setMethod("sha2", signature(y = "Column", x = "numeric"), function(y, x) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 1deb057bb1b82..bdd4b360f4973 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -992,8 +992,9 @@ setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) #' @name NULL setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) -#' @rdname crc32 +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("crc32", function(x) { standardGeneric("crc32") }) #' @rdname column_nonaggregate_functions @@ -1006,8 +1007,9 @@ setGeneric("create_array", function(x, ...) { standardGeneric("create_array") }) #' @name NULL setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) -#' @rdname hash +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("hash", function(x, ...) { standardGeneric("hash") }) #' @param x empty. Should be used with no argument. @@ -1205,8 +1207,9 @@ setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) #' @name NULL setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) -#' @rdname md5 +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("md5", function(x) { standardGeneric("md5") }) #' @rdname column_datetime_functions @@ -1350,12 +1353,14 @@ setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) #' @name NULL setGeneric("second", function(x) { standardGeneric("second") }) -#' @rdname sha1 +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("sha1", function(x) { standardGeneric("sha1") }) -#' @rdname sha2 +#' @rdname column_misc_functions #' @export +#' @name NULL setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) #' @rdname column_math_functions From 52981715bb8d653a1141f55b36da804412eb783a Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Thu, 29 Jun 2017 23:00:50 -0700 Subject: [PATCH 115/118] [SPARK-20889][SPARKR] Grouped documentation for COLLECTION column methods ## What changes were proposed in this pull request? Grouped documentation for column collection methods. Author: actuaryzhang Author: Wayne Zhang Closes #18458 from actuaryzhang/sparkRDocCollection. --- R/pkg/R/functions.R | 204 +++++++++++++++++++------------------------- R/pkg/R/generics.R | 27 ++++-- 2 files changed, 108 insertions(+), 123 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 67cb7a7f6db08..a1f5c4f8cc18d 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -171,6 +171,35 @@ NULL #' } NULL +#' Collection functions for Column operations +#' +#' Collection functions defined for \code{Column}. +#' +#' @param x Column to compute on. Note the difference in the following methods: +#' \itemize{ +#' \item \code{to_json}: it is the column containing the struct or array of the structs. +#' \item \code{from_json}: it is the column containing the JSON string. +#' } +#' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains +#' additional named properties to control how it is converted, accepts the same +#' options as the JSON data source. +#' @name column_collection_functions +#' @rdname column_collection_functions +#' @family collection functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) +#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) +#' head(tmp2) +#' head(select(tmp, posexplode(tmp$v1))) +#' head(select(tmp, sort_array(tmp$v1))) +#' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))} +NULL + #' @details #' \code{lit}: A new Column is created to represent the literal value. #' If the parameter is a Column, it is returned unchanged. @@ -1642,30 +1671,23 @@ setMethod("to_date", column(jc) }) -#' to_json -#' -#' Converts a column containing a \code{structType} or array of \code{structType} into a Column -#' of JSON string. Resolving the Column can fail if an unsupported type is encountered. -#' -#' @param x Column containing the struct or array of the structs -#' @param ... additional named properties to control how it is converted, accepts the same options -#' as the JSON data source. +#' @details +#' \code{to_json}: Converts a column containing a \code{structType} or array of \code{structType} +#' into a Column of JSON string. Resolving the Column can fail if an unsupported type is encountered. #' -#' @family non-aggregate functions -#' @rdname to_json -#' @name to_json -#' @aliases to_json,Column-method +#' @rdname column_collection_functions +#' @aliases to_json to_json,Column-method #' @export #' @examples +#' #' \dontrun{ #' # Converts a struct into a JSON object -#' df <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") -#' select(df, to_json(df$d, dateFormat = 'dd/MM/yyyy')) +#' df2 <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") +#' select(df2, to_json(df2$d, dateFormat = 'dd/MM/yyyy')) #' #' # Converts an array of structs into a JSON array -#' df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") -#' select(df, to_json(df$people)) -#'} +#' df2 <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") +#' df2 <- mutate(df2, people_json = to_json(df2$people))} #' @note to_json since 2.2.0 setMethod("to_json", signature(x = "Column"), function(x, ...) { @@ -2120,28 +2142,28 @@ setMethod("date_format", signature(y = "Column", x = "character"), column(jc) }) -#' from_json -#' -#' Parses a column containing a JSON string into a Column of \code{structType} with the specified -#' \code{schema} or array of \code{structType} if \code{as.json.array} is set to \code{TRUE}. -#' If the string is unparseable, the Column will contains the value NA. +#' @details +#' \code{from_json}: Parses a column containing a JSON string into a Column of \code{structType} +#' with the specified \code{schema} or array of \code{structType} if \code{as.json.array} is set +#' to \code{TRUE}. If the string is unparseable, the Column will contain the value NA. #' -#' @param x Column containing the JSON string. +#' @rdname column_collection_functions #' @param schema a structType object to use as the schema to use when parsing the JSON string. #' @param as.json.array indicating if input string is JSON array of objects or a single object. -#' @param ... additional named properties to control how the json is parsed, accepts the same -#' options as the JSON data source. -#' -#' @family non-aggregate functions -#' @rdname from_json -#' @name from_json -#' @aliases from_json,Column,structType-method +#' @aliases from_json from_json,Column,structType-method #' @export #' @examples +#' #' \dontrun{ -#' schema <- structType(structField("name", "string"), -#' select(df, from_json(df$value, schema, dateFormat = "dd/MM/yyyy")) -#'} +#' df2 <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") +#' df2 <- mutate(df2, d2 = to_json(df2$d, dateFormat = 'dd/MM/yyyy')) +#' schema <- structType(structField("date", "string")) +#' head(select(df2, from_json(df2$d2, schema, dateFormat = 'dd/MM/yyyy'))) + +#' df2 <- sql("SELECT named_struct('name', 'Bob') as people") +#' df2 <- mutate(df2, people_json = to_json(df2$people)) +#' schema <- structType(structField("name", "string")) +#' head(select(df2, from_json(df2$people_json, schema)))} #' @note from_json since 2.2.0 setMethod("from_json", signature(x = "Column", schema = "structType"), function(x, schema, as.json.array = FALSE, ...) { @@ -3101,18 +3123,14 @@ setMethod("row_number", ###################### Collection functions###################### -#' array_contains -#' -#' Returns null if the array is null, true if the array contains the value, and false otherwise. +#' @details +#' \code{array_contains}: Returns null if the array is null, true if the array contains +#' the value, and false otherwise. #' -#' @param x A Column #' @param value A value to be checked if contained in the column -#' @rdname array_contains -#' @aliases array_contains,Column-method -#' @name array_contains -#' @family collection functions +#' @rdname column_collection_functions +#' @aliases array_contains array_contains,Column-method #' @export -#' @examples \dontrun{array_contains(df$c, 1)} #' @note array_contains since 1.6.0 setMethod("array_contains", signature(x = "Column", value = "ANY"), @@ -3121,18 +3139,12 @@ setMethod("array_contains", column(jc) }) -#' explode -#' -#' Creates a new row for each element in the given array or map column. -#' -#' @param x Column to compute on +#' @details +#' \code{explode}: Creates a new row for each element in the given array or map column. #' -#' @rdname explode -#' @name explode -#' @family collection functions -#' @aliases explode,Column-method +#' @rdname column_collection_functions +#' @aliases explode explode,Column-method #' @export -#' @examples \dontrun{explode(df$c)} #' @note explode since 1.5.0 setMethod("explode", signature(x = "Column"), @@ -3141,18 +3153,12 @@ setMethod("explode", column(jc) }) -#' size -#' -#' Returns length of array or map. -#' -#' @param x Column to compute on +#' @details +#' \code{size}: Returns length of array or map. #' -#' @rdname size -#' @name size -#' @aliases size,Column-method -#' @family collection functions +#' @rdname column_collection_functions +#' @aliases size size,Column-method #' @export -#' @examples \dontrun{size(df$c)} #' @note size since 1.5.0 setMethod("size", signature(x = "Column"), @@ -3161,25 +3167,16 @@ setMethod("size", column(jc) }) -#' sort_array -#' -#' Sorts the input array in ascending or descending order according +#' @details +#' \code{sort_array}: Sorts the input array in ascending or descending order according #' to the natural ordering of the array elements. #' -#' @param x A Column to sort +#' @rdname column_collection_functions #' @param asc A logical flag indicating the sorting order. #' TRUE, sorting is in ascending order. #' FALSE, sorting is in descending order. -#' @rdname sort_array -#' @name sort_array -#' @aliases sort_array,Column-method -#' @family collection functions +#' @aliases sort_array sort_array,Column-method #' @export -#' @examples -#' \dontrun{ -#' sort_array(df$c) -#' sort_array(df$c, FALSE) -#' } #' @note sort_array since 1.6.0 setMethod("sort_array", signature(x = "Column"), @@ -3188,18 +3185,13 @@ setMethod("sort_array", column(jc) }) -#' posexplode -#' -#' Creates a new row for each element with position in the given array or map column. -#' -#' @param x Column to compute on +#' @details +#' \code{posexplode}: Creates a new row for each element with position in the given array +#' or map column. #' -#' @rdname posexplode -#' @name posexplode -#' @family collection functions -#' @aliases posexplode,Column-method +#' @rdname column_collection_functions +#' @aliases posexplode posexplode,Column-method #' @export -#' @examples \dontrun{posexplode(df$c)} #' @note posexplode since 2.1.0 setMethod("posexplode", signature(x = "Column"), @@ -3325,27 +3317,24 @@ setMethod("repeat_string", column(jc) }) -#' explode_outer -#' -#' Creates a new row for each element in the given array or map column. +#' @details +#' \code{explode}: Creates a new row for each element in the given array or map column. #' Unlike \code{explode}, if the array/map is \code{null} or empty #' then \code{null} is produced. #' -#' @param x Column to compute on #' -#' @rdname explode_outer -#' @name explode_outer -#' @family collection functions -#' @aliases explode_outer,Column-method +#' @rdname column_collection_functions +#' @aliases explode_outer explode_outer,Column-method #' @export #' @examples +#' #' \dontrun{ -#' df <- createDataFrame(data.frame( +#' df2 <- createDataFrame(data.frame( #' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") #' )) #' -#' head(select(df, df$id, explode_outer(split_string(df$text, ",")))) -#' } +#' head(select(df2, df2$id, explode_outer(split_string(df2$text, ",")))) +#' head(select(df2, df2$id, posexplode_outer(split_string(df2$text, ","))))} #' @note explode_outer since 2.3.0 setMethod("explode_outer", signature(x = "Column"), @@ -3354,27 +3343,14 @@ setMethod("explode_outer", column(jc) }) -#' posexplode_outer -#' -#' Creates a new row for each element with position in the given array or map column. -#' Unlike \code{posexplode}, if the array/map is \code{null} or empty +#' @details +#' \code{posexplode_outer}: Creates a new row for each element with position in the given +#' array or map column. Unlike \code{posexplode}, if the array/map is \code{null} or empty #' then the row (\code{null}, \code{null}) is produced. #' -#' @param x Column to compute on -#' -#' @rdname posexplode_outer -#' @name posexplode_outer -#' @family collection functions -#' @aliases posexplode_outer,Column-method +#' @rdname column_collection_functions +#' @aliases posexplode_outer posexplode_outer,Column-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(data.frame( -#' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") -#' )) -#' -#' head(select(df, df$id, posexplode_outer(split_string(df$text, ",")))) -#' } #' @note posexplode_outer since 2.3.0 setMethod("posexplode_outer", signature(x = "Column"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index bdd4b360f4973..b901b74e4728d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -913,8 +913,9 @@ setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) #' @name NULL setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) -#' @rdname array_contains +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) #' @rdname column_string_functions @@ -1062,12 +1063,14 @@ setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") #' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) -#' @rdname explode +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("explode", function(x) { standardGeneric("explode") }) -#' @rdname explode_outer +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) #' @rdname column_nonaggregate_functions @@ -1090,8 +1093,9 @@ setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) #' @name NULL setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) -#' @rdname from_json +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) #' @rdname column_datetime_functions @@ -1275,12 +1279,14 @@ setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_ra #' @name NULL setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) -#' @rdname posexplode +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) -#' @rdname posexplode_outer +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") }) #' @rdname column_datetime_functions @@ -1383,8 +1389,9 @@ setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUns #' @name NULL setGeneric("signum", function(x) { standardGeneric("signum") }) -#' @rdname size +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("size", function(x) { standardGeneric("size") }) #' @rdname column_aggregate_functions @@ -1392,8 +1399,9 @@ setGeneric("size", function(x) { standardGeneric("size") }) #' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) -#' @rdname sort_array +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) #' @rdname column_string_functions @@ -1456,8 +1464,9 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @name NULL setGeneric("to_date", function(x, format) { standardGeneric("to_date") }) -#' @rdname to_json +#' @rdname column_collection_functions #' @export +#' @name NULL setGeneric("to_json", function(x, ...) { standardGeneric("to_json") }) #' @rdname column_datetime_functions From 49d767d838691fc7d964be2c4349662f5500ff2b Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Fri, 30 Jun 2017 20:02:15 +0800 Subject: [PATCH 116/118] [SPARK-18710][ML] Add offset in GLM ## What changes were proposed in this pull request? Add support for offset in GLM. This is useful for at least two reasons: 1. Account for exposure: e.g., when modeling the number of accidents, we may need to use miles driven as an offset to access factors on frequency. 2. Test incremental effects of new variables: we can use predictions from the existing model as offset and run a much smaller model on only new variables. This avoids re-estimating the large model with all variables (old + new) and can be very important for efficient large-scaled analysis. ## How was this patch tested? New test. yanboliang srowen felixcheung sethah Author: actuaryzhang Closes #16699 from actuaryzhang/offset. --- .../apache/spark/ml/feature/Instance.scala | 21 + .../IterativelyReweightedLeastSquares.scala | 14 +- .../spark/ml/optim/WeightedLeastSquares.scala | 2 +- .../GeneralizedLinearRegression.scala | 184 +++-- ...erativelyReweightedLeastSquaresSuite.scala | 40 +- .../GeneralizedLinearRegressionSuite.scala | 634 ++++++++++-------- 6 files changed, 534 insertions(+), 361 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala index cce3ca45ccd8f..dd56fbbfa2b63 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala @@ -27,3 +27,24 @@ import org.apache.spark.ml.linalg.Vector * @param features The vector of features for this data point. */ private[ml] case class Instance(label: Double, weight: Double, features: Vector) + +/** + * Case class that represents an instance of data point with + * label, weight, offset and features. + * This is mainly used in GeneralizedLinearRegression currently. + * + * @param label Label for this data point. + * @param weight The weight of this instance. + * @param offset The offset used for this data point. + * @param features The vector of features for this data point. + */ +private[ml] case class OffsetInstance( + label: Double, + weight: Double, + offset: Double, + features: Vector) { + + /** Converts to an [[Instance]] object by leaving out the offset. */ + def toInstance: Instance = Instance(label, weight, features) + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala index 9c495512422ba..6961b45f55e4d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.optim import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg._ import org.apache.spark.rdd.RDD @@ -43,7 +43,7 @@ private[ml] class IterativelyReweightedLeastSquaresModel( * find M-estimator in robust regression and other optimization problems. * * @param initialModel the initial guess model. - * @param reweightFunc the reweight function which is used to update offsets and weights + * @param reweightFunc the reweight function which is used to update working labels and weights * at each iteration. * @param fitIntercept whether to fit intercept. * @param regParam L2 regularization parameter used by WLS. @@ -57,13 +57,13 @@ private[ml] class IterativelyReweightedLeastSquaresModel( */ private[ml] class IterativelyReweightedLeastSquares( val initialModel: WeightedLeastSquaresModel, - val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double), + val reweightFunc: (OffsetInstance, WeightedLeastSquaresModel) => (Double, Double), val fitIntercept: Boolean, val regParam: Double, val maxIter: Int, val tol: Double) extends Logging with Serializable { - def fit(instances: RDD[Instance]): IterativelyReweightedLeastSquaresModel = { + def fit(instances: RDD[OffsetInstance]): IterativelyReweightedLeastSquaresModel = { var converged = false var iter = 0 @@ -75,10 +75,10 @@ private[ml] class IterativelyReweightedLeastSquares( oldModel = model - // Update offsets and weights using reweightFunc + // Update working labels and weights using reweightFunc val newInstances = instances.map { instance => - val (newOffset, newWeight) = reweightFunc(instance, oldModel) - Instance(newOffset, newWeight, instance.features) + val (newLabel, newWeight) = reweightFunc(instance, oldModel) + Instance(newLabel, newWeight, instance.features) } // Estimate new model diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 56ab9675700a0..32b0af72ba9bb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.optim import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg._ import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index bff0d9bbb46ff..ce3460ae43566 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -26,8 +26,8 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.PredictorParams -import org.apache.spark.ml.feature.Instance -import org.apache.spark.ml.linalg.{BLAS, Vector} +import org.apache.spark.ml.feature.{Instance, OffsetInstance} +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} import org.apache.spark.ml.optim._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -138,6 +138,27 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam @Since("2.0.0") def getLinkPredictionCol: String = $(linkPredictionCol) + /** + * Param for offset column name. If this is not set or empty, we treat all instance offsets + * as 0.0. The feature specified as offset has a constant coefficient of 1.0. + * @group param + */ + @Since("2.3.0") + final val offsetCol: Param[String] = new Param[String](this, "offsetCol", "The offset " + + "column name. If this is not set or empty, we treat all instance offsets as 0.0") + + /** @group getParam */ + @Since("2.3.0") + def getOffsetCol: String = $(offsetCol) + + /** Checks whether weight column is set and nonempty. */ + private[regression] def hasWeightCol: Boolean = + isSet(weightCol) && $(weightCol).nonEmpty + + /** Checks whether offset column is set and nonempty. */ + private[regression] def hasOffsetCol: Boolean = + isSet(offsetCol) && $(offsetCol).nonEmpty + /** Checks whether we should output link prediction. */ private[regression] def hasLinkPredictionCol: Boolean = { isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty @@ -172,6 +193,11 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam } val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) + + if (hasOffsetCol) { + SchemaUtils.checkNumericType(schema, $(offsetCol)) + } + if (hasLinkPredictionCol) { SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType) } else { @@ -306,6 +332,16 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val @Since("2.0.0") def setWeightCol(value: String): this.type = set(weightCol, value) + /** + * Sets the value of param [[offsetCol]]. + * If this is not set or empty, we treat all instance offsets as 0.0. + * Default is not set, so all instances have offset 0.0. + * + * @group setParam + */ + @Since("2.3.0") + def setOffsetCol(value: String): this.type = set(offsetCol, value) + /** * Sets the solver algorithm used for optimization. * Currently only supports "irls" which is also the default solver. @@ -329,7 +365,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, weightCol, predictionCol, linkPredictionCol, + instr.logParams(labelCol, featuresCol, weightCol, offsetCol, predictionCol, linkPredictionCol, family, solver, fitIntercept, link, maxIter, regParam, tol) instr.logNumFeatures(numFeatures) @@ -343,15 +379,16 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val "GeneralizedLinearRegression was given data with 0 features, and with Param fitIntercept " + "set to false. To fit a model with 0 features, fitIntercept must be set to true." ) - val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = - dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val w = if (!hasWeightCol) lit(1.0) else col($(weightCol)) + val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType) val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity) { // TODO: Make standardizeFeatures and standardizeLabel configurable. + val instances: RDD[Instance] = + dataset.select(col($(labelCol)), w, offset, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, offset: Double, features: Vector) => + Instance(label - offset, weight, features) + } val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) val wlsModel = optimizer.fit(instances) @@ -362,6 +399,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val wlsModel.diagInvAtWA.toArray, 1, getSolver) model.setSummary(Some(trainingSummary)) } else { + val instances: RDD[OffsetInstance] = + dataset.select(col($(labelCol)), w, offset, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, offset: Double, features: Vector) => + OffsetInstance(label, weight, offset, features) + } // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) val optimizer = new IterativelyReweightedLeastSquares(initialModel, @@ -425,12 +467,12 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * Get the initial guess model for [[IterativelyReweightedLeastSquares]]. */ def initialize( - instances: RDD[Instance], + instances: RDD[OffsetInstance], fitIntercept: Boolean, regParam: Double): WeightedLeastSquaresModel = { val newInstances = instances.map { instance => val mu = family.initialize(instance.label, instance.weight) - val eta = predict(mu) + val eta = predict(mu) - instance.offset Instance(eta, instance.weight, instance.features) } // TODO: Make standardizeFeatures and standardizeLabel configurable. @@ -441,16 +483,16 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine } /** - * The reweight function used to update offsets and weights + * The reweight function used to update working labels and weights * at each iteration of [[IterativelyReweightedLeastSquares]]. */ - val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = { - (instance: Instance, model: WeightedLeastSquaresModel) => { - val eta = model.predict(instance.features) + val reweightFunc: (OffsetInstance, WeightedLeastSquaresModel) => (Double, Double) = { + (instance: OffsetInstance, model: WeightedLeastSquaresModel) => { + val eta = model.predict(instance.features) + instance.offset val mu = fitted(eta) - val offset = eta + (instance.label - mu) * link.deriv(mu) - val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) - (offset, weight) + val newLabel = eta - instance.offset + (instance.label - mu) * link.deriv(mu) + val newWeight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) + (newLabel, newWeight) } } } @@ -950,15 +992,22 @@ class GeneralizedLinearRegressionModel private[ml] ( private lazy val familyAndLink = FamilyAndLink(this) override protected def predict(features: Vector): Double = { - val eta = predictLink(features) + predict(features, 0.0) + } + + /** + * Calculates the predicted value when offset is set. + */ + private def predict(features: Vector, offset: Double): Double = { + val eta = predictLink(features, offset) familyAndLink.fitted(eta) } /** - * Calculate the link prediction (linear predictor) of the given instance. + * Calculates the link prediction (linear predictor) of the given instance. */ - private def predictLink(features: Vector): Double = { - BLAS.dot(features, coefficients) + intercept + private def predictLink(features: Vector, offset: Double): Double = { + BLAS.dot(features, coefficients) + intercept + offset } override def transform(dataset: Dataset[_]): DataFrame = { @@ -967,14 +1016,16 @@ class GeneralizedLinearRegressionModel private[ml] ( } override protected def transformImpl(dataset: Dataset[_]): DataFrame = { - val predictUDF = udf { (features: Vector) => predict(features) } - val predictLinkUDF = udf { (features: Vector) => predictLink(features) } + val predictUDF = udf { (features: Vector, offset: Double) => predict(features, offset) } + val predictLinkUDF = udf { (features: Vector, offset: Double) => predictLink(features, offset) } + + val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType) var output = dataset if ($(predictionCol).nonEmpty) { - output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)), offset)) } if (hasLinkPredictionCol) { - output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)))) + output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)), offset)) } output.toDF() } @@ -1146,9 +1197,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** Degrees of freedom. */ @Since("2.0.0") - lazy val degreesOfFreedom: Long = { - numInstances - rank - } + lazy val degreesOfFreedom: Long = numInstances - rank /** The residual degrees of freedom. */ @Since("2.0.0") @@ -1156,18 +1205,20 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** The residual degrees of freedom for the null model. */ @Since("2.0.0") - lazy val residualDegreeOfFreedomNull: Long = if (model.getFitIntercept) { - numInstances - 1 - } else { - numInstances + lazy val residualDegreeOfFreedomNull: Long = { + if (model.getFitIntercept) numInstances - 1 else numInstances } - private def weightCol: Column = { - if (!model.isDefined(model.weightCol) || model.getWeightCol.isEmpty) { - lit(1.0) - } else { - col(model.getWeightCol) - } + private def label: Column = col(model.getLabelCol).cast(DoubleType) + + private def prediction: Column = col(predictionCol) + + private def weight: Column = { + if (!model.hasWeightCol) lit(1.0) else col(model.getWeightCol) + } + + private def offset: Column = { + if (!model.hasOffsetCol) lit(0.0) else col(model.getOffsetCol).cast(DoubleType) } private[regression] lazy val devianceResiduals: DataFrame = { @@ -1175,25 +1226,23 @@ class GeneralizedLinearRegressionSummary private[regression] ( val r = math.sqrt(math.max(family.deviance(y, mu, weight), 0.0)) if (y > mu) r else -1.0 * r } - val w = weightCol predictions.select( - drUDF(col(model.getLabelCol), col(predictionCol), w).as("devianceResiduals")) + drUDF(label, prediction, weight).as("devianceResiduals")) } private[regression] lazy val pearsonResiduals: DataFrame = { val prUDF = udf { mu: Double => family.variance(mu) } - val w = weightCol - predictions.select(col(model.getLabelCol).minus(col(predictionCol)) - .multiply(sqrt(w)).divide(sqrt(prUDF(col(predictionCol)))).as("pearsonResiduals")) + predictions.select(label.minus(prediction) + .multiply(sqrt(weight)).divide(sqrt(prUDF(prediction))).as("pearsonResiduals")) } private[regression] lazy val workingResiduals: DataFrame = { val wrUDF = udf { (y: Double, mu: Double) => (y - mu) * link.deriv(mu) } - predictions.select(wrUDF(col(model.getLabelCol), col(predictionCol)).as("workingResiduals")) + predictions.select(wrUDF(label, prediction).as("workingResiduals")) } private[regression] lazy val responseResiduals: DataFrame = { - predictions.select(col(model.getLabelCol).minus(col(predictionCol)).as("responseResiduals")) + predictions.select(label.minus(prediction).as("responseResiduals")) } /** @@ -1225,16 +1274,35 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ @Since("2.0.0") lazy val nullDeviance: Double = { - val w = weightCol - val wtdmu: Double = if (model.getFitIntercept) { - val agg = predictions.agg(sum(w.multiply(col(model.getLabelCol))), sum(w)).first() - agg.getDouble(0) / agg.getDouble(1) + val intercept: Double = if (!model.getFitIntercept) { + 0.0 } else { - link.unlink(0.0) + /* + Estimate intercept analytically when there is no offset, or when there is offset but + the model is Gaussian family with identity link. Otherwise, fit an intercept only model. + */ + if (!model.hasOffsetCol || + (model.hasOffsetCol && family == Gaussian && link == Identity)) { + val agg = predictions.agg(sum(weight.multiply( + label.minus(offset))), sum(weight)).first() + link.link(agg.getDouble(0) / agg.getDouble(1)) + } else { + // Create empty feature column and fit intercept only model using param setting from model + val featureNull = "feature_" + java.util.UUID.randomUUID.toString + val paramMap = model.extractParamMap() + paramMap.put(model.featuresCol, featureNull) + if (family.name != "tweedie") { + paramMap.remove(model.variancePower) + } + val emptyVectorUDF = udf{ () => Vectors.zeros(0) } + model.parent.fit( + dataset.withColumn(featureNull, emptyVectorUDF()), paramMap + ).intercept + } } - predictions.select(col(model.getLabelCol).cast(DoubleType), w).rdd.map { - case Row(y: Double, weight: Double) => - family.deviance(y, wtdmu, weight) + predictions.select(label, offset, weight).rdd.map { + case Row(y: Double, offset: Double, weight: Double) => + family.deviance(y, link.unlink(intercept + offset), weight) }.sum() } @@ -1243,8 +1311,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ @Since("2.0.0") lazy val deviance: Double = { - val w = weightCol - predictions.select(col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map { + predictions.select(label, prediction, weight).rdd.map { case Row(label: Double, pred: Double, weight: Double) => family.deviance(label, pred, weight) }.sum() @@ -1269,10 +1336,9 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** Akaike Information Criterion (AIC) for the fitted model. */ @Since("2.0.0") lazy val aic: Double = { - val w = weightCol - val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0) + val weightSum = predictions.select(weight).agg(sum(weight)).first().getDouble(0) val t = predictions.select( - col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map { + label, prediction, weight).rdd.map { case Row(label: Double, pred: Double, weight: Double) => (label, pred, weight) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala index 50260952ecb66..6d143504fcf58 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.optim import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -26,8 +26,8 @@ import org.apache.spark.rdd.RDD class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { - private var instances1: RDD[Instance] = _ - private var instances2: RDD[Instance] = _ + private var instances1: RDD[OffsetInstance] = _ + private var instances2: RDD[OffsetInstance] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -39,10 +39,10 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes w <- c(1, 2, 3, 4) */ instances1 = sc.parallelize(Seq( - Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)), - Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), - Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) + OffsetInstance(1.0, 1.0, 0.0, Vectors.dense(0.0, 5.0).toSparse), + OffsetInstance(0.0, 2.0, 0.0, Vectors.dense(1.0, 2.0)), + OffsetInstance(1.0, 3.0, 0.0, Vectors.dense(2.0, 1.0)), + OffsetInstance(0.0, 4.0, 0.0, Vectors.dense(3.0, 3.0)) ), 2) /* R code: @@ -52,10 +52,10 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes w <- c(1, 2, 3, 4) */ instances2 = sc.parallelize(Seq( - Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + OffsetInstance(2.0, 1.0, 0.0, Vectors.dense(0.0, 5.0).toSparse), + OffsetInstance(8.0, 2.0, 0.0, Vectors.dense(1.0, 7.0)), + OffsetInstance(3.0, 3.0, 0.0, Vectors.dense(2.0, 11.0)), + OffsetInstance(9.0, 4.0, 0.0, Vectors.dense(3.0, 13.0)) ), 2) } @@ -156,7 +156,7 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes var idx = 0 for (fitIntercept <- Seq(false, true)) { val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, - standardizeFeatures = false, standardizeLabel = false).fit(instances2) + standardizeFeatures = false, standardizeLabel = false).fit(instances2.map(_.toInstance)) val irls = new IterativelyReweightedLeastSquares(initial, L1RegressionReweightFunc, fitIntercept, regParam = 0.0, maxIter = 200, tol = 1e-7).fit(instances2) val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1)) @@ -169,29 +169,29 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes object IterativelyReweightedLeastSquaresSuite { def BinomialReweightFunc( - instance: Instance, + instance: OffsetInstance, model: WeightedLeastSquaresModel): (Double, Double) = { - val eta = model.predict(instance.features) + val eta = model.predict(instance.features) + instance.offset val mu = 1.0 / (1.0 + math.exp(-1.0 * eta)) - val z = eta + (instance.label - mu) / (mu * (1.0 - mu)) + val z = eta - instance.offset + (instance.label - mu) / (mu * (1.0 - mu)) val w = mu * (1 - mu) * instance.weight (z, w) } def PoissonReweightFunc( - instance: Instance, + instance: OffsetInstance, model: WeightedLeastSquaresModel): (Double, Double) = { - val eta = model.predict(instance.features) + val eta = model.predict(instance.features) + instance.offset val mu = math.exp(eta) - val z = eta + (instance.label - mu) / mu + val z = eta - instance.offset + (instance.label - mu) / mu val w = mu * instance.weight (z, w) } def L1RegressionReweightFunc( - instance: Instance, + instance: OffsetInstance, model: WeightedLeastSquaresModel): (Double, Double) = { - val eta = model.predict(instance.features) + val eta = model.predict(instance.features) + instance.offset val e = math.max(math.abs(eta - instance.label), 1e-7) val w = 1 / e val y = instance.label diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index f7c7c001a36af..cfaa57314bd66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LogisticRegressionSuite._ -import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} @@ -797,77 +797,160 @@ class GeneralizedLinearRegressionSuite } } - test("glm summary: gaussian family with weight") { + test("generalized linear regression with weight and offset") { /* - R code: + R code: + library(statmod) - A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) - b <- c(17, 19, 23, 29) - w <- c(1, 2, 3, 4) - df <- as.data.frame(cbind(A, b)) - */ - val datasetWithWeight = Seq( - Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) + df <- as.data.frame(matrix(c( + 0.2, 1.0, 2.0, 0.0, 5.0, + 0.5, 2.1, 0.5, 1.0, 2.0, + 0.9, 0.4, 1.0, 2.0, 1.0, + 0.7, 0.7, 0.0, 3.0, 3.0), 4, 5, byrow = TRUE)) + families <- list(gaussian, binomial, poisson, Gamma, tweedie(1.5)) + f1 <- V1 ~ -1 + V4 + V5 + f2 <- V1 ~ V4 + V5 + for (f in c(f1, f2)) { + for (fam in families) { + model <- glm(f, df, family = fam, weights = V2, offset = V3) + print(as.vector(coef(model))) + } + } + [1] 0.5169222 -0.3344444 + [1] 0.9419107 -0.6864404 + [1] 0.1812436 -0.6568422 + [1] -0.2869094 0.7857710 + [1] 0.1055254 0.2979113 + [1] -0.05990345 0.53188982 -0.32118415 + [1] -0.2147117 0.9911750 -0.6356096 + [1] -1.5616130 0.6646470 -0.3192581 + [1] 0.3390397 -0.3406099 0.6870259 + [1] 0.3665034 0.1039416 0.1484616 + */ + val dataset = Seq( + OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)), + OffsetInstance(0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)), + OffsetInstance(0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)), + OffsetInstance(0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0)) ).toDF() + + val expected = Seq( + Vectors.dense(0, 0.5169222, -0.3344444), + Vectors.dense(0, 0.9419107, -0.6864404), + Vectors.dense(0, 0.1812436, -0.6568422), + Vectors.dense(0, -0.2869094, 0.785771), + Vectors.dense(0, 0.1055254, 0.2979113), + Vectors.dense(-0.05990345, 0.53188982, -0.32118415), + Vectors.dense(-0.2147117, 0.991175, -0.6356096), + Vectors.dense(-1.561613, 0.664647, -0.3192581), + Vectors.dense(0.3390397, -0.3406099, 0.6870259), + Vectors.dense(0.3665034, 0.1039416, 0.1484616)) + + import GeneralizedLinearRegression._ + + var idx = 0 + + for (fitIntercept <- Seq(false, true)) { + for (family <- Seq("gaussian", "binomial", "poisson", "gamma", "tweedie")) { + val trainer = new GeneralizedLinearRegression().setFamily(family) + .setFitIntercept(fitIntercept).setOffsetCol("offset") + .setWeightCol("weight").setLinkPredictionCol("linkPrediction") + if (family == "tweedie") trainer.setVariancePower(1.5) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, s"Model mismatch: GLM with family = $family," + + s" and fitIntercept = $fitIntercept.") + + val familyLink = FamilyAndLink(trainer) + model.transform(dataset).select("features", "offset", "prediction", "linkPrediction") + .collect().foreach { + case Row(features: DenseVector, offset: Double, prediction1: Double, + linkPrediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + offset + val prediction2 = familyLink.fitted(eta) + val linkPrediction2 = eta + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"family = $family, and fitIntercept = $fitIntercept.") + assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " + + s"GLM with family = $family, and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("glm summary: gaussian family with weight and offset") { /* - R code: + R code: - model <- glm(formula = "b ~ .", family="gaussian", data = df, weights = w) - summary(model) + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(17, 19, 23, 29) + w <- c(1, 2, 3, 4) + off <- c(2, 3, 1, 4) + df <- as.data.frame(cbind(A, b)) + */ + val dataset = Seq( + OffsetInstance(17.0, 1.0, 2.0, Vectors.dense(0.0, 5.0).toSparse), + OffsetInstance(19.0, 2.0, 3.0, Vectors.dense(1.0, 7.0)), + OffsetInstance(23.0, 3.0, 1.0, Vectors.dense(2.0, 11.0)), + OffsetInstance(29.0, 4.0, 4.0, Vectors.dense(3.0, 13.0)) + ).toDF() + /* + R code: - Deviance Residuals: - 1 2 3 4 - 1.920 -1.358 -1.109 0.960 + model <- glm(formula = "b ~ .", family = "gaussian", data = df, + weights = w, offset = off) + summary(model) - Coefficients: - Estimate Std. Error t value Pr(>|t|) - (Intercept) 18.080 9.608 1.882 0.311 - V1 6.080 5.556 1.094 0.471 - V2 -0.600 1.960 -0.306 0.811 + Deviance Residuals: + 1 2 3 4 + 0.9600 -0.6788 -0.5543 0.4800 - (Dispersion parameter for gaussian family taken to be 7.68) + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) 5.5400 4.8040 1.153 0.455 + V1 -0.9600 2.7782 -0.346 0.788 + V2 1.7000 0.9798 1.735 0.333 - Null deviance: 202.00 on 3 degrees of freedom - Residual deviance: 7.68 on 1 degrees of freedom - AIC: 18.783 + (Dispersion parameter for gaussian family taken to be 1.92) - Number of Fisher Scoring iterations: 2 + Null deviance: 152.10 on 3 degrees of freedom + Residual deviance: 1.92 on 1 degrees of freedom + AIC: 13.238 - residuals(model, type="pearson") - 1 2 3 4 - 1.920000 -1.357645 -1.108513 0.960000 + Number of Fisher Scoring iterations: 2 - residuals(model, type="working") + residuals(model, type = "pearson") + 1 2 3 4 + 0.9600000 -0.6788225 -0.5542563 0.4800000 + residuals(model, type = "working") 1 2 3 4 - 1.92 -0.96 -0.64 0.48 - - residuals(model, type="response") + 0.96 -0.48 -0.32 0.24 + residuals(model, type = "response") 1 2 3 4 - 1.92 -0.96 -0.64 0.48 + 0.96 -0.48 -0.32 0.24 */ val trainer = new GeneralizedLinearRegression() - .setWeightCol("weight") + .setWeightCol("weight").setOffsetCol("offset") + + val model = trainer.fit(dataset) - val model = trainer.fit(datasetWithWeight) - - val coefficientsR = Vectors.dense(Array(6.080, -0.600)) - val interceptR = 18.080 - val devianceResidualsR = Array(1.920, -1.358, -1.109, 0.960) - val pearsonResidualsR = Array(1.920000, -1.357645, -1.108513, 0.960000) - val workingResidualsR = Array(1.92, -0.96, -0.64, 0.48) - val responseResidualsR = Array(1.92, -0.96, -0.64, 0.48) - val seCoefR = Array(5.556, 1.960, 9.608) - val tValsR = Array(1.094, -0.306, 1.882) - val pValsR = Array(0.471, 0.811, 0.311) - val dispersionR = 7.68 - val nullDevianceR = 202.00 - val residualDevianceR = 7.68 + val coefficientsR = Vectors.dense(Array(-0.96, 1.7)) + val interceptR = 5.54 + val devianceResidualsR = Array(0.96, -0.67882, -0.55426, 0.48) + val pearsonResidualsR = Array(0.96, -0.67882, -0.55426, 0.48) + val workingResidualsR = Array(0.96, -0.48, -0.32, 0.24) + val responseResidualsR = Array(0.96, -0.48, -0.32, 0.24) + val seCoefR = Array(2.7782, 0.9798, 4.804) + val tValsR = Array(-0.34555, 1.73506, 1.15321) + val pValsR = Array(0.78819, 0.33286, 0.45478) + val dispersionR = 1.92 + val nullDevianceR = 152.1 + val residualDevianceR = 1.92 val residualDegreeOfFreedomNullR = 3 val residualDegreeOfFreedomR = 1 - val aicR = 18.783 + val aicR = 13.23758 assert(model.hasSummary) val summary = model.summary @@ -912,7 +995,7 @@ class GeneralizedLinearRegressionSuite assert(summary.aic ~== aicR absTol 1E-3) assert(summary.solver === "irls") - val summary2: GeneralizedLinearRegressionSummary = model.evaluate(datasetWithWeight) + val summary2: GeneralizedLinearRegressionSummary = model.evaluate(dataset) assert(summary.predictions.columns.toSet === summary2.predictions.columns.toSet) assert(summary.predictionCol === summary2.predictionCol) assert(summary.rank === summary2.rank) @@ -925,79 +1008,79 @@ class GeneralizedLinearRegressionSuite assert(summary.aic === summary2.aic) } - test("glm summary: binomial family with weight") { + test("glm summary: binomial family with weight and offset") { /* - R code: + R code: - A <- matrix(c(0, 1, 2, 3, 5, 2, 1, 3), 4, 2) - b <- c(1, 0.5, 1, 0) - w <- c(1, 2.0, 0.3, 4.7) - df <- as.data.frame(cbind(A, b)) + df <- as.data.frame(matrix(c( + 0.2, 1.0, 2.0, 0.0, 5.0, + 0.5, 2.1, 0.5, 1.0, 2.0, + 0.9, 0.4, 1.0, 2.0, 1.0, + 0.7, 0.7, 0.0, 3.0, 3.0), 4, 5, byrow = TRUE)) */ - val datasetWithWeight = Seq( - Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(0.5, 2.0, Vectors.dense(1.0, 2.0)), - Instance(1.0, 0.3, Vectors.dense(2.0, 1.0)), - Instance(0.0, 4.7, Vectors.dense(3.0, 3.0)) + val dataset = Seq( + OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)), + OffsetInstance(0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)), + OffsetInstance(0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)), + OffsetInstance(0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0)) ).toDF() - /* - R code: - - model <- glm(formula = "b ~ . -1", family="binomial", data = df, weights = w) - summary(model) - - Deviance Residuals: - 1 2 3 4 - 0.2404 0.1965 1.2824 -0.6916 + R code: - Coefficients: - Estimate Std. Error z value Pr(>|z|) - x1 -1.6901 1.2764 -1.324 0.185 - x2 0.7059 0.9449 0.747 0.455 + model <- glm(formula = "V1 ~ V4 + V5", family = "binomial", data = df, + weights = V2, offset = V3) + summary(model) - (Dispersion parameter for binomial family taken to be 1) + Deviance Residuals: + 1 2 3 4 + 0.002584 -0.003800 0.012478 -0.001796 - Null deviance: 8.3178 on 4 degrees of freedom - Residual deviance: 2.2193 on 2 degrees of freedom - AIC: 5.9915 + Coefficients: + Estimate Std. Error z value Pr(>|z|) + (Intercept) -0.2147 3.5687 -0.060 0.952 + V4 0.9912 1.2344 0.803 0.422 + V5 -0.6356 0.9669 -0.657 0.511 - Number of Fisher Scoring iterations: 5 + (Dispersion parameter for binomial family taken to be 1) - residuals(model, type="pearson") - 1 2 3 4 - 0.171217 0.197406 2.085864 -0.495332 + Null deviance: 2.17560881 on 3 degrees of freedom + Residual deviance: 0.00018005 on 1 degrees of freedom + AIC: 10.245 - residuals(model, type="working") - 1 2 3 4 - 1.029315 0.281881 15.502768 -1.052203 + Number of Fisher Scoring iterations: 4 - residuals(model, type="response") - 1 2 3 4 - 0.028480 0.069123 0.935495 -0.049613 + residuals(model, type = "pearson") + 1 2 3 4 + 0.002586113 -0.003799744 0.012372235 -0.001796892 + residuals(model, type = "working") + 1 2 3 4 + 0.006477857 -0.005244163 0.063541250 -0.004691064 + residuals(model, type = "response") + 1 2 3 4 + 0.0010324375 -0.0013110318 0.0060225522 -0.0009832738 */ val trainer = new GeneralizedLinearRegression() .setFamily("Binomial") .setWeightCol("weight") - .setFitIntercept(false) - - val model = trainer.fit(datasetWithWeight) - - val coefficientsR = Vectors.dense(Array(-1.690134, 0.705929)) - val interceptR = 0.0 - val devianceResidualsR = Array(0.2404, 0.1965, 1.2824, -0.6916) - val pearsonResidualsR = Array(0.171217, 0.197406, 2.085864, -0.495332) - val workingResidualsR = Array(1.029315, 0.281881, 15.502768, -1.052203) - val responseResidualsR = Array(0.02848, 0.069123, 0.935495, -0.049613) - val seCoefR = Array(1.276417, 0.944934) - val tValsR = Array(-1.324124, 0.747068) - val pValsR = Array(0.185462, 0.455023) - val dispersionR = 1.0 - val nullDevianceR = 8.3178 - val residualDevianceR = 2.2193 - val residualDegreeOfFreedomNullR = 4 - val residualDegreeOfFreedomR = 2 - val aicR = 5.991537 + .setOffsetCol("offset") + + val model = trainer.fit(dataset) + + val coefficientsR = Vectors.dense(Array(0.99117, -0.63561)) + val interceptR = -0.21471 + val devianceResidualsR = Array(0.00258, -0.0038, 0.01248, -0.0018) + val pearsonResidualsR = Array(0.00259, -0.0038, 0.01237, -0.0018) + val workingResidualsR = Array(0.00648, -0.00524, 0.06354, -0.00469) + val responseResidualsR = Array(0.00103, -0.00131, 0.00602, -0.00098) + val seCoefR = Array(1.23439, 0.9669, 3.56866) + val tValsR = Array(0.80297, -0.65737, -0.06017) + val pValsR = Array(0.42199, 0.51094, 0.95202) + val dispersionR = 1 + val nullDevianceR = 2.17561 + val residualDevianceR = 0.00018 + val residualDegreeOfFreedomNullR = 3 + val residualDegreeOfFreedomR = 1 + val aicR = 10.24453 val summary = model.summary val devianceResiduals = summary.residuals() @@ -1040,81 +1123,79 @@ class GeneralizedLinearRegressionSuite assert(summary.solver === "irls") } - test("glm summary: poisson family with weight") { + test("glm summary: poisson family with weight and offset") { /* - R code: + R code: - A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) - b <- c(2, 8, 3, 9) - w <- c(1, 2, 3, 4) - df <- as.data.frame(cbind(A, b)) + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(2, 8, 3, 9) + w <- c(1, 2, 3, 4) + off <- c(2, 3, 1, 4) + df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = Seq( - Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + val dataset = Seq( + OffsetInstance(2.0, 1.0, 2.0, Vectors.dense(0.0, 5.0).toSparse), + OffsetInstance(8.0, 2.0, 3.0, Vectors.dense(1.0, 7.0)), + OffsetInstance(3.0, 3.0, 1.0, Vectors.dense(2.0, 11.0)), + OffsetInstance(9.0, 4.0, 4.0, Vectors.dense(3.0, 13.0)) ).toDF() /* - R code: - - model <- glm(formula = "b ~ .", family="poisson", data = df, weights = w) - summary(model) - - Deviance Residuals: - 1 2 3 4 - -0.28952 0.11048 0.14839 -0.07268 - - Coefficients: - Estimate Std. Error z value Pr(>|z|) - (Intercept) 6.2999 1.6086 3.916 8.99e-05 *** - V1 3.3241 1.0184 3.264 0.00110 ** - V2 -1.0818 0.3522 -3.071 0.00213 ** - --- - Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 - - (Dispersion parameter for poisson family taken to be 1) - - Null deviance: 15.38066 on 3 degrees of freedom - Residual deviance: 0.12333 on 1 degrees of freedom - AIC: 41.803 - - Number of Fisher Scoring iterations: 3 + R code: - residuals(model, type="pearson") - 1 2 3 4 - -0.28043145 0.11099310 0.14963714 -0.07253611 + model <- glm(formula = "b ~ .", family = "poisson", data = df, + weights = w, offset = off) + summary(model) - residuals(model, type="working") - 1 2 3 4 - -0.17960679 0.02813593 0.05113852 -0.01201650 + Deviance Residuals: + 1 2 3 4 + -2.0480 1.2315 1.8293 -0.7107 - residuals(model, type="response") - 1 2 3 4 - -0.4378554 0.2189277 0.1459518 -0.1094638 + Coefficients: + Estimate Std. Error z value Pr(>|z|) + (Intercept) -4.5678 1.9625 -2.328 0.0199 + V1 -2.8784 1.1683 -2.464 0.0137 + V2 0.8859 0.4170 2.124 0.0336 + + (Dispersion parameter for poisson family taken to be 1) + + Null deviance: 22.5585 on 3 degrees of freedom + Residual deviance: 9.5622 on 1 degrees of freedom + AIC: 51.242 + + Number of Fisher Scoring iterations: 5 + + residuals(model, type = "pearson") + 1 2 3 4 + -1.7480418 1.3037611 2.0750099 -0.6972966 + residuals(model, type = "working") + 1 2 3 4 + -0.6891489 0.3833588 0.9710682 -0.1096590 + residuals(model, type = "response") + 1 2 3 4 + -4.433948 2.216974 1.477983 -1.108487 */ val trainer = new GeneralizedLinearRegression() .setFamily("Poisson") .setWeightCol("weight") - .setFitIntercept(true) - - val model = trainer.fit(datasetWithWeight) - - val coefficientsR = Vectors.dense(Array(3.3241, -1.0818)) - val interceptR = 6.2999 - val devianceResidualsR = Array(-0.28952, 0.11048, 0.14839, -0.07268) - val pearsonResidualsR = Array(-0.28043145, 0.11099310, 0.14963714, -0.07253611) - val workingResidualsR = Array(-0.17960679, 0.02813593, 0.05113852, -0.01201650) - val responseResidualsR = Array(-0.4378554, 0.2189277, 0.1459518, -0.1094638) - val seCoefR = Array(1.0184, 0.3522, 1.6086) - val tValsR = Array(3.264, -3.071, 3.916) - val pValsR = Array(0.00110, 0.00213, 0.00009) - val dispersionR = 1.0 - val nullDevianceR = 15.38066 - val residualDevianceR = 0.12333 + .setOffsetCol("offset") + + val model = trainer.fit(dataset) + + val coefficientsR = Vectors.dense(Array(-2.87843, 0.88589)) + val interceptR = -4.56784 + val devianceResidualsR = Array(-2.04796, 1.23149, 1.82933, -0.71066) + val pearsonResidualsR = Array(-1.74804, 1.30376, 2.07501, -0.6973) + val workingResidualsR = Array(-0.68915, 0.38336, 0.97107, -0.10966) + val responseResidualsR = Array(-4.43395, 2.21697, 1.47798, -1.10849) + val seCoefR = Array(1.16826, 0.41703, 1.96249) + val tValsR = Array(-2.46387, 2.12428, -2.32757) + val pValsR = Array(0.01374, 0.03365, 0.01993) + val dispersionR = 1 + val nullDevianceR = 22.55853 + val residualDevianceR = 9.5622 val residualDegreeOfFreedomNullR = 3 val residualDegreeOfFreedomR = 1 - val aicR = 41.803 + val aicR = 51.24218 val summary = model.summary val devianceResiduals = summary.residuals() @@ -1157,78 +1238,79 @@ class GeneralizedLinearRegressionSuite assert(summary.solver === "irls") } - test("glm summary: gamma family with weight") { + test("glm summary: gamma family with weight and offset") { /* - R code: + R code: - A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) - b <- c(2, 8, 3, 9) - w <- c(1, 2, 3, 4) - df <- as.data.frame(cbind(A, b)) + A <- matrix(c(0, 5, 1, 2, 2, 1, 3, 3), 4, 2, byrow = TRUE) + b <- c(1, 2, 1, 2) + w <- c(1, 2, 3, 4) + off <- c(0, 0.5, 1, 0) + df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = Seq( - Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), - Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), - Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + val dataset = Seq( + OffsetInstance(1.0, 1.0, 0.0, Vectors.dense(0.0, 5.0)), + OffsetInstance(2.0, 2.0, 0.5, Vectors.dense(1.0, 2.0)), + OffsetInstance(1.0, 3.0, 1.0, Vectors.dense(2.0, 1.0)), + OffsetInstance(2.0, 4.0, 0.0, Vectors.dense(3.0, 3.0)) ).toDF() /* - R code: - - model <- glm(formula = "b ~ .", family="Gamma", data = df, weights = w) - summary(model) + R code: - Deviance Residuals: - 1 2 3 4 - -0.26343 0.05761 0.12818 -0.03484 + model <- glm(formula = "b ~ .", family = "Gamma", data = df, + weights = w, offset = off) + summary(model) - Coefficients: - Estimate Std. Error t value Pr(>|t|) - (Intercept) -0.81511 0.23449 -3.476 0.178 - V1 -0.72730 0.16137 -4.507 0.139 - V2 0.23894 0.05481 4.359 0.144 + Deviance Residuals: + 1 2 3 4 + -0.17095 0.19867 -0.23604 0.03241 - (Dispersion parameter for Gamma family taken to be 0.07986091) + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) -0.56474 0.23866 -2.366 0.255 + V1 0.07695 0.06931 1.110 0.467 + V2 0.28068 0.07320 3.835 0.162 - Null deviance: 2.937462 on 3 degrees of freedom - Residual deviance: 0.090358 on 1 degrees of freedom - AIC: 23.202 + (Dispersion parameter for Gamma family taken to be 0.1212174) - Number of Fisher Scoring iterations: 4 + Null deviance: 2.02568 on 3 degrees of freedom + Residual deviance: 0.12546 on 1 degrees of freedom + AIC: 0.93388 - residuals(model, type="pearson") - 1 2 3 4 - -0.24082508 0.05839241 0.13135766 -0.03463621 + Number of Fisher Scoring iterations: 4 - residuals(model, type="working") + residuals(model, type = "pearson") + 1 2 3 4 + -0.16134949 0.20807694 -0.22544551 0.03258777 + residuals(model, type = "working") 1 2 3 4 - 0.091414181 -0.005374314 -0.027196998 0.001890910 - - residuals(model, type="response") - 1 2 3 4 - -0.6344390 0.3172195 0.2114797 -0.1586097 + 0.135315831 -0.084390309 0.113219135 -0.008279688 + residuals(model, type = "response") + 1 2 3 4 + -0.1923918 0.2565224 -0.1496381 0.0320653 */ val trainer = new GeneralizedLinearRegression() .setFamily("Gamma") .setWeightCol("weight") + .setOffsetCol("offset") + + val model = trainer.fit(dataset) - val model = trainer.fit(datasetWithWeight) - - val coefficientsR = Vectors.dense(Array(-0.72730, 0.23894)) - val interceptR = -0.81511 - val devianceResidualsR = Array(-0.26343, 0.05761, 0.12818, -0.03484) - val pearsonResidualsR = Array(-0.24082508, 0.05839241, 0.13135766, -0.03463621) - val workingResidualsR = Array(0.091414181, -0.005374314, -0.027196998, 0.001890910) - val responseResidualsR = Array(-0.6344390, 0.3172195, 0.2114797, -0.1586097) - val seCoefR = Array(0.16137, 0.05481, 0.23449) - val tValsR = Array(-4.507, 4.359, -3.476) - val pValsR = Array(0.139, 0.144, 0.178) - val dispersionR = 0.07986091 - val nullDevianceR = 2.937462 - val residualDevianceR = 0.090358 + val coefficientsR = Vectors.dense(Array(0.07695, 0.28068)) + val interceptR = -0.56474 + val devianceResidualsR = Array(-0.17095, 0.19867, -0.23604, 0.03241) + val pearsonResidualsR = Array(-0.16135, 0.20808, -0.22545, 0.03259) + val workingResidualsR = Array(0.13532, -0.08439, 0.11322, -0.00828) + val responseResidualsR = Array(-0.19239, 0.25652, -0.14964, 0.03207) + val seCoefR = Array(0.06931, 0.0732, 0.23866) + val tValsR = Array(1.11031, 3.83453, -2.3663) + val pValsR = Array(0.46675, 0.16241, 0.25454) + val dispersionR = 0.12122 + val nullDevianceR = 2.02568 + val residualDevianceR = 0.12546 val residualDegreeOfFreedomNullR = 3 val residualDegreeOfFreedomR = 1 - val aicR = 23.202 + val aicR = 0.93388 val summary = model.summary val devianceResiduals = summary.residuals() @@ -1271,77 +1353,81 @@ class GeneralizedLinearRegressionSuite assert(summary.solver === "irls") } - test("glm summary: tweedie family with weight") { + test("glm summary: tweedie family with weight and offset") { /* R code: - library(statmod) df <- as.data.frame(matrix(c( - 1.0, 1.0, 0.0, 5.0, - 0.5, 2.0, 1.0, 2.0, - 1.0, 3.0, 2.0, 1.0, - 0.0, 4.0, 3.0, 3.0), 4, 4, byrow = TRUE)) + 1.0, 1.0, 1.0, 0.0, 5.0, + 0.5, 2.0, 3.0, 1.0, 2.0, + 1.0, 3.0, 2.0, 2.0, 1.0, + 0.0, 4.0, 0.0, 3.0, 3.0), 4, 5, byrow = TRUE)) + */ + val dataset = Seq( + OffsetInstance(1.0, 1.0, 1.0, Vectors.dense(0.0, 5.0)), + OffsetInstance(0.5, 2.0, 3.0, Vectors.dense(1.0, 2.0)), + OffsetInstance(1.0, 3.0, 2.0, Vectors.dense(2.0, 1.0)), + OffsetInstance(0.0, 4.0, 0.0, Vectors.dense(3.0, 3.0)) + ).toDF() + /* + R code: - model <- glm(V1 ~ -1 + V3 + V4, data = df, weights = V2, - family = tweedie(var.power = 1.6, link.power = 0)) + library(statmod) + model <- glm(V1 ~ V4 + V5, data = df, weights = V2, offset = V3, + family = tweedie(var.power = 1.6, link.power = 0.0)) summary(model) Deviance Residuals: 1 2 3 4 - 0.6210 -0.0515 1.6935 -3.2539 + 0.8917 -2.1396 1.2252 -1.7946 Coefficients: - Estimate Std. Error t value Pr(>|t|) - V3 -0.4087 0.5205 -0.785 0.515 - V4 -0.1212 0.4082 -0.297 0.794 + Estimate Std. Error t value Pr(>|t|) + (Intercept) -0.03047 3.65000 -0.008 0.995 + V4 -1.14577 1.41674 -0.809 0.567 + V5 -0.36585 0.97065 -0.377 0.771 - (Dispersion parameter for Tweedie family taken to be 3.830036) + (Dispersion parameter for Tweedie family taken to be 6.334961) - Null deviance: 20.702 on 4 degrees of freedom - Residual deviance: 13.844 on 2 degrees of freedom + Null deviance: 12.784 on 3 degrees of freedom + Residual deviance: 10.095 on 1 degrees of freedom AIC: NA - Number of Fisher Scoring iterations: 11 - - residuals(model, type="pearson") - 1 2 3 4 - 0.7383616 -0.0509458 2.2348337 -1.4552090 - residuals(model, type="working") - 1 2 3 4 - 0.83354150 -0.04103552 1.55676369 -1.00000000 - residuals(model, type="response") - 1 2 3 4 - 0.45460738 -0.02139574 0.60888055 -0.20392801 + Number of Fisher Scoring iterations: 18 + + residuals(model, type = "pearson") + 1 2 3 4 + 1.1472554 -1.4642569 1.4935199 -0.8025842 + residuals(model, type = "working") + 1 2 3 4 + 1.3624928 -0.8322375 0.9894580 -1.0000000 + residuals(model, type = "response") + 1 2 3 4 + 0.57671828 -2.48040354 0.49735052 -0.01040646 */ - val datasetWithWeight = Seq( - Instance(1.0, 1.0, Vectors.dense(0.0, 5.0)), - Instance(0.5, 2.0, Vectors.dense(1.0, 2.0)), - Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), - Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) - ).toDF() - val trainer = new GeneralizedLinearRegression() .setFamily("tweedie") .setVariancePower(1.6) .setLinkPower(0.0) .setWeightCol("weight") - .setFitIntercept(false) - - val model = trainer.fit(datasetWithWeight) - val coefficientsR = Vectors.dense(Array(-0.408746, -0.12125)) - val interceptR = 0.0 - val devianceResidualsR = Array(0.621047, -0.051515, 1.693473, -3.253946) - val pearsonResidualsR = Array(0.738362, -0.050946, 2.234834, -1.455209) - val workingResidualsR = Array(0.833541, -0.041036, 1.556764, -1.0) - val responseResidualsR = Array(0.454607, -0.021396, 0.608881, -0.203928) - val seCoefR = Array(0.520519, 0.408215) - val tValsR = Array(-0.785267, -0.297024) - val pValsR = Array(0.514549, 0.794457) - val dispersionR = 3.830036 - val nullDevianceR = 20.702 - val residualDevianceR = 13.844 - val residualDegreeOfFreedomNullR = 4 - val residualDegreeOfFreedomR = 2 + .setOffsetCol("offset") + + val model = trainer.fit(dataset) + + val coefficientsR = Vectors.dense(Array(-1.14577, -0.36585)) + val interceptR = -0.03047 + val devianceResidualsR = Array(0.89171, -2.13961, 1.2252, -1.79463) + val pearsonResidualsR = Array(1.14726, -1.46426, 1.49352, -0.80258) + val workingResidualsR = Array(1.36249, -0.83224, 0.98946, -1) + val responseResidualsR = Array(0.57672, -2.4804, 0.49735, -0.01041) + val seCoefR = Array(1.41674, 0.97065, 3.65) + val tValsR = Array(-0.80873, -0.37691, -0.00835) + val pValsR = Array(0.56707, 0.77053, 0.99468) + val dispersionR = 6.33496 + val nullDevianceR = 12.78358 + val residualDevianceR = 10.09488 + val residualDegreeOfFreedomNullR = 3 + val residualDegreeOfFreedomR = 1 val summary = model.summary From 3c2fc19d478256f8dc0ae7219fdd188030218c07 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 30 Jun 2017 20:30:26 +0800 Subject: [PATCH 117/118] [SPARK-18294][CORE] Implement commit protocol to support `mapred` package's committer ## What changes were proposed in this pull request? This PR makes the following changes: - Implement a new commit protocol `HadoopMapRedCommitProtocol` which support the old `mapred` package's committer; - Refactor SparkHadoopWriter and SparkHadoopMapReduceWriter, now they are combined together, thus we can support write through both mapred and mapreduce API by the new SparkHadoopWriter, a lot of duplicated codes are removed. After this change, it should be pretty easy for us to support the committer from both the new and the old hadoop API at high level. ## How was this patch tested? No major behavior change, passed the existing test cases. Author: Xingbo Jiang Closes #18438 from jiangxb1987/SparkHadoopWriter. --- .../io/HadoopMapRedCommitProtocol.scala | 36 ++ .../internal/io/HadoopWriteConfigUtil.scala | 79 ++++ .../io/SparkHadoopMapReduceWriter.scala | 181 -------- .../spark/internal/io/SparkHadoopWriter.scala | 393 ++++++++++++++---- .../apache/spark/rdd/PairRDDFunctions.scala | 72 +--- .../spark/rdd/PairRDDFunctionsSuite.scala | 2 +- .../OutputCommitCoordinatorSuite.scala | 35 +- 7 files changed, 461 insertions(+), 337 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala create mode 100644 core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala delete mode 100644 core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala new file mode 100644 index 0000000000000..ddbd624b380d4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.{TaskAttemptContext => NewTaskAttemptContext} + +/** + * An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter + * (from the old mapred API). + * + * Unlike Hadoop's OutputCommitter, this implementation is serializable. + */ +class HadoopMapRedCommitProtocol(jobId: String, path: String) + extends HadoopMapReduceCommitProtocol(jobId, path) { + + override def setupCommitter(context: NewTaskAttemptContext): OutputCommitter = { + val config = context.getConfiguration.asInstanceOf[JobConf] + config.getOutputCommitter + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala new file mode 100644 index 0000000000000..9b987e0e1bb67 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import scala.reflect.ClassTag + +import org.apache.hadoop.mapreduce._ + +import org.apache.spark.SparkConf + +/** + * Interface for create output format/committer/writer used during saving an RDD using a Hadoop + * OutputFormat (both from the old mapred API and the new mapreduce API) + * + * Notes: + * 1. Implementations should throw [[IllegalArgumentException]] when wrong hadoop API is + * referenced; + * 2. Implementations must be serializable, as the instance instantiated on the driver + * will be used for tasks on executors; + * 3. Implementations should have a constructor with exactly one argument: + * (conf: SerializableConfiguration) or (conf: SerializableJobConf). + */ +abstract class HadoopWriteConfigUtil[K, V: ClassTag] extends Serializable { + + // -------------------------------------------------------------------------- + // Create JobContext/TaskAttemptContext + // -------------------------------------------------------------------------- + + def createJobContext(jobTrackerId: String, jobId: Int): JobContext + + def createTaskAttemptContext( + jobTrackerId: String, + jobId: Int, + splitId: Int, + taskAttemptId: Int): TaskAttemptContext + + // -------------------------------------------------------------------------- + // Create committer + // -------------------------------------------------------------------------- + + def createCommitter(jobId: Int): HadoopMapReduceCommitProtocol + + // -------------------------------------------------------------------------- + // Create writer + // -------------------------------------------------------------------------- + + def initWriter(taskContext: TaskAttemptContext, splitId: Int): Unit + + def write(pair: (K, V)): Unit + + def closeWriter(taskContext: TaskAttemptContext): Unit + + // -------------------------------------------------------------------------- + // Create OutputFormat + // -------------------------------------------------------------------------- + + def initOutputFormat(jobContext: JobContext): Unit + + // -------------------------------------------------------------------------- + // Verify hadoop config + // -------------------------------------------------------------------------- + + def assertConf(jobContext: JobContext, conf: SparkConf): Unit +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala deleted file mode 100644 index 376ff9bb19f74..0000000000000 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.internal.io - -import java.text.SimpleDateFormat -import java.util.{Date, Locale} - -import scala.reflect.ClassTag -import scala.util.DynamicVariable - -import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapred.{JobConf, JobID} -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl - -import org.apache.spark.{SparkConf, SparkException, TaskContext} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.OutputMetrics -import org.apache.spark.internal.Logging -import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage -import org.apache.spark.rdd.RDD -import org.apache.spark.util.{SerializableConfiguration, Utils} - -/** - * A helper object that saves an RDD using a Hadoop OutputFormat - * (from the newer mapreduce API, not the old mapred API). - */ -private[spark] -object SparkHadoopMapReduceWriter extends Logging { - - /** - * Basic work flow of this command is: - * 1. Driver side setup, prepare the data source and hadoop configuration for the write job to - * be issued. - * 2. Issues a write job consists of one or more executor side tasks, each of which writes all - * rows within an RDD partition. - * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any - * exception is thrown during task commitment, also aborts that task. - * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is - * thrown during job commitment, also aborts the job. - */ - def write[K, V: ClassTag]( - rdd: RDD[(K, V)], - hadoopConf: Configuration): Unit = { - // Extract context and configuration from RDD. - val sparkContext = rdd.context - val stageId = rdd.id - val sparkConf = rdd.conf - val conf = new SerializableConfiguration(hadoopConf) - - // Set up a job. - val jobTrackerId = SparkHadoopWriterUtils.createJobTrackerID(new Date()) - val jobAttemptId = new TaskAttemptID(jobTrackerId, stageId, TaskType.MAP, 0, 0) - val jobContext = new TaskAttemptContextImpl(conf.value, jobAttemptId) - val format = jobContext.getOutputFormatClass - - if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(sparkConf)) { - // FileOutputFormat ignores the filesystem parameter - val jobFormat = format.newInstance - jobFormat.checkOutputSpecs(jobContext) - } - - val committer = FileCommitProtocol.instantiate( - className = classOf[HadoopMapReduceCommitProtocol].getName, - jobId = stageId.toString, - outputPath = conf.value.get("mapreduce.output.fileoutputformat.outputdir"), - isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol] - committer.setupJob(jobContext) - - // Try to write all RDD partitions as a Hadoop OutputFormat. - try { - val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { - executeTask( - context = context, - jobTrackerId = jobTrackerId, - sparkStageId = context.stageId, - sparkPartitionId = context.partitionId, - sparkAttemptNumber = context.attemptNumber, - committer = committer, - hadoopConf = conf.value, - outputFormat = format.asInstanceOf[Class[OutputFormat[K, V]]], - iterator = iter) - }) - - committer.commitJob(jobContext, ret) - logInfo(s"Job ${jobContext.getJobID} committed.") - } catch { - case cause: Throwable => - logError(s"Aborting job ${jobContext.getJobID}.", cause) - committer.abortJob(jobContext) - throw new SparkException("Job aborted.", cause) - } - } - - /** Write an RDD partition out in a single Spark task. */ - private def executeTask[K, V: ClassTag]( - context: TaskContext, - jobTrackerId: String, - sparkStageId: Int, - sparkPartitionId: Int, - sparkAttemptNumber: Int, - committer: FileCommitProtocol, - hadoopConf: Configuration, - outputFormat: Class[_ <: OutputFormat[K, V]], - iterator: Iterator[(K, V)]): TaskCommitMessage = { - // Set up a task. - val attemptId = new TaskAttemptID(jobTrackerId, sparkStageId, TaskType.REDUCE, - sparkPartitionId, sparkAttemptNumber) - val taskContext = new TaskAttemptContextImpl(hadoopConf, attemptId) - committer.setupTask(taskContext) - - val (outputMetrics, callback) = SparkHadoopWriterUtils.initHadoopOutputMetrics(context) - - // Initiate the writer. - val taskFormat = outputFormat.newInstance() - // If OutputFormat is Configurable, we should set conf to it. - taskFormat match { - case c: Configurable => c.setConf(hadoopConf) - case _ => () - } - var writer = taskFormat.getRecordWriter(taskContext) - .asInstanceOf[RecordWriter[K, V]] - require(writer != null, "Unable to obtain RecordWriter") - var recordsWritten = 0L - - // Write all rows in RDD partition. - try { - val ret = Utils.tryWithSafeFinallyAndFailureCallbacks { - // Write rows out, release resource and commit the task. - while (iterator.hasNext) { - val pair = iterator.next() - writer.write(pair._1, pair._2) - - // Update bytes written metric every few records - SparkHadoopWriterUtils.maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten) - recordsWritten += 1 - } - if (writer != null) { - writer.close(taskContext) - writer = null - } - committer.commitTask(taskContext) - }(catchBlock = { - // If there is an error, release resource and then abort the task. - try { - if (writer != null) { - writer.close(taskContext) - writer = null - } - } finally { - committer.abortTask(taskContext) - logError(s"Task ${taskContext.getTaskAttemptID} aborted.") - } - }) - - outputMetrics.setBytesWritten(callback()) - outputMetrics.setRecordsWritten(recordsWritten) - - ret - } catch { - case t: Throwable => - throw new SparkException("Task failed while writing rows", t) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index acc9c38571007..7d846f9354df6 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -17,143 +17,374 @@ package org.apache.spark.internal.io -import java.io.IOException -import java.text.{NumberFormat, SimpleDateFormat} +import java.text.NumberFormat import java.util.{Date, Locale} +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.mapred._ -import org.apache.hadoop.mapreduce.TaskType +import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, +OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, +TaskAttemptContext => NewTaskAttemptContext, TaskAttemptID => NewTaskAttemptID, TaskType} +import org.apache.hadoop.mapreduce.task.{TaskAttemptContextImpl => NewTaskAttemptContextImpl} -import org.apache.spark.SerializableWritable +import org.apache.spark.{SerializableWritable, SparkConf, SparkException, TaskContext} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.rdd.HadoopRDD -import org.apache.spark.util.SerializableJobConf +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.rdd.{HadoopRDD, RDD} +import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf, Utils} /** - * Internal helper class that saves an RDD using a Hadoop OutputFormat. - * - * Saves the RDD using a JobConf, which should contain an output key class, an output value class, - * a filename to write to, etc, exactly like in a Hadoop MapReduce job. + * A helper object that saves an RDD using a Hadoop OutputFormat. + */ +private[spark] +object SparkHadoopWriter extends Logging { + import SparkHadoopWriterUtils._ + + /** + * Basic work flow of this command is: + * 1. Driver side setup, prepare the data source and hadoop configuration for the write job to + * be issued. + * 2. Issues a write job consists of one or more executor side tasks, each of which writes all + * rows within an RDD partition. + * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any + * exception is thrown during task commitment, also aborts that task. + * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is + * thrown during job commitment, also aborts the job. + */ + def write[K, V: ClassTag]( + rdd: RDD[(K, V)], + config: HadoopWriteConfigUtil[K, V]): Unit = { + // Extract context and configuration from RDD. + val sparkContext = rdd.context + val stageId = rdd.id + + // Set up a job. + val jobTrackerId = createJobTrackerID(new Date()) + val jobContext = config.createJobContext(jobTrackerId, stageId) + config.initOutputFormat(jobContext) + + // Assert the output format/key/value class is set in JobConf. + config.assertConf(jobContext, rdd.conf) + + val committer = config.createCommitter(stageId) + committer.setupJob(jobContext) + + // Try to write all RDD partitions as a Hadoop OutputFormat. + try { + val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { + executeTask( + context = context, + config = config, + jobTrackerId = jobTrackerId, + sparkStageId = context.stageId, + sparkPartitionId = context.partitionId, + sparkAttemptNumber = context.attemptNumber, + committer = committer, + iterator = iter) + }) + + committer.commitJob(jobContext, ret) + logInfo(s"Job ${jobContext.getJobID} committed.") + } catch { + case cause: Throwable => + logError(s"Aborting job ${jobContext.getJobID}.", cause) + committer.abortJob(jobContext) + throw new SparkException("Job aborted.", cause) + } + } + + /** Write a RDD partition out in a single Spark task. */ + private def executeTask[K, V: ClassTag]( + context: TaskContext, + config: HadoopWriteConfigUtil[K, V], + jobTrackerId: String, + sparkStageId: Int, + sparkPartitionId: Int, + sparkAttemptNumber: Int, + committer: FileCommitProtocol, + iterator: Iterator[(K, V)]): TaskCommitMessage = { + // Set up a task. + val taskContext = config.createTaskAttemptContext( + jobTrackerId, sparkStageId, sparkPartitionId, sparkAttemptNumber) + committer.setupTask(taskContext) + + val (outputMetrics, callback) = initHadoopOutputMetrics(context) + + // Initiate the writer. + config.initWriter(taskContext, sparkPartitionId) + var recordsWritten = 0L + + // Write all rows in RDD partition. + try { + val ret = Utils.tryWithSafeFinallyAndFailureCallbacks { + while (iterator.hasNext) { + val pair = iterator.next() + config.write(pair) + + // Update bytes written metric every few records + maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten) + recordsWritten += 1 + } + + config.closeWriter(taskContext) + committer.commitTask(taskContext) + }(catchBlock = { + // If there is an error, release resource and then abort the task. + try { + config.closeWriter(taskContext) + } finally { + committer.abortTask(taskContext) + logError(s"Task ${taskContext.getTaskAttemptID} aborted.") + } + }) + + outputMetrics.setBytesWritten(callback()) + outputMetrics.setRecordsWritten(recordsWritten) + + ret + } catch { + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) + } + } +} + +/** + * A helper class that reads JobConf from older mapred API, creates output Format/Committer/Writer. */ private[spark] -class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { +class HadoopMapRedWriteConfigUtil[K, V: ClassTag](conf: SerializableJobConf) + extends HadoopWriteConfigUtil[K, V] with Logging { - private val now = new Date() - private val conf = new SerializableJobConf(jobConf) + private var outputFormat: Class[_ <: OutputFormat[K, V]] = null + private var writer: RecordWriter[K, V] = null - private var jobID = 0 - private var splitID = 0 - private var attemptID = 0 - private var jID: SerializableWritable[JobID] = null - private var taID: SerializableWritable[TaskAttemptID] = null + private def getConf: JobConf = conf.value - @transient private var writer: RecordWriter[AnyRef, AnyRef] = null - @transient private var format: OutputFormat[AnyRef, AnyRef] = null - @transient private var committer: OutputCommitter = null - @transient private var jobContext: JobContext = null - @transient private var taskContext: TaskAttemptContext = null + // -------------------------------------------------------------------------- + // Create JobContext/TaskAttemptContext + // -------------------------------------------------------------------------- - def preSetup() { - setIDs(0, 0, 0) - HadoopRDD.addLocalConfiguration("", 0, 0, 0, conf.value) + override def createJobContext(jobTrackerId: String, jobId: Int): NewJobContext = { + val jobAttemptId = new SerializableWritable(new JobID(jobTrackerId, jobId)) + new JobContextImpl(getConf, jobAttemptId.value) + } - val jCtxt = getJobContext() - getOutputCommitter().setupJob(jCtxt) + override def createTaskAttemptContext( + jobTrackerId: String, + jobId: Int, + splitId: Int, + taskAttemptId: Int): NewTaskAttemptContext = { + // Update JobConf. + HadoopRDD.addLocalConfiguration(jobTrackerId, jobId, splitId, taskAttemptId, conf.value) + // Create taskContext. + val attemptId = new TaskAttemptID(jobTrackerId, jobId, TaskType.MAP, splitId, taskAttemptId) + new TaskAttemptContextImpl(getConf, attemptId) } + // -------------------------------------------------------------------------- + // Create committer + // -------------------------------------------------------------------------- - def setup(jobid: Int, splitid: Int, attemptid: Int) { - setIDs(jobid, splitid, attemptid) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(now), - jobid, splitID, attemptID, conf.value) + override def createCommitter(jobId: Int): HadoopMapReduceCommitProtocol = { + // Update JobConf. + HadoopRDD.addLocalConfiguration("", 0, 0, 0, getConf) + // Create commit protocol. + FileCommitProtocol.instantiate( + className = classOf[HadoopMapRedCommitProtocol].getName, + jobId = jobId.toString, + outputPath = getConf.get("mapred.output.dir"), + isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol] } - def open() { + // -------------------------------------------------------------------------- + // Create writer + // -------------------------------------------------------------------------- + + override def initWriter(taskContext: NewTaskAttemptContext, splitId: Int): Unit = { val numfmt = NumberFormat.getInstance(Locale.US) numfmt.setMinimumIntegerDigits(5) numfmt.setGroupingUsed(false) - val outputName = "part-" + numfmt.format(splitID) - val path = FileOutputFormat.getOutputPath(conf.value) + val outputName = "part-" + numfmt.format(splitId) + val path = FileOutputFormat.getOutputPath(getConf) val fs: FileSystem = { if (path != null) { - path.getFileSystem(conf.value) + path.getFileSystem(getConf) } else { - FileSystem.get(conf.value) + FileSystem.get(getConf) } } - getOutputCommitter().setupTask(getTaskContext()) - writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL) + writer = getConf.getOutputFormat + .getRecordWriter(fs, getConf, outputName, Reporter.NULL) + .asInstanceOf[RecordWriter[K, V]] + + require(writer != null, "Unable to obtain RecordWriter") } - def write(key: AnyRef, value: AnyRef) { + override def write(pair: (K, V)): Unit = { + require(writer != null, "Must call createWriter before write.") + writer.write(pair._1, pair._2) + } + + override def closeWriter(taskContext: NewTaskAttemptContext): Unit = { if (writer != null) { - writer.write(key, value) - } else { - throw new IOException("Writer is null, open() has not been called") + writer.close(Reporter.NULL) } } - def close() { - writer.close(Reporter.NULL) - } + // -------------------------------------------------------------------------- + // Create OutputFormat + // -------------------------------------------------------------------------- - def commit() { - SparkHadoopMapRedUtil.commitTask(getOutputCommitter(), getTaskContext(), jobID, splitID) + override def initOutputFormat(jobContext: NewJobContext): Unit = { + if (outputFormat == null) { + outputFormat = getConf.getOutputFormat.getClass + .asInstanceOf[Class[_ <: OutputFormat[K, V]]] + } } - def commitJob() { - val cmtr = getOutputCommitter() - cmtr.commitJob(getJobContext()) + private def getOutputFormat(): OutputFormat[K, V] = { + require(outputFormat != null, "Must call initOutputFormat first.") + + outputFormat.newInstance() } - // ********* Private Functions ********* + // -------------------------------------------------------------------------- + // Verify hadoop config + // -------------------------------------------------------------------------- + + override def assertConf(jobContext: NewJobContext, conf: SparkConf): Unit = { + val outputFormatInstance = getOutputFormat() + val keyClass = getConf.getOutputKeyClass + val valueClass = getConf.getOutputValueClass + if (outputFormatInstance == null) { + throw new SparkException("Output format class not set") + } + if (keyClass == null) { + throw new SparkException("Output key class not set") + } + if (valueClass == null) { + throw new SparkException("Output value class not set") + } + SparkHadoopUtil.get.addCredentials(getConf) + + logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + + valueClass.getSimpleName + ")") - private def getOutputFormat(): OutputFormat[AnyRef, AnyRef] = { - if (format == null) { - format = conf.value.getOutputFormat() - .asInstanceOf[OutputFormat[AnyRef, AnyRef]] + if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(conf)) { + // FileOutputFormat ignores the filesystem parameter + val ignoredFs = FileSystem.get(getConf) + getOutputFormat().checkOutputSpecs(ignoredFs, getConf) } - format + } +} + +/** + * A helper class that reads Configuration from newer mapreduce API, creates output + * Format/Committer/Writer. + */ +private[spark] +class HadoopMapReduceWriteConfigUtil[K, V: ClassTag](conf: SerializableConfiguration) + extends HadoopWriteConfigUtil[K, V] with Logging { + + private var outputFormat: Class[_ <: NewOutputFormat[K, V]] = null + private var writer: NewRecordWriter[K, V] = null + + private def getConf: Configuration = conf.value + + // -------------------------------------------------------------------------- + // Create JobContext/TaskAttemptContext + // -------------------------------------------------------------------------- + + override def createJobContext(jobTrackerId: String, jobId: Int): NewJobContext = { + val jobAttemptId = new NewTaskAttemptID(jobTrackerId, jobId, TaskType.MAP, 0, 0) + new NewTaskAttemptContextImpl(getConf, jobAttemptId) + } + + override def createTaskAttemptContext( + jobTrackerId: String, + jobId: Int, + splitId: Int, + taskAttemptId: Int): NewTaskAttemptContext = { + val attemptId = new NewTaskAttemptID( + jobTrackerId, jobId, TaskType.REDUCE, splitId, taskAttemptId) + new NewTaskAttemptContextImpl(getConf, attemptId) + } + + // -------------------------------------------------------------------------- + // Create committer + // -------------------------------------------------------------------------- + + override def createCommitter(jobId: Int): HadoopMapReduceCommitProtocol = { + FileCommitProtocol.instantiate( + className = classOf[HadoopMapReduceCommitProtocol].getName, + jobId = jobId.toString, + outputPath = getConf.get("mapreduce.output.fileoutputformat.outputdir"), + isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol] } - private def getOutputCommitter(): OutputCommitter = { - if (committer == null) { - committer = conf.value.getOutputCommitter + // -------------------------------------------------------------------------- + // Create writer + // -------------------------------------------------------------------------- + + override def initWriter(taskContext: NewTaskAttemptContext, splitId: Int): Unit = { + val taskFormat = getOutputFormat() + // If OutputFormat is Configurable, we should set conf to it. + taskFormat match { + case c: Configurable => c.setConf(getConf) + case _ => () } - committer + + writer = taskFormat.getRecordWriter(taskContext) + .asInstanceOf[NewRecordWriter[K, V]] + + require(writer != null, "Unable to obtain RecordWriter") + } + + override def write(pair: (K, V)): Unit = { + require(writer != null, "Must call createWriter before write.") + writer.write(pair._1, pair._2) } - private def getJobContext(): JobContext = { - if (jobContext == null) { - jobContext = new JobContextImpl(conf.value, jID.value) + override def closeWriter(taskContext: NewTaskAttemptContext): Unit = { + if (writer != null) { + writer.close(taskContext) + writer = null + } else { + logWarning("Writer has been closed.") } - jobContext } - private def getTaskContext(): TaskAttemptContext = { - if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) + // -------------------------------------------------------------------------- + // Create OutputFormat + // -------------------------------------------------------------------------- + + override def initOutputFormat(jobContext: NewJobContext): Unit = { + if (outputFormat == null) { + outputFormat = jobContext.getOutputFormatClass + .asInstanceOf[Class[_ <: NewOutputFormat[K, V]]] } - taskContext } - protected def newTaskAttemptContext( - conf: JobConf, - attemptId: TaskAttemptID): TaskAttemptContext = { - new TaskAttemptContextImpl(conf, attemptId) + private def getOutputFormat(): NewOutputFormat[K, V] = { + require(outputFormat != null, "Must call initOutputFormat first.") + + outputFormat.newInstance() } - private def setIDs(jobid: Int, splitid: Int, attemptid: Int) { - jobID = jobid - splitID = splitid - attemptID = attemptid + // -------------------------------------------------------------------------- + // Verify hadoop config + // -------------------------------------------------------------------------- - jID = new SerializableWritable[JobID](SparkHadoopWriterUtils.createJobID(now, jobid)) - taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID)) + override def assertConf(jobContext: NewJobContext, conf: SparkConf): Unit = { + if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(conf)) { + getOutputFormat().checkOutputSpecs(jobContext) + } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 58762cc0838cd..4628fa8ba270e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -27,7 +27,6 @@ import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} @@ -36,13 +35,11 @@ import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewO import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.annotation.Experimental -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.internal.io.{SparkHadoopMapReduceWriter, SparkHadoopWriter, - SparkHadoopWriterUtils} +import org.apache.spark.internal.io._ import org.apache.spark.internal.Logging import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.random.StratifiedSamplingUtils @@ -1082,9 +1079,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * result of using direct output committer with speculation enabled. */ def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { - SparkHadoopMapReduceWriter.write( + val config = new HadoopMapReduceWriteConfigUtil[K, V](new SerializableConfiguration(conf)) + SparkHadoopWriter.write( rdd = self, - hadoopConf = conf) + config = config) } /** @@ -1094,62 +1092,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * MapReduce job. */ def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { - // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). - val hadoopConf = conf - val outputFormatInstance = hadoopConf.getOutputFormat - val keyClass = hadoopConf.getOutputKeyClass - val valueClass = hadoopConf.getOutputValueClass - if (outputFormatInstance == null) { - throw new SparkException("Output format class not set") - } - if (keyClass == null) { - throw new SparkException("Output key class not set") - } - if (valueClass == null) { - throw new SparkException("Output value class not set") - } - SparkHadoopUtil.get.addCredentials(hadoopConf) - - logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + - valueClass.getSimpleName + ")") - - if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(self.conf)) { - // FileOutputFormat ignores the filesystem parameter - val ignoredFs = FileSystem.get(hadoopConf) - hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf) - } - - val writer = new SparkHadoopWriter(hadoopConf) - writer.preSetup() - - val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { - // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it - // around by taking a mod. We expect that no task will be attempted 2 billion times. - val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt - - val (outputMetrics, callback) = SparkHadoopWriterUtils.initHadoopOutputMetrics(context) - - writer.setup(context.stageId, context.partitionId, taskAttemptId) - writer.open() - var recordsWritten = 0L - - Utils.tryWithSafeFinallyAndFailureCallbacks { - while (iter.hasNext) { - val record = iter.next() - writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) - - // Update bytes written metric every few records - SparkHadoopWriterUtils.maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten) - recordsWritten += 1 - } - }(finallyBlock = writer.close()) - writer.commit() - outputMetrics.setBytesWritten(callback()) - outputMetrics.setRecordsWritten(recordsWritten) - } - - self.context.runJob(self, writeToFile) - writer.commitJob() + val config = new HadoopMapRedWriteConfigUtil[K, V](new SerializableJobConf(conf)) + SparkHadoopWriter.write( + rdd = self, + config = config) } /** diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 02df157be377c..44dd955ce8690 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -561,7 +561,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { pairs.saveAsHadoopFile( "ignored", pairs.keyClass, pairs.valueClass, classOf[FakeFormatWithCallback], conf) } - assert(e.getMessage contains "failed to write") + assert(e.getCause.getMessage contains "failed to write") assert(FakeWriterWithCallback.calledBy === "write,callback,close") assert(FakeWriterWithCallback.exception != null, "exception should be captured") diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index e51e6a0d3ff6b..1579b614ea5b0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -18,12 +18,14 @@ package org.apache.spark.scheduler import java.io.File +import java.util.Date import java.util.concurrent.TimeoutException import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.hadoop.mapred.{JobConf, OutputCommitter, TaskAttemptContext, TaskAttemptID} +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.TaskType import org.mockito.Matchers import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -31,7 +33,7 @@ import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter import org.apache.spark._ -import org.apache.spark.internal.io.SparkHadoopWriter +import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapRedCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.rdd.{FakeOutputCommitter, RDD} import org.apache.spark.util.{ThreadUtils, Utils} @@ -214,6 +216,8 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { */ private case class OutputCommitFunctions(tempDirPath: String) { + private val jobId = new SerializableWritable(SparkHadoopWriterUtils.createJobID(new Date, 0)) + // Mock output committer that simulates a successful commit (after commit is authorized) private def successfulOutputCommitter = new FakeOutputCommitter { override def commitTask(context: TaskAttemptContext): Unit = { @@ -256,14 +260,23 @@ private case class OutputCommitFunctions(tempDirPath: String) { def jobConf = new JobConf { override def getOutputCommitter(): OutputCommitter = outputCommitter } - val sparkHadoopWriter = new SparkHadoopWriter(jobConf) { - override def newTaskAttemptContext( - conf: JobConf, - attemptId: TaskAttemptID): TaskAttemptContext = { - mock(classOf[TaskAttemptContext]) - } - } - sparkHadoopWriter.setup(ctx.stageId, ctx.partitionId, ctx.attemptNumber) - sparkHadoopWriter.commit() + + // Instantiate committer. + val committer = FileCommitProtocol.instantiate( + className = classOf[HadoopMapRedCommitProtocol].getName, + jobId = jobId.value.getId.toString, + outputPath = jobConf.get("mapred.output.dir"), + isAppend = false) + + // Create TaskAttemptContext. + // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it + // around by taking a mod. We expect that no task will be attempted 2 billion times. + val taskAttemptId = (ctx.taskAttemptId % Int.MaxValue).toInt + val attemptId = new TaskAttemptID( + new TaskID(jobId.value, TaskType.MAP, ctx.partitionId), taskAttemptId) + val taskContext = new TaskAttemptContextImpl(jobConf, attemptId) + + committer.setupTask(taskContext) + committer.commitTask(taskContext) } } From 528c9281aecc49e9bff204dd303962c705c6f237 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 30 Jun 2017 23:25:14 +0800 Subject: [PATCH 118/118] [ML] Fix scala-2.10 build failure of GeneralizedLinearRegressionSuite. ## What changes were proposed in this pull request? Fix scala-2.10 build failure of ```GeneralizedLinearRegressionSuite```. ## How was this patch tested? Build with scala-2.10. Author: Yanbo Liang Closes #18489 from yanboliang/glr. --- .../ml/regression/GeneralizedLinearRegressionSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index cfaa57314bd66..83f1344a7bcb1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -1075,7 +1075,7 @@ class GeneralizedLinearRegressionSuite val seCoefR = Array(1.23439, 0.9669, 3.56866) val tValsR = Array(0.80297, -0.65737, -0.06017) val pValsR = Array(0.42199, 0.51094, 0.95202) - val dispersionR = 1 + val dispersionR = 1.0 val nullDevianceR = 2.17561 val residualDevianceR = 0.00018 val residualDegreeOfFreedomNullR = 3 @@ -1114,7 +1114,7 @@ class GeneralizedLinearRegressionSuite assert(x._1 ~== x._2 absTol 1E-3) } summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } - assert(summary.dispersion ~== dispersionR absTol 1E-3) + assert(summary.dispersion === dispersionR) assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3) assert(summary.deviance ~== residualDevianceR absTol 1E-3) assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) @@ -1190,7 +1190,7 @@ class GeneralizedLinearRegressionSuite val seCoefR = Array(1.16826, 0.41703, 1.96249) val tValsR = Array(-2.46387, 2.12428, -2.32757) val pValsR = Array(0.01374, 0.03365, 0.01993) - val dispersionR = 1 + val dispersionR = 1.0 val nullDevianceR = 22.55853 val residualDevianceR = 9.5622 val residualDegreeOfFreedomNullR = 3 @@ -1229,7 +1229,7 @@ class GeneralizedLinearRegressionSuite assert(x._1 ~== x._2 absTol 1E-3) } summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } - assert(summary.dispersion ~== dispersionR absTol 1E-3) + assert(summary.dispersion === dispersionR) assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3) assert(summary.deviance ~== residualDevianceR absTol 1E-3) assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)