Skip to content

Commit

Permalink
[SPARK-45914][PYTHON] Support commit and abort API for Python data so…
Browse files Browse the repository at this point in the history
…urce write

### What changes were proposed in this pull request?

This PR introduces support for the commit and abort APIs for Python data source write. After this PR, users can customize their implementations for committing and aborting write operations.

### Why are the changes needed?

To support Python data source.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

New unit tests

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #44497 from allisonwang-db/spark-45914-commit-abort.

Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
allisonwang-db authored and HyukjinKwon committed Dec 29, 2023
1 parent e9c4e7e commit f99b86a
Show file tree
Hide file tree
Showing 5 changed files with 352 additions and 28 deletions.
84 changes: 63 additions & 21 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,12 @@ def reader(self, schema) -> "DataSourceReader":
assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")])
self.assertEqual(df.rdd.getNumPartitions(), 2)

def test_custom_json_data_source(self):
def _get_test_json_data_source(self):
import json
import os
from dataclasses import dataclass

class JsonDataSourceReader(DataSourceReader):
class TestJsonReader(DataSourceReader):
def __init__(self, options):
self.options = options

Expand All @@ -242,18 +244,39 @@ def read(self, partition):
data = json.loads(line)
yield data.get("name"), data.get("age")

class JsonDataSourceWriter(DataSourceWriter):
@dataclass
class TestCommitMessage(WriterCommitMessage):
count: int

class TestJsonWriter(DataSourceWriter):
def __init__(self, options):
self.options = options
self.path = self.options.get("path")

def write(self, iterator):
path = self.options.get("path")
with open(path, "w") as file:
from pyspark import TaskContext

context = TaskContext.get()
output_path = os.path.join(self.path, f"{context.partitionId}.json")
count = 0
with open(output_path, "w") as file:
for row in iterator:
count += 1
if "id" in row and row.id > 5:
raise Exception("id > 5")
file.write(json.dumps(row.asDict()) + "\n")
return WriterCommitMessage()
return TestCommitMessage(count=count)

class JsonDataSource(DataSource):
def commit(self, messages):
total_count = sum(message.count for message in messages)
with open(os.path.join(self.path, "_success.txt"), "a") as file:
file.write(f"count: {total_count}\n")

def abort(self, messages):
with open(os.path.join(self.path, "_failed.txt"), "a") as file:
file.write("failed")

class TestJsonDataSource(DataSource):
@classmethod
def name(cls):
return "my-json"
Expand All @@ -262,13 +285,16 @@ def schema(self):
return "name STRING, age INT"

def reader(self, schema) -> "DataSourceReader":
return JsonDataSourceReader(self.options)
return TestJsonReader(self.options)

def writer(self, schema, overwrite):
return JsonDataSourceWriter(self.options)
return TestJsonWriter(self.options)

return TestJsonDataSource

self.spark.dataSource.register(JsonDataSource)
# Test data source read.
def test_custom_json_data_source_read(self):
data_source = self._get_test_json_data_source()
self.spark.dataSource.register(data_source)
path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
path2 = os.path.join(SPARK_HOME, "python/test_support/sql/people1.json")
assertDataFrameEqual(
Expand All @@ -279,18 +305,34 @@ def writer(self, schema, overwrite):
self.spark.read.format("my-json").load(path2),
[Row(name="Jonathan", age=None)],
)
# Test data source write.
df = self.spark.read.json(path1)

def test_custom_json_data_source_write(self):
data_source = self._get_test_json_data_source()
self.spark.dataSource.register(data_source)
input_path = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
df = self.spark.read.json(input_path)
with tempfile.TemporaryDirectory() as d:
df.write.format("my-json").mode("append").save(d)
assertDataFrameEqual(self.spark.read.json(d), self.spark.read.json(input_path))

def test_custom_json_data_source_commit(self):
data_source = self._get_test_json_data_source()
self.spark.dataSource.register(data_source)
with tempfile.TemporaryDirectory() as d:
self.spark.range(0, 5, 1, 3).write.format("my-json").mode("append").save(d)
with open(os.path.join(d, "_success.txt"), "r") as file:
text = file.read()
assert text == "count: 5\n"

def test_custom_json_data_source_abort(self):
data_source = self._get_test_json_data_source()
self.spark.dataSource.register(data_source)
with tempfile.TemporaryDirectory() as d:
path = os.path.join(d, "res.json")
df.write.format("my-json").mode("append").save(path)
with open(path, "r") as file:
with self.assertRaises(PythonException):
self.spark.range(0, 8, 1, 3).write.format("my-json").mode("append").save(d)
with open(os.path.join(d, "_failed.txt"), "r") as file:
text = file.read()
assert text == (
'{"age": null, "name": "Michael"}\n'
'{"age": 30, "name": "Andy"}\n'
'{"age": 19, "name": "Justin"}\n'
)
assert text == "failed"


class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
Expand Down
121 changes: 121 additions & 0 deletions python/pyspark/sql/worker/commit_data_source_write.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import sys
from typing import IO

from pyspark.accumulators import _accumulatorRegistry
from pyspark.errors import PySparkAssertionError
from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import (
read_bool,
read_int,
write_int,
SpecialLengths,
)
from pyspark.sql.datasource import DataSourceWriter, WriterCommitMessage
from pyspark.util import handle_worker_exception
from pyspark.worker_util import (
check_python_version,
pickleSer,
send_accumulator_updates,
setup_broadcasts,
setup_memory_limits,
setup_spark_files,
)


def main(infile: IO, outfile: IO) -> None:
"""
Main method for committing or aborting a data source write operation.
This process is invoked from the `UserDefinedPythonDataSourceCommitRunner.runInPython`
method in the BatchWrite implementation of the PythonTableProvider. It is
responsible for invoking either the `commit` or the `abort` method on a data source
writer instance, given a list of commit messages.
"""
try:
check_python_version(infile)

memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
setup_memory_limits(memory_limit_mb)

setup_spark_files(infile)
setup_broadcasts(infile)

_accumulatorRegistry.clear()

# Receive the data source writer instance.
writer = pickleSer._read_with_length(infile)
if not isinstance(writer, DataSourceWriter):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "an instance of DataSourceWriter",
"actual": f"'{type(writer).__name__}'",
},
)

# Receive the commit messages.
num_messages = read_int(infile)
commit_messages = []
for _ in range(num_messages):
message = pickleSer._read_with_length(infile)
if message is not None and not isinstance(message, WriterCommitMessage):
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "an instance of WriterCommitMessage",
"actual": f"'{type(message).__name__}'",
},
)
commit_messages.append(message)

# Receive a boolean to indicate whether to invoke `abort`.
abort = read_bool(infile)

# Commit or abort the Python data source write.
# Note the commit messages can be None if there are failed tasks.
if abort:
writer.abort(commit_messages) # type: ignore[arg-type]
else:
writer.commit(commit_messages) # type: ignore[arg-type]

# Send a status code back to JVM.
write_int(0, outfile)

except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)

send_accumulator_updates(outfile)

# check end of stream
if read_int(infile) == SpecialLengths.END_OF_STREAM:
write_int(SpecialLengths.END_OF_STREAM, outfile)
else:
# write a different value to tell JVM to not reuse this worker
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
sys.exit(-1)


if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
main(sock_file, sock_file)
3 changes: 3 additions & 0 deletions python/pyspark/sql/worker/write_into_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ def batch_to_rows() -> Iterator[Row]:
command = (data_source_write_func, return_type)
pickleSer._write_with_length(command, outfile)

# Return the picked writer.
pickleSer._write_with_length(writer, outfile)

except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ class PythonTableProvider extends TableProvider {

override def toBatch: BatchWrite = new BatchWrite {

// Store the pickled data source writer instance.
private var pythonDataSourceWriter: Array[Byte] = _

override def createBatchWriterFactory(
physicalInfo: PhysicalWriteInfo): DataWriterFactory = {

Expand All @@ -136,14 +139,21 @@ class PythonTableProvider extends TableProvider {
info.schema(),
info.options(),
isTruncate)

pythonDataSourceWriter = writeInfo.writer

PythonBatchWriterFactory(source, writeInfo.func, info.schema(), jobArtifactUUID)
}

// TODO(SPARK-45914): Support commit protocol
override def commit(messages: Array[WriterCommitMessage]): Unit = {}
override def commit(messages: Array[WriterCommitMessage]): Unit = {
source.commitWriteInPython(pythonDataSourceWriter, messages)
}

// TODO(SPARK-45914): Support commit protocol
override def abort(messages: Array[WriterCommitMessage]): Unit = {}
override def abort(messages: Array[WriterCommitMessage]): Unit = {
source.commitWriteInPython(pythonDataSourceWriter, messages, abort = true)
}

override def toString: String = shortName
}

override def description: String = "(Python)"
Expand Down Expand Up @@ -333,6 +343,17 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
overwrite).runInPython()
}

/**
* (Driver-side) Run Python process to either commit or abort a write operation.
*/
def commitWriteInPython(
writer: Array[Byte],
messages: Array[WriterCommitMessage],
abort: Boolean = false): Unit = {
new UserDefinedPythonDataSourceCommitRunner(
dataSourceCls, writer, messages, abort).runInPython()
}

/**
* (Executor-side) Create an iterator that execute the Python function.
*/
Expand Down Expand Up @@ -590,7 +611,7 @@ class UserDefinedPythonDataSourceReadRunner(
/**
* Hold the results of running [[UserDefinedPythonDataSourceWriteRunner]].
*/
case class PythonDataSourceWriteInfo(func: Array[Byte])
case class PythonDataSourceWriteInfo(func: Array[Byte], writer: Array[Byte])

/**
* A runner that creates a Python data source writer instance and returns a Python function
Expand Down Expand Up @@ -640,9 +661,55 @@ class UserDefinedPythonDataSourceWriteRunner(
action = "plan", tpe = "write", msg = msg)
}

// Receive the pickled data source.
// Receive the pickled data source write function.
val writeUdf: Array[Byte] = PythonWorkerUtils.readBytes(length, dataIn)

PythonDataSourceWriteInfo(func = writeUdf)
// Receive the pickled instance of the data source writer.
val writer: Array[Byte] = PythonWorkerUtils.readBytes(dataIn)

PythonDataSourceWriteInfo(func = writeUdf, writer = writer)
}
}

/**
* A runner that takes a Python data source writer and a list of commit messages,
* and invokes the `commit` or `abort` method of the writer in Python.
*/
class UserDefinedPythonDataSourceCommitRunner(
dataSourceCls: PythonFunction,
writer: Array[Byte],
messages: Array[WriterCommitMessage],
abort: Boolean) extends PythonPlannerRunner[Unit](dataSourceCls) {
override val workerModule: String = "pyspark.sql.worker.commit_data_source_write"

override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = {
// Send the Python data source writer.
PythonWorkerUtils.writeBytes(writer, dataOut)

// Send the commit messages.
dataOut.writeInt(messages.length)
messages.foreach { message =>
// Commit messages can be null if there are task failures.
if (message == null) {
dataOut.writeInt(SpecialLengths.NULL)
} else {
PythonWorkerUtils.writeBytes(
message.asInstanceOf[PythonWriterCommitMessage].pickledMessage, dataOut)
}
}

// Send whether to invoke `abort` instead of `commit`.
dataOut.writeBoolean(abort)
}

override protected def receiveFromPython(dataIn: DataInputStream): Unit = {
// Receive any exceptions thrown in the Python worker.
val code = dataIn.readInt()
if (code == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
val msg = PythonWorkerUtils.readUTF(dataIn)
throw QueryCompilationErrors.failToPlanDataSourceError(
action = "commit or abort", tpe = "write", msg = msg)
}
assert(code == 0, s"Python commit job should run successfully, but got exit code: $code")
}
}
Loading

0 comments on commit f99b86a

Please sign in to comment.