Skip to content

Commit

Permalink
[SPARK-26019][PYSPARK] Allow insecure py4j gateways
Browse files Browse the repository at this point in the history
Spark always creates secure py4j connections between java and python,
but it also allows users to pass in their own connection.  This restores
the ability for users to pass in an _insecure_ connection, though it
forces them to set the env variable 'PYSPARK_ALLOW_INSECURE_GATEWAY=1', and still
issues a warning.

Added test cases verifying the failure without the extra configuration,
and verifying things still work with an insecure configuration (in
particular, accumulators, as those were broken with an insecure py4j
gateway before).

For the tests, I added ways to create insecure gateways, but I tried to put in protections to make sure that wouldn't get used incorrectly.

Closes apache#23337 from squito/SPARK-26019.

Authored-by: Imran Rashid <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
squito authored and HyukjinKwon committed Jan 3, 2019
1 parent 1802124 commit 1e99f4e
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,17 @@ private[spark] object PythonGatewayServer extends Logging {
// with the same secret, in case the app needs callbacks from the JVM to the underlying
// python processes.
val localhost = InetAddress.getLoopbackAddress()
val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder()
.authToken(secret)
val builder = new GatewayServer.GatewayServerBuilder()
.javaPort(0)
.javaAddress(localhost)
.callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
.build()
if (sys.env.getOrElse("_PYSPARK_CREATE_INSECURE_GATEWAY", "0") != "1") {
builder.authToken(secret)
} else {
assert(sys.env.getOrElse("SPARK_TESTING", "0") == "1",
"Creating insecure Java gateways only allowed for testing")
}
val gatewayServer: GatewayServer = builder.build()

gatewayServer.start()
val boundPort: Int = gatewayServer.getListeningPort
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -616,8 +616,10 @@ private[spark] class PythonAccumulatorV2(
if (socket == null || socket.isClosed) {
socket = new Socket(serverHost, serverPort)
logInfo(s"Connected to AccumulatorServer at host: $serverHost port: $serverPort")
// send the secret just for the initial authentication when opening a new connection
socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
if (secretToken != null) {
// send the secret just for the initial authentication when opening a new connection
socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
}
}
socket
}
Expand Down
7 changes: 4 additions & 3 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,10 @@ def authenticate_and_accum_updates():
raise Exception(
"The value of the provided token to the AccumulatorServer is not correct.")

# first we keep polling till we've received the authentication token
poll(authenticate_and_accum_updates)
# now we've authenticated, don't need to check for the token anymore
if auth_token is not None:
# first we keep polling till we've received the authentication token
poll(authenticate_and_accum_updates)
# now we've authenticated if needed, don't need to check for the token anymore
poll(accum_updates)


Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,20 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
ValueError:...
"""
self._callsite = first_spark_call() or CallSite(None, None, None)
if gateway is not None and gateway.gateway_parameters.auth_token is None:
allow_insecure_env = os.environ.get("PYSPARK_ALLOW_INSECURE_GATEWAY", "0")
if allow_insecure_env == "1" or allow_insecure_env.lower() == "true":
warnings.warn(
"You are passing in an insecure Py4j gateway. This "
"presents a security risk, and will be completely forbidden in Spark 3.0")
else:
raise ValueError(
"You are trying to pass an insecure Py4j gateway to Spark. This"
" presents a security risk. If you are sure you understand and accept this"
" risk, you can set the environment variable"
" 'PYSPARK_ALLOW_INSECURE_GATEWAY=1', but"
" note this option will be removed in Spark 3.0")

SparkContext._ensure_initialized(self, gateway=gateway, conf=conf)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
Expand Down
23 changes: 19 additions & 4 deletions python/pyspark/java_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,20 @@ def launch_gateway(conf=None):
"""
launch jvm gateway
:param conf: spark configuration passed to spark-submit
:return:
:return: a JVM gateway
"""
return _launch_gateway(conf)


def _launch_gateway(conf=None, insecure=False):
"""
launch jvm gateway
:param conf: spark configuration passed to spark-submit
:param insecure: True to create an insecure gateway; only for testing
:return: a JVM gateway
"""
if insecure and os.environ.get("SPARK_TESTING", "0") != "1":
raise ValueError("creating insecure gateways is only for testing")
if "PYSPARK_GATEWAY_PORT" in os.environ:
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
Expand Down Expand Up @@ -74,6 +86,8 @@ def launch_gateway(conf=None):

env = dict(os.environ)
env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file
if insecure:
env["_PYSPARK_CREATE_INSECURE_GATEWAY"] = "1"

# Launch the Java gateway.
# We open a pipe to stdin so that the Java gateway can die when the pipe is broken
Expand Down Expand Up @@ -116,9 +130,10 @@ def killChild():
atexit.register(killChild)

# Connect to the gateway
gateway = JavaGateway(
gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret,
auto_convert=True))
gateway_params = GatewayParameters(port=gateway_port, auto_convert=True)
if not insecure:
gateway_params.auth_token = gateway_secret
gateway = JavaGateway(gateway_parameters=gateway_params)

# Import the classes used by PySpark
java_import(gateway.jvm, "org.apache.spark.SparkConf")
Expand Down
32 changes: 32 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from pyspark import keyword_only
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.java_gateway import _launch_gateway
from pyspark.rdd import RDD
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
Expand Down Expand Up @@ -2381,6 +2382,37 @@ def test_startTime(self):
with SparkContext() as sc:
self.assertGreater(sc.startTime, 0)

def test_forbid_insecure_gateway(self):
# By default, we fail immediately if you try to create a SparkContext
# with an insecure gateway
gateway = _launch_gateway(insecure=True)
log4j = gateway.jvm.org.apache.log4j
old_level = log4j.LogManager.getRootLogger().getLevel()
try:
log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL)
with self.assertRaises(Exception) as context:
SparkContext(gateway=gateway)
self.assertIn("insecure Py4j gateway", str(context.exception))
self.assertIn("PYSPARK_ALLOW_INSECURE_GATEWAY", str(context.exception))
self.assertIn("removed in Spark 3.0", str(context.exception))
finally:
log4j.LogManager.getRootLogger().setLevel(old_level)

def test_allow_insecure_gateway_with_conf(self):
with SparkContext._lock:
SparkContext._gateway = None
SparkContext._jvm = None
gateway = _launch_gateway(insecure=True)
try:
os.environ["PYSPARK_ALLOW_INSECURE_GATEWAY"] = "1"
with SparkContext(gateway=gateway) as sc:
a = sc.accumulator(1)
rdd = sc.parallelize([1, 2, 3])
rdd.foreach(lambda x: a.add(x))
self.assertEqual(7, a.value)
finally:
os.environ.pop("PYSPARK_ALLOW_INSECURE_GATEWAY", None)


class ConfTests(unittest.TestCase):
def test_memory_conf(self):
Expand Down

0 comments on commit 1e99f4e

Please sign in to comment.