Skip to content

Commit

Permalink
[SPARK-31945][SQL][PYSPARK] Enable cache for the same Python function
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR proposes to make `PythonFunction` holds `Seq[Byte]` instead of `Array[Byte]` to be able to compare if the byte array has the same values for the cache manager.

### Why are the changes needed?

Currently the cache manager doesn't use the cache for `udf` if the `udf` is created again even if the functions is the same.

```py
>>> func = lambda x: x

>>> df = spark.range(1)
>>> df.select(udf(func)("id")).cache()
```
```py
>>> df.select(udf(func)("id")).explain()
== Physical Plan ==
*(2) Project [pythonUDF0#14 AS <lambda>(id)#12]
+- BatchEvalPython [<lambda>(id#0L)], [pythonUDF0#14]
 +- *(1) Range (0, 1, step=1, splits=12)
```

This is because `PythonFunction` holds `Array[Byte]`, and `equals` method of array equals only when the both array is the same instance.

### Does this PR introduce _any_ user-facing change?

Yes, if the user reuse the Python function for the UDF, the cache manager will detect the same function and use the cache for it.

### How was this patch tested?

I added a test case and manually.

```py
>>> df.select(udf(func)("id")).explain()
== Physical Plan ==
InMemoryTableScan [<lambda>(id)#12]
   +- InMemoryRelation [<lambda>(id)#12], StorageLevel(disk, memory, deserialized, 1 replicas)
         +- *(2) Project [pythonUDF0#5 AS <lambda>(id)#3]
            +- BatchEvalPython [<lambda>(id#0L)], [pythonUDF0#5]
               +- *(1) Range (0, 1, step=1, splits=12)
```

Closes #28774 from ueshin/issues/SPARK-31945/udf_cache.

Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: HyukjinKwon <[email protected]>
  • Loading branch information
ueshin authored and HyukjinKwon committed Jun 10, 2020
1 parent e14029b commit 032d179
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 4 deletions.
16 changes: 14 additions & 2 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,25 @@ private[spark] class PythonRDD(
* runner.
*/
private[spark] case class PythonFunction(
command: Array[Byte],
command: Seq[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: PythonAccumulatorV2)
accumulator: PythonAccumulatorV2) {

def this(
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: PythonAccumulatorV2) = {
this(command.toSeq, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator)
}
}

/**
* A wrapper for chained Python functions (from bottom to top).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions])
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
val command = funcs.head.funcs.head.command
dataOut.writeInt(command.length)
dataOut.write(command)
dataOut.write(command.toArray)
}

protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
Expand Down
9 changes: 9 additions & 0 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,15 @@ def f(*a):
r = df.select(fUdf(*df.columns))
self.assertEqual(r.first()[0], "success")

def test_udf_cache(self):
func = lambda x: x

df = self.spark.range(1)
df.select(udf(func)("id")).cache()

self.assertEqual(df.select(udf(func)("id"))._jdf.queryExecution()
.withCachedData().getClass().getSimpleName(), 'InMemoryRelation')


class UDFInitializationTests(unittest.TestCase):
def tearDown(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ object PythonUDFRunner {
dataOut.writeInt(chained.funcs.length)
chained.funcs.foreach { f =>
dataOut.writeInt(f.command.length)
dataOut.write(f.command)
dataOut.write(f.command.toArray)
}
}
}
Expand Down

0 comments on commit 032d179

Please sign in to comment.