Skip to content

Commit

Permalink
[Databricks Provider] Revert PRs #40864 and #40471 (#41050)
Browse files Browse the repository at this point in the history
* Revert "Fix named parameters templating in Databricks operators (#40864)"

This reverts commit cfe1d53.

* Revert "Make Databricks operators' json parameter compatible with XComs, Jinja expression values (#40471)"

This reverts commit 4fb2140.

This reverts PR #40864 and PR #40471.

Previously, PR #40471 was contributed to address issue #35433. 
However, that contribution gave rise to another issue #40788. 
Next #40788 was being attempted to be resolved in PR #40864. 
However, with the second PR, it appears that the previous old 
issue #35433 has [resurfaced](#40864 (comment)). So, at the moment, the case is 
that we have 2 PRs on top of the existing implementation 
eventually having nil effect and the previous issues persists. 
I believe it is better to revert those 2 PRs, reopen the earlier 
issue #35433 and peacefully address it by taking the needed time.
  • Loading branch information
pankajkoti authored Jul 27, 2024
1 parent 047d139 commit 4535e08
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 813 deletions.
273 changes: 94 additions & 179 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
WorkflowJobRunLink,
)
from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger
from airflow.providers.databricks.utils.databricks import _normalise_json_content, validate_trigger_event
from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event

if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
Expand Down Expand Up @@ -186,17 +186,6 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger)
raise AirflowException(error_message)


def _handle_overridden_json_params(operator):
for key, value in operator.overridden_json_params.items():
if value is not None:
operator.json[key] = value


def normalise_json_content(operator):
if operator.json:
operator.json = _normalise_json_content(operator.json)


class DatabricksJobRunLink(BaseOperatorLink):
"""Constructs a link to monitor a Databricks Job Run."""

Expand Down Expand Up @@ -263,23 +252,7 @@ class DatabricksCreateJobsOperator(BaseOperator):
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = (
"json",
"databricks_conn_id",
"name",
"description",
"tags",
"tasks",
"job_clusters",
"email_notifications",
"webhook_notifications",
"notification_settings",
"timeout_seconds",
"schedule",
"max_concurrent_runs",
"git_source",
"access_control_list",
)
template_fields: Sequence[str] = ("json", "databricks_conn_id")
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
ui_fgcolor = "#fff"
Expand Down Expand Up @@ -316,19 +289,34 @@ def __init__(
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
self.name = name
self.description = description
self.tags = tags
self.tasks = tasks
self.job_clusters = job_clusters
self.email_notifications = email_notifications
self.webhook_notifications = webhook_notifications
self.notification_settings = notification_settings
self.timeout_seconds = timeout_seconds
self.schedule = schedule
self.max_concurrent_runs = max_concurrent_runs
self.git_source = git_source
self.access_control_list = access_control_list
if name is not None:
self.json["name"] = name
if description is not None:
self.json["description"] = description
if tags is not None:
self.json["tags"] = tags
if tasks is not None:
self.json["tasks"] = tasks
if job_clusters is not None:
self.json["job_clusters"] = job_clusters
if email_notifications is not None:
self.json["email_notifications"] = email_notifications
if webhook_notifications is not None:
self.json["webhook_notifications"] = webhook_notifications
if notification_settings is not None:
self.json["notification_settings"] = notification_settings
if timeout_seconds is not None:
self.json["timeout_seconds"] = timeout_seconds
if schedule is not None:
self.json["schedule"] = schedule
if max_concurrent_runs is not None:
self.json["max_concurrent_runs"] = max_concurrent_runs
if git_source is not None:
self.json["git_source"] = git_source
if access_control_list is not None:
self.json["access_control_list"] = access_control_list
if self.json:
self.json = normalise_json_content(self.json)

@cached_property
def _hook(self):
Expand All @@ -340,40 +328,16 @@ def _hook(self):
caller="DatabricksCreateJobsOperator",
)

def _setup_and_validate_json(self):
self.overridden_json_params = {
"name": self.name,
"description": self.description,
"tags": self.tags,
"tasks": self.tasks,
"job_clusters": self.job_clusters,
"email_notifications": self.email_notifications,
"webhook_notifications": self.webhook_notifications,
"notification_settings": self.notification_settings,
"timeout_seconds": self.timeout_seconds,
"schedule": self.schedule,
"max_concurrent_runs": self.max_concurrent_runs,
"git_source": self.git_source,
"access_control_list": self.access_control_list,
}

_handle_overridden_json_params(self)

def execute(self, context: Context) -> int:
if "name" not in self.json:
raise AirflowException("Missing required parameter: name")

normalise_json_content(self)

def execute(self, context: Context) -> int:
self._setup_and_validate_json()

job_id = self._hook.find_job_id_by_name(self.json["name"])
if job_id is None:
return self._hook.create_job(self.json)
self._hook.reset_job(str(job_id), self.json)
if (access_control_list := self.json.get("access_control_list")) is not None:
acl_json = {"access_control_list": access_control_list}
self._hook.update_job_permission(job_id, _normalise_json_content(acl_json))
self._hook.update_job_permission(job_id, normalise_json_content(acl_json))

return job_id

Expand Down Expand Up @@ -500,25 +464,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = (
"json",
"databricks_conn_id",
"tasks",
"spark_jar_task",
"notebook_task",
"spark_python_task",
"spark_submit_task",
"pipeline_task",
"dbt_task",
"new_cluster",
"existing_cluster_id",
"libraries",
"run_name",
"timeout_seconds",
"idempotency_token",
"access_control_list",
"git_source",
)
template_fields: Sequence[str] = ("json", "databricks_conn_id")
template_ext: Sequence[str] = (".json-tpl",)
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
Expand Down Expand Up @@ -564,21 +510,43 @@ def __init__(
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
self.tasks = tasks
self.spark_jar_task = spark_jar_task
self.notebook_task = notebook_task
self.spark_python_task = spark_python_task
self.spark_submit_task = spark_submit_task
self.pipeline_task = pipeline_task
self.dbt_task = dbt_task
self.new_cluster = new_cluster
self.existing_cluster_id = existing_cluster_id
self.libraries = libraries
self.run_name = run_name
self.timeout_seconds = timeout_seconds
self.idempotency_token = idempotency_token
self.access_control_list = access_control_list
self.git_source = git_source
if tasks is not None:
self.json["tasks"] = tasks
if spark_jar_task is not None:
self.json["spark_jar_task"] = spark_jar_task
if notebook_task is not None:
self.json["notebook_task"] = notebook_task
if spark_python_task is not None:
self.json["spark_python_task"] = spark_python_task
if spark_submit_task is not None:
self.json["spark_submit_task"] = spark_submit_task
if pipeline_task is not None:
self.json["pipeline_task"] = pipeline_task
if dbt_task is not None:
self.json["dbt_task"] = dbt_task
if new_cluster is not None:
self.json["new_cluster"] = new_cluster
if existing_cluster_id is not None:
self.json["existing_cluster_id"] = existing_cluster_id
if libraries is not None:
self.json["libraries"] = libraries
if run_name is not None:
self.json["run_name"] = run_name
if timeout_seconds is not None:
self.json["timeout_seconds"] = timeout_seconds
if "run_name" not in self.json:
self.json["run_name"] = run_name or kwargs["task_id"]
if idempotency_token is not None:
self.json["idempotency_token"] = idempotency_token
if access_control_list is not None:
self.json["access_control_list"] = access_control_list
if git_source is not None:
self.json["git_source"] = git_source

if "dbt_task" in self.json and "git_source" not in self.json:
raise AirflowException("git_source is required for dbt_task")
if pipeline_task is not None and "pipeline_id" in pipeline_task and "pipeline_name" in pipeline_task:
raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'")

# This variable will be used in case our task gets killed.
self.run_id: int | None = None
Expand All @@ -597,43 +565,7 @@ def _get_hook(self, caller: str) -> DatabricksHook:
caller=caller,
)

def _setup_and_validate_json(self):
self.overridden_json_params = {
"tasks": self.tasks,
"spark_jar_task": self.spark_jar_task,
"notebook_task": self.notebook_task,
"spark_python_task": self.spark_python_task,
"spark_submit_task": self.spark_submit_task,
"pipeline_task": self.pipeline_task,
"dbt_task": self.dbt_task,
"new_cluster": self.new_cluster,
"existing_cluster_id": self.existing_cluster_id,
"libraries": self.libraries,
"run_name": self.run_name,
"timeout_seconds": self.timeout_seconds,
"idempotency_token": self.idempotency_token,
"access_control_list": self.access_control_list,
"git_source": self.git_source,
}

_handle_overridden_json_params(self)

if "run_name" not in self.json or self.json["run_name"] is None:
self.json["run_name"] = self.task_id

if "dbt_task" in self.json and "git_source" not in self.json:
raise AirflowException("git_source is required for dbt_task")
if (
"pipeline_task" in self.json
and "pipeline_id" in self.json["pipeline_task"]
and "pipeline_name" in self.json["pipeline_task"]
):
raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'")

normalise_json_content(self)

def execute(self, context: Context):
self._setup_and_validate_json()
if (
"pipeline_task" in self.json
and self.json["pipeline_task"].get("pipeline_id") is None
Expand All @@ -643,7 +575,7 @@ def execute(self, context: Context):
pipeline_name = self.json["pipeline_task"]["pipeline_name"]
self.json["pipeline_task"]["pipeline_id"] = self._hook.find_pipeline_id_by_name(pipeline_name)
del self.json["pipeline_task"]["pipeline_name"]
json_normalised = _normalise_json_content(self.json)
json_normalised = normalise_json_content(self.json)
self.run_id = self._hook.submit_run(json_normalised)
if self.deferrable:
_handle_deferrable_databricks_operator_execution(self, self._hook, self.log, context)
Expand Down Expand Up @@ -679,7 +611,7 @@ def __init__(self, *args, **kwargs):

def execute(self, context):
hook = self._get_hook(caller="DatabricksSubmitRunDeferrableOperator")
json_normalised = _normalise_json_content(self.json)
json_normalised = normalise_json_content(self.json)
self.run_id = hook.submit_run(json_normalised)
_handle_deferrable_databricks_operator_execution(self, hook, self.log, context)

Expand Down Expand Up @@ -836,18 +768,7 @@ class DatabricksRunNowOperator(BaseOperator):
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = (
"json",
"databricks_conn_id",
"job_id",
"job_name",
"notebook_params",
"python_params",
"python_named_params",
"jar_params",
"spark_submit_params",
"idempotency_token",
)
template_fields: Sequence[str] = ("json", "databricks_conn_id")
template_ext: Sequence[str] = (".json-tpl",)
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
Expand Down Expand Up @@ -890,14 +811,27 @@ def __init__(
self.deferrable = deferrable
self.repair_run = repair_run
self.cancel_previous_runs = cancel_previous_runs
self.job_id = job_id
self.job_name = job_name
self.notebook_params = notebook_params
self.python_params = python_params
self.python_named_params = python_named_params
self.jar_params = jar_params
self.spark_submit_params = spark_submit_params
self.idempotency_token = idempotency_token

if job_id is not None:
self.json["job_id"] = job_id
if job_name is not None:
self.json["job_name"] = job_name
if "job_id" in self.json and "job_name" in self.json:
raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'")
if notebook_params is not None:
self.json["notebook_params"] = notebook_params
if python_params is not None:
self.json["python_params"] = python_params
if python_named_params is not None:
self.json["python_named_params"] = python_named_params
if jar_params is not None:
self.json["jar_params"] = jar_params
if spark_submit_params is not None:
self.json["spark_submit_params"] = spark_submit_params
if idempotency_token is not None:
self.json["idempotency_token"] = idempotency_token
if self.json:
self.json = normalise_json_content(self.json)
# This variable will be used in case our task gets killed.
self.run_id: int | None = None
self.do_xcom_push = do_xcom_push
Expand All @@ -915,26 +849,7 @@ def _get_hook(self, caller: str) -> DatabricksHook:
caller=caller,
)

def _setup_and_validate_json(self):
self.overridden_json_params = {
"job_id": self.job_id,
"job_name": self.job_name,
"notebook_params": self.notebook_params,
"python_params": self.python_params,
"python_named_params": self.python_named_params,
"jar_params": self.jar_params,
"spark_submit_params": self.spark_submit_params,
"idempotency_token": self.idempotency_token,
}
_handle_overridden_json_params(self)

if "job_id" in self.json and "job_name" in self.json:
raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'")

normalise_json_content(self)

def execute(self, context: Context):
self._setup_and_validate_json()
hook = self._hook
if "job_name" in self.json:
job_id = hook.find_job_id_by_name(self.json["job_name"])
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/databricks/utils/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from airflow.providers.databricks.hooks.databricks import RunState


def _normalise_json_content(content, json_path: str = "json") -> str | bool | list | dict:
def normalise_json_content(content, json_path: str = "json") -> str | bool | list | dict:
"""
Normalize content or all values of content if it is a dict to a string.
Expand All @@ -33,7 +33,7 @@ def _normalise_json_content(content, json_path: str = "json") -> str | bool | li
The only one exception is when we have boolean values, they can not be converted
to string type because databricks does not understand 'True' or 'False' values.
"""
normalise = _normalise_json_content
normalise = normalise_json_content
if isinstance(content, (str, bool)):
return content
elif isinstance(content, (int, float)):
Expand Down
Loading

0 comments on commit 4535e08

Please sign in to comment.