Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzhan committed Oct 13, 2014
2 parents 8fad1cf + 497b0f4 commit 0d4d2ed
Show file tree
Hide file tree
Showing 176 changed files with 1,390 additions and 162 deletions.
12 changes: 7 additions & 5 deletions core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,18 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
if (partsScanned > 0) {
// If we didn't find any rows after the first iteration, just try all partitions next.
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
// by 50%.
// by 50%. We also cap the estimation in the end.
if (results.size == 0) {
numPartsToTry = totalParts - 1
numPartsToTry = partsScanned * 4
} else {
numPartsToTry = (1.5 * num * partsScanned / results.size).toInt
// the left side of max is >=1 whenever partsScanned >= 2
numPartsToTry = Math.max(1,
(1.5 * num * partsScanned / results.size).toInt - partsScanned)
numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
}
}
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions

val left = num - results.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
Expand Down
8 changes: 5 additions & 3 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1079,15 +1079,17 @@ abstract class RDD[T: ClassTag](
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise,
// If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise,
// interpolate the number of partitions we need to try, but overestimate it by 50%.
// We also cap the estimation in the end.
if (buf.size == 0) {
numPartsToTry = partsScanned * 4
} else {
numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt
// the left side of max is >=1 whenever partsScanned >= 2
numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1)
numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
}
}
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions

val left = num - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,10 +1070,13 @@ def take(self, num):
# If we didn't find any rows after the previous iteration,
# quadruple and retry. Otherwise, interpolate the number of
# partitions we need to try, but overestimate it by 50%.
# We also cap the estimation in the end.
if len(items) == 0:
numPartsToTry = partsScanned * 4
else:
numPartsToTry = int(1.5 * num * partsScanned / len(items))
# the first paramter of max is >=1 whenever partsScanned >= 2
numPartsToTry = int(1.5 * num * partsScanned / len(items)) - partsScanned
numPartsToTry = min(max(numPartsToTry, 1), partsScanned * 4)

left = num - len(items)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst

import java.sql.Timestamp
import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
Expand Down Expand Up @@ -77,8 +77,9 @@ object ScalaReflection {
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
typeCoercionRules ++
extendedRules : _*),
Batch("Check Analysis", Once,
CheckResolution),
CheckResolution,
CheckAggregation),
Batch("AnalysisOperators", fixedPoint,
EliminateAnalysisOperators)
)
Expand All @@ -88,6 +89,32 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
}
}

/**
* Checks for non-aggregated attributes with aggregation
*/
object CheckAggregation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transform {
case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
def isValidAggregateExpression(expr: Expression): Boolean = expr match {
case _: AggregateExpression => true
case e: Attribute => groupingExprs.contains(e)
case e if groupingExprs.contains(e) => true
case e if e.references.isEmpty => true
case e => e.children.forall(isValidAggregateExpression)
}

aggregateExprs.foreach { e =>
if (!isValidAggregateExpression(e)) {
throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
}
}

aggregatePlan
}
}
}

/**
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
Expand Down Expand Up @@ -204,18 +231,17 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
*/
object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
if aggregate.resolved && containsAggregate(havingCondition) => {
val evaluatedCondition = Alias(havingCondition, "havingCondition")()
val aggExprsWithHaving = evaluatedCondition +: originalAggExprs

Project(aggregate.output,
Filter(evaluatedCondition.toAttribute,
aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
}

}

protected def containsAggregate(condition: Expression): Boolean =
condition
.collect { case ae: AggregateExpression => ae }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,20 +220,39 @@ trait HiveTypeCoercion {
case a: BinaryArithmetic if a.right.dataType == StringType =>
a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))

// we should cast all timestamp/date/string compare into string compare
case p: BinaryPredicate if p.left.dataType == StringType
&& p.right.dataType == DateType =>
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == DateType
&& p.right.dataType == StringType =>
p.makeCopy(Array(Cast(p.left, StringType), p.right))
case p: BinaryPredicate if p.left.dataType == StringType
&& p.right.dataType == TimestampType =>
p.makeCopy(Array(Cast(p.left, TimestampType), p.right))
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == TimestampType
&& p.right.dataType == StringType =>
p.makeCopy(Array(p.left, Cast(p.right, TimestampType)))
p.makeCopy(Array(Cast(p.left, StringType), p.right))
case p: BinaryPredicate if p.left.dataType == TimestampType
&& p.right.dataType == DateType =>
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == DateType
&& p.right.dataType == TimestampType =>
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))

case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType =>
p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType =>
p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))

case i @ In(a,b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
i.makeCopy(Array(a,b.map(Cast(_,TimestampType))))
case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == StringType) =>
i.makeCopy(Array(Cast(a, StringType), b))
case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
i.makeCopy(Array(Cast(a, StringType), b))
case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))

case Sum(e) if e.dataType == StringType =>
Sum(Cast(e, DoubleType))
Expand Down Expand Up @@ -283,6 +302,8 @@ trait HiveTypeCoercion {
// Skip if the type is boolean type already. Note that this extra cast should be removed
// by optimizer.SimplifyCasts.
case Cast(e, BooleanType) if e.dataType == BooleanType => e
// DateType should be null if be cast to boolean.
case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType)
// If the data type is not boolean and is being cast boolean, turn it into a comparison
// with the numeric value, i.e. x != 0. This will coerce the type into numeric type.
case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ case class Star(
mapFunction: Attribute => Expression = identity[Attribute])
extends Attribute with trees.LeafNode[Expression] {

override def name = throw new UnresolvedException(this, "exprId")
override def name = throw new UnresolvedException(this, "name")
override def exprId = throw new UnresolvedException(this, "exprId")
override def dataType = throw new UnresolvedException(this, "dataType")
override def nullable = throw new UnresolvedException(this, "nullable")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst

import java.sql.Timestamp
import java.sql.{Date, Timestamp}

import scala.language.implicitConversions

Expand Down Expand Up @@ -119,6 +119,7 @@ package object dsl {
implicit def floatToLiteral(f: Float) = Literal(f)
implicit def doubleToLiteral(d: Double) = Literal(d)
implicit def stringToLiteral(s: String) = Literal(s)
implicit def dateToLiteral(d: Date) = Literal(d)
implicit def decimalToLiteral(d: BigDecimal) = Literal(d)
implicit def timestampToLiteral(t: Timestamp) = Literal(t)
implicit def binaryToLiteral(a: Array[Byte]) = Literal(a)
Expand Down Expand Up @@ -174,6 +175,9 @@ package object dsl {
/** Creates a new AttributeReference of type string */
def string = AttributeReference(s, StringType, nullable = true)()

/** Creates a new AttributeReference of type date */
def date = AttributeReference(s, DateType, nullable = true)()

/** Creates a new AttributeReference of type decimal */
def decimal = AttributeReference(s, DecimalType, nullable = true)()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,26 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.Star

protected class AttributeEquals(val a: Attribute) {
override def hashCode() = a.exprId.hashCode()
override def equals(other: Any) = other match {
case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId
case otherAttribute => false
override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match {
case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId
case (a1, a2) => a1 == a2
}
}

object AttributeSet {
/** Constructs a new [[AttributeSet]] given a sequence of [[Attribute Attributes]]. */
def apply(baseSet: Seq[Attribute]) = {
new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet)
}
def apply(a: Attribute) =
new AttributeSet(Set(new AttributeEquals(a)))

/** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
def apply(baseSet: Seq[Expression]) =
new AttributeSet(
baseSet
.flatMap(_.references)
.map(new AttributeEquals(_)).toSet)
}

/**
Expand Down Expand Up @@ -103,4 +110,6 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
// We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all
// sorts of things in its closure.
override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq

override def toString = "{" + baseSet.map(_.a).mkString(", ") + "}"
}
Loading

0 comments on commit 0d4d2ed

Please sign in to comment.