Skip to content

Commit

Permalink
[SPARK-6194] [SPARK-677] [PySpark] fix memory leak in collect()
Browse files Browse the repository at this point in the history
Because circular reference between JavaObject and JavaMember, an Java object can not be released until Python GC kick in, then it will cause memory leak in collect(), which may consume lots of memory in JVM.

This PR change the way we sending collected data back into Python from local file to socket, which could avoid any disk IO during collect, also avoid any referrers of Java object in Python.

cc JoshRosen

Author: Davies Liu <[email protected]>

Closes #4923 from davies/fix_collect and squashes the following commits:

d730286 [Davies Liu] address comments
24c92a4 [Davies Liu] fix style
ba54614 [Davies Liu] use socket to transfer data from JVM
9517c8f [Davies Liu] fix memory leak in collect()

(cherry picked from commit 8767565)
Signed-off-by: Josh Rosen <[email protected]>

Conflicts:
	core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
	python/pyspark/rdd.py
	python/pyspark/sql/dataframe.py
  • Loading branch information
Davies Liu authored and JoshRosen committed Mar 10, 2015
1 parent e753f9c commit d7c359b
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 54 deletions.
74 changes: 58 additions & 16 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,28 @@ package org.apache.spark.api.python

import java.io._
import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections}

import org.apache.spark.input.PortableDataStream
import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap}

import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.existentials

import com.google.common.base.Charsets.UTF_8

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf}
import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat}

import org.apache.spark._
import org.apache.spark.SparkContext._
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

import scala.util.control.NonFatal

private[spark] class PythonRDD(
@transient parent: RDD[_],
command: Array[Byte],
Expand Down Expand Up @@ -331,21 +332,33 @@ private[spark] object PythonRDD extends Logging {
/**
* Adapter for calling SparkContext#runJob from Python.
*
* This method will return an iterator of an array that contains all elements in the RDD
* This method will serve an iterator of an array that contains all elements in the RDD
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
*
* @return the port number of a local socket which serves the data collected from this job.
*/
def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
partitions: JArrayList[Int],
allowLocal: Boolean): Iterator[Array[Byte]] = {
allowLocal: Boolean): Int = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
flattenedPartition.iterator
serveIterator(flattenedPartition.iterator,
s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}")
}

/**
* A helper function to collect an RDD as an iterator, then serve it via socket.
*
* @return the port number of a local socket which serves the data collected from this job.
*/
def collectAndServe[T](rdd: RDD[T]): Int = {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}

def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
Expand Down Expand Up @@ -594,15 +607,44 @@ private[spark] object PythonRDD extends Logging {
dataOut.write(bytes)
}

def writeToFile[T](items: java.util.Iterator[T], filename: String) {
import scala.collection.JavaConverters._
writeToFile(items.asScala, filename)
}
/**
* Create a socket server and a background thread to serve the data in `items`,
*
* The socket server can only accept one connection, or close if no connection
* in 3 seconds.
*
* Once a connection comes in, it tries to serialize all the data in `items`
* and send them into this connection.
*
* The thread will terminate after all the data are sent or any exceptions happen.
*/
private def serveIterator[T](items: Iterator[T], threadName: String): Int = {
val serverSocket = new ServerSocket(0, 1)
serverSocket.setReuseAddress(true)
// Close the socket if no connection in 3 seconds
serverSocket.setSoTimeout(3000)

new Thread(threadName) {
setDaemon(true)
override def run() {
try {
val sock = serverSocket.accept()
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
try {
writeIteratorToStream(items, out)
} finally {
out.close()
}
} catch {
case NonFatal(e) =>
logError(s"Error while sending iterator", e)
} finally {
serverSocket.close()
}
}
}.start()

def writeToFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
writeIteratorToStream(items, file)
file.close()
serverSocket.getLocalPort
}

private def getMergedConf(confAsMap: java.util.HashMap[String, String],
Expand Down
13 changes: 6 additions & 7 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from tempfile import NamedTemporaryFile
import atexit

from py4j.java_collections import ListConverter

from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
Expand All @@ -31,11 +33,9 @@
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer, AutoBatchedSerializer, NoOpSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
from pyspark.rdd import RDD, _load_from_socket
from pyspark.traceback_utils import CallSite, first_spark_call

from py4j.java_collections import ListConverter


__all__ = ['SparkContext']

Expand All @@ -58,7 +58,6 @@ class SparkContext(object):

_gateway = None
_jvm = None
_writeToFile = None
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
Expand Down Expand Up @@ -211,7 +210,6 @@ def _ensure_initialized(cls, instance=None, gateway=None):
if not SparkContext._gateway:
SparkContext._gateway = gateway or launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile

if instance:
if (SparkContext._active_spark_context and
Expand Down Expand Up @@ -824,8 +822,9 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
# by runJob() in order to avoid having to pass a Python lambda into
# SparkContext#runJob.
mappedRDD = rdd.mapPartitions(partitionFunc)
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(mappedRDD._collect_iterator_through_file(it))
port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions,
allowLocal)
return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))

def _add_profile(self, id, profileAcc):
if not self._profile_stats:
Expand Down
43 changes: 27 additions & 16 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from collections import defaultdict
from itertools import chain, ifilter, imap
import operator
import os
import sys
import shlex
from subprocess import Popen, PIPE
Expand All @@ -29,6 +28,7 @@
import heapq
import bisect
import random
import socket
from math import sqrt, log, isinf, isnan

from pyspark.accumulators import PStatsParam
Expand Down Expand Up @@ -112,6 +112,30 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])


def _load_from_socket(port, serializer):
sock = socket.socket()
try:
sock.connect(("localhost", port))
rf = sock.makefile("rb", 65536)
for item in serializer.load_stream(rf):
yield item
finally:
sock.close()


class Partitioner(object):
def __init__(self, numPartitions, partitionFunc):
self.numPartitions = numPartitions
self.partitionFunc = partitionFunc

def __eq__(self, other):
return (isinstance(other, Partitioner) and self.numPartitions == other.numPartitions
and self.partitionFunc == other.partitionFunc)

def __call__(self, k):
return self.partitionFunc(k) % self.numPartitions


class RDD(object):

"""
Expand Down Expand Up @@ -683,21 +707,8 @@ def collect(self):
Return a list that contains all of the elements in this RDD.
"""
with SCCallSiteSync(self.context) as css:
bytesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(bytesInJava))

def _collect_iterator_through_file(self, iterator):
# Transferring lots of data through Py4J can be slow because
# socket.readline() is inefficient. Instead, we'll dump the data to a
# file and read it back.
tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close()
self.ctx._writeToFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
for item in self._jrdd_deserializer.load_stream(tempFile):
yield item
os.unlink(tempFile.name)
port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
return list(_load_from_socket(port, self._jrdd_deserializer))

def reduce(self, f):
"""
Expand Down
8 changes: 5 additions & 3 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from py4j.protocol import Py4JError
from py4j.java_collections import ListConverter, MapConverter

from pyspark.rdd import RDD
from pyspark.rdd import RDD, _load_from_socket
from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
CloudPickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
Expand Down Expand Up @@ -1996,9 +1996,11 @@ def collect(self):
[Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
"""
with SCCallSiteSync(self.context) as css:
bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator()
rdd = self._jschema_rdd.baseSchemaRDD().javaToPython().rdd()
port = self._sc._jvm.PythonRDD.collectAndServe(rdd)
rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
cls = _create_cls(self.schema())
return map(cls, self._collect_iterator_through_file(bytesInJava))
return [cls(r) for r in rs]

def take(self, num):
"""Take the first num rows of the RDD.
Expand Down
12 changes: 0 additions & 12 deletions sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -412,18 +412,6 @@ class SchemaRDD(
SerDeUtil.javaToPython(jrdd)
}

/**
* Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same
* format as javaToPython. It is used by pyspark.
*/
private[sql] def collectToPython: JList[Array[Byte]] = {
val fieldTypes = schema.fields.map(_.dataType)
val pickle = new Pickler
new java.util.ArrayList(collect().map { row =>
EvaluatePython.rowToArray(row, fieldTypes)
}.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
}

/**
* Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value
* of base RDD functions that do not change schema.
Expand Down

0 comments on commit d7c359b

Please sign in to comment.