Skip to content

Commit

Permalink
Merge branch 'master' of github.com:apache/spark into concurrent-sql-…
Browse files Browse the repository at this point in the history
…executions

Conflicts:
	core/src/test/scala/org/apache/spark/ThreadingSuite.scala
  • Loading branch information
Andrew Or committed Sep 11, 2015
2 parents 5297f79 + c34fc19 commit 3c00cc6
Show file tree
Hide file tree
Showing 21 changed files with 626 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ private[spark] class ExternalSorter[K, V, C](
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
while (it.hasNext) {
val partitionId = it.nextPartition()
require(partitionId >= 0 && partitionId < numPartitions,
s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
it.writeNext(writer)
elementsPerPartition(partitionId) += 1
objectsWritten += 1
Expand Down
71 changes: 45 additions & 26 deletions core/src/test/scala/org/apache/spark/ThreadingSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,30 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
val nums = sc.parallelize(1 to 2, 2)
val sem = new Semaphore(0)
ThreadingSuiteState.clear()
var throwable: Option[Throwable] = None
for (i <- 0 until 2) {
new Thread {
override def run() {
val ans = nums.map(number => {
val running = ThreadingSuiteState.runningThreads
running.getAndIncrement()
val time = System.currentTimeMillis()
while (running.get() != 4 && System.currentTimeMillis() < time + 1000) {
Thread.sleep(100)
}
if (running.get() != 4) {
ThreadingSuiteState.failed.set(true)
}
number
}).collect()
assert(ans.toList === List(1, 2))
sem.release()
try {
val ans = nums.map(number => {
val running = ThreadingSuiteState.runningThreads
running.getAndIncrement()
val time = System.currentTimeMillis()
while (running.get() != 4 && System.currentTimeMillis() < time + 1000) {
Thread.sleep(100)
}
if (running.get() != 4) {
ThreadingSuiteState.failed.set(true)
}
number
}).collect()
assert(ans.toList === List(1, 2))
} catch {
case t: Throwable =>
throwable = Some(t)
} finally {
sem.release()
}
}
}.start()
}
Expand All @@ -145,19 +152,25 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
ThreadingSuiteState.runningThreads.get() + "); failing test")
fail("One or more threads didn't see runningThreads = 4")
}
throwable.foreach { t => throw t }
}

test("set local properties in different thread") {
sc = new SparkContext("local", "test")
val sem = new Semaphore(0)

var throwable: Option[Throwable] = None
val threads = (1 to 5).map { i =>
new Thread() {
override def run() {
// TODO: these assertion failures don't actually fail the test...
sc.setLocalProperty("test", i.toString)
assert(sc.getLocalProperty("test") === i.toString)
sem.release()
try {
sc.setLocalProperty("test", i.toString)
assert(sc.getLocalProperty("test") === i.toString)
} catch {
case t: Throwable =>
throwable = Some(t)
} finally {
sem.release()
}
}
}
}
Expand All @@ -166,21 +179,27 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {

sem.acquire(5)
assert(sc.getLocalProperty("test") === null)
throwable.foreach { t => throw t }
}

test("set and get local properties in parent-children thread") {
sc = new SparkContext("local", "test")
sc.setLocalProperty("test", "parent")
val sem = new Semaphore(0)

var throwable: Option[Throwable] = None
val threads = (1 to 5).map { i =>
new Thread() {
override def run() {
// TODO: these assertion failures don't actually fail the test...
assert(sc.getLocalProperty("test") === "parent")
sc.setLocalProperty("test", i.toString)
assert(sc.getLocalProperty("test") === i.toString)
sem.release()
try {
assert(sc.getLocalProperty("test") === "parent")
sc.setLocalProperty("test", i.toString)
assert(sc.getLocalProperty("test") === i.toString)
} catch {
case t: Throwable =>
throwable = Some(t)
} finally {
sem.release()
}
}
}
}
Expand All @@ -190,6 +209,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
sem.acquire(5)
assert(sc.getLocalProperty("test") === "parent")
assert(sc.getLocalProperty("Foo") === null)
throwable.foreach { t => throw t }
}

test("inheritance exclusions (SPARK-10548)") {
Expand Down Expand Up @@ -236,7 +256,6 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
// Create a new thread which will inherit the current thread's properties
val thread = new Thread() {
override def run(): Unit = {
// TODO: these assertion failures don't actually fail the test...
assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId")
// Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task
try {
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def __getnewargs__(self):
# This method is called when attempting to pickle SparkContext, which is always an error:
raise Exception(
"It appears that you are attempting to reference SparkContext from a broadcast "
"variable, action, or transforamtion. SparkContext can only be used on the driver, "
"variable, action, or transformation. SparkContext can only be used on the driver, "
"not in code that it run on workers. For more information, see SPARK-5063."
)

Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/sql/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ def _(self):
return _


def _bin_func_op(name, reverse=False, doc="binary function"):
def _(self, other):
sc = SparkContext._active_spark_context
fn = getattr(sc._jvm.functions, name)
jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other)
njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc)
return Column(njc)
_.__doc__ = doc
return _


def _bin_op(name, doc="binary operator"):
""" Create a method for given binary operator
"""
Expand Down Expand Up @@ -151,6 +162,8 @@ def __init__(self, jc):
__rdiv__ = _reverse_op("divide")
__rtruediv__ = _reverse_op("divide")
__rmod__ = _reverse_op("mod")
__pow__ = _bin_func_op("pow")
__rpow__ = _bin_func_op("pow", reverse=True)

# logistic operators
__eq__ = _bin_op("equalTo")
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def test_column_operators(self):
cs = self.df.value
c = ci == cs
self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1)
self.assertTrue(all(isinstance(c, Column) for c in rcc))
cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7]
self.assertTrue(all(isinstance(c, Column) for c in cb))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{StringUtils, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

Expand Down Expand Up @@ -140,7 +140,15 @@ case class Cast(child: Expression, dataType: DataType)
// UDFToBoolean
private[this] def castToBoolean(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, _.numBytes() != 0)
buildCast[UTF8String](_, s => {
if (StringUtils.isTrueString(s)) {
true
} else if (StringUtils.isFalseString(s)) {
false
} else {
null
}
})
case TimestampType =>
buildCast[Long](_, t => t != 0)
case DateType =>
Expand Down Expand Up @@ -646,7 +654,17 @@ case class Cast(child: Expression, dataType: DataType)

private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
case StringType =>
(c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;"
val stringUtils = StringUtils.getClass.getName.stripSuffix("$")
(c, evPrim, evNull) =>
s"""
if ($stringUtils.isTrueString($c)) {
$evPrim = true;
} else if ($stringUtils.isFalseString($c)) {
$evPrim = false;
} else {
$evNull = true;
}
"""
case TimestampType =>
(c, evPrim, evNull) => s"$evPrim = $c != 0;"
case DateType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,10 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
// a && a => a
case (l, r) if l fastEquals r => l
// a && (not(a) || b) => a && b
case (l, Or(l1, r)) if (Not(l) fastEquals l1) => And(l, r)
case (l, Or(r, l1)) if (Not(l) fastEquals l1) => And(l, r)
case (Or(l, l1), r) if (l1 fastEquals Not(r)) => And(l, r)
case (Or(l1, l), r) if (l1 fastEquals Not(r)) => And(l, r)
case (l, Or(l1, r)) if (Not(l) == l1) => And(l, r)
case (l, Or(r, l1)) if (Not(l) == l1) => And(l, r)
case (Or(l, l1), r) if (l1 == Not(r)) => And(l, r)
case (Or(l1, l), r) if (l1 == Not(r)) => And(l, r)
// (a || b) && (a || c) => a || (b && c)
case _ =>
// 1. Split left and right to get the disjunctive predicates,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util

import java.util.regex.Pattern

import org.apache.spark.unsafe.types.UTF8String

object StringUtils {

// replace the _ with .{1} exactly match 1 time of any character
Expand All @@ -44,4 +46,10 @@ object StringUtils {
v
}
}

private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString)
private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString)

def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)
}
Loading

0 comments on commit 3c00cc6

Please sign in to comment.