Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only restrict spark binary passed via extra #30213

Merged
merged 1 commit into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions airflow/providers/apache/spark/hooks/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,7 @@ def __init__(
self._submit_sp: Any | None = None
self._yarn_application_id: str | None = None
self._kubernetes_driver_pod: str | None = None
self._spark_binary = spark_binary
if self._spark_binary is not None and self._spark_binary not in ALLOWED_SPARK_BINARIES:
raise RuntimeError(
f"The spark-binary extra can be on of {ALLOWED_SPARK_BINARIES} and it"
f" was `{spark_binary}`. Please make sure your spark binary is one of the"
f" allowed ones and that it is available on the PATH"
)

self.spark_binary = spark_binary
self._connection = self._resolve_connection()
self._is_yarn = "yarn" in self._connection["master"]
self._is_kubernetes = "k8s" in self._connection["master"]
Expand Down Expand Up @@ -186,7 +179,7 @@ def _resolve_connection(self) -> dict[str, Any]:
"master": "yarn",
"queue": None,
"deploy_mode": None,
"spark_binary": self._spark_binary or "spark-submit",
"spark_binary": self.spark_binary or "spark-submit",
"namespace": None,
}

Expand All @@ -203,21 +196,22 @@ def _resolve_connection(self) -> dict[str, Any]:
extra = conn.extra_dejson
conn_data["queue"] = extra.get("queue")
conn_data["deploy_mode"] = extra.get("deploy-mode")
spark_binary = self._spark_binary or extra.get("spark-binary", "spark-submit")
if spark_binary not in ALLOWED_SPARK_BINARIES:
raise RuntimeError(
f"The `spark-binary` extra can be one of {ALLOWED_SPARK_BINARIES} and it"
f" was `{spark_binary}`. Please make sure your spark binary is one of the"
" allowed ones and that it is available on the PATH"
)
if not self.spark_binary:
self.spark_binary = extra.get("spark-binary", "spark-submit")
if self.spark_binary is not None and self.spark_binary not in ALLOWED_SPARK_BINARIES:
raise RuntimeError(
f"The spark-binary extra can be on of {ALLOWED_SPARK_BINARIES} and it"
f" was `{self.spark_binary}`. Please make sure your spark binary is one of the"
f" allowed ones and that it is available on the PATH"
)
conn_spark_home = extra.get("spark-home")
if conn_spark_home:
raise RuntimeError(
"The `spark-home` extra is not allowed any more. Please make sure one of"
f" {ALLOWED_SPARK_BINARIES} is available on the PATH, and set `spark-binary`"
" if needed."
)
conn_data["spark_binary"] = spark_binary
conn_data["spark_binary"] = self.spark_binary
conn_data["namespace"] = extra.get("namespace")
except AirflowException:
self.log.info(
Expand Down
5 changes: 2 additions & 3 deletions tests/providers/apache/spark/hooks/test_spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,9 +461,8 @@ def test_resolve_connection_spark_binary_spark3_submit_set_connection(self):
assert connection == expected_spark_connection
assert cmd[0] == "spark3-submit"

def test_resolve_connection_custom_spark_binary_not_allowed_runtime_error(self):
with pytest.raises(RuntimeError):
SparkSubmitHook(conn_id="spark_binary_set", spark_binary="another-custom-spark-submit")
def test_resolve_connection_custom_spark_binary_allowed_in_hook(self):
SparkSubmitHook(conn_id="spark_binary_set", spark_binary="another-custom-spark-submit")

def test_resolve_connection_spark_binary_extra_not_allowed_runtime_error(self):
with pytest.raises(RuntimeError):
Expand Down