diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index b1299b9d85040..89aa6e9df91cb 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -40,7 +40,7 @@ WorkflowJobRunLink, ) from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger -from airflow.providers.databricks.utils.databricks import _normalise_json_content, validate_trigger_event +from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event if TYPE_CHECKING: from airflow.models.taskinstancekey import TaskInstanceKey @@ -186,17 +186,6 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger) raise AirflowException(error_message) -def _handle_overridden_json_params(operator): - for key, value in operator.overridden_json_params.items(): - if value is not None: - operator.json[key] = value - - -def normalise_json_content(operator): - if operator.json: - operator.json = _normalise_json_content(operator.json) - - class DatabricksJobRunLink(BaseOperatorLink): """Constructs a link to monitor a Databricks Job Run.""" @@ -263,23 +252,7 @@ class DatabricksCreateJobsOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ( - "json", - "databricks_conn_id", - "name", - "description", - "tags", - "tasks", - "job_clusters", - "email_notifications", - "webhook_notifications", - "notification_settings", - "timeout_seconds", - "schedule", - "max_concurrent_runs", - "git_source", - "access_control_list", - ) + template_fields: Sequence[str] = ("json", "databricks_conn_id") # Databricks brand color (blue) under white text ui_color = "#1CB1C2" ui_fgcolor = "#fff" @@ -316,19 +289,34 @@ def __init__( self.databricks_retry_limit = databricks_retry_limit self.databricks_retry_delay = databricks_retry_delay self.databricks_retry_args = databricks_retry_args - self.name = name - self.description = description - self.tags = tags - self.tasks = tasks - self.job_clusters = job_clusters - self.email_notifications = email_notifications - self.webhook_notifications = webhook_notifications - self.notification_settings = notification_settings - self.timeout_seconds = timeout_seconds - self.schedule = schedule - self.max_concurrent_runs = max_concurrent_runs - self.git_source = git_source - self.access_control_list = access_control_list + if name is not None: + self.json["name"] = name + if description is not None: + self.json["description"] = description + if tags is not None: + self.json["tags"] = tags + if tasks is not None: + self.json["tasks"] = tasks + if job_clusters is not None: + self.json["job_clusters"] = job_clusters + if email_notifications is not None: + self.json["email_notifications"] = email_notifications + if webhook_notifications is not None: + self.json["webhook_notifications"] = webhook_notifications + if notification_settings is not None: + self.json["notification_settings"] = notification_settings + if timeout_seconds is not None: + self.json["timeout_seconds"] = timeout_seconds + if schedule is not None: + self.json["schedule"] = schedule + if max_concurrent_runs is not None: + self.json["max_concurrent_runs"] = max_concurrent_runs + if git_source is not None: + self.json["git_source"] = git_source + if access_control_list is not None: + self.json["access_control_list"] = access_control_list + if self.json: + self.json = normalise_json_content(self.json) @cached_property def _hook(self): @@ -340,40 +328,16 @@ def _hook(self): caller="DatabricksCreateJobsOperator", ) - def _setup_and_validate_json(self): - self.overridden_json_params = { - "name": self.name, - "description": self.description, - "tags": self.tags, - "tasks": self.tasks, - "job_clusters": self.job_clusters, - "email_notifications": self.email_notifications, - "webhook_notifications": self.webhook_notifications, - "notification_settings": self.notification_settings, - "timeout_seconds": self.timeout_seconds, - "schedule": self.schedule, - "max_concurrent_runs": self.max_concurrent_runs, - "git_source": self.git_source, - "access_control_list": self.access_control_list, - } - - _handle_overridden_json_params(self) - + def execute(self, context: Context) -> int: if "name" not in self.json: raise AirflowException("Missing required parameter: name") - - normalise_json_content(self) - - def execute(self, context: Context) -> int: - self._setup_and_validate_json() - job_id = self._hook.find_job_id_by_name(self.json["name"]) if job_id is None: return self._hook.create_job(self.json) self._hook.reset_job(str(job_id), self.json) if (access_control_list := self.json.get("access_control_list")) is not None: acl_json = {"access_control_list": access_control_list} - self._hook.update_job_permission(job_id, _normalise_json_content(acl_json)) + self._hook.update_job_permission(job_id, normalise_json_content(acl_json)) return job_id @@ -500,25 +464,7 @@ class DatabricksSubmitRunOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ( - "json", - "databricks_conn_id", - "tasks", - "spark_jar_task", - "notebook_task", - "spark_python_task", - "spark_submit_task", - "pipeline_task", - "dbt_task", - "new_cluster", - "existing_cluster_id", - "libraries", - "run_name", - "timeout_seconds", - "idempotency_token", - "access_control_list", - "git_source", - ) + template_fields: Sequence[str] = ("json", "databricks_conn_id") template_ext: Sequence[str] = (".json-tpl",) # Databricks brand color (blue) under white text ui_color = "#1CB1C2" @@ -564,21 +510,43 @@ def __init__( self.databricks_retry_args = databricks_retry_args self.wait_for_termination = wait_for_termination self.deferrable = deferrable - self.tasks = tasks - self.spark_jar_task = spark_jar_task - self.notebook_task = notebook_task - self.spark_python_task = spark_python_task - self.spark_submit_task = spark_submit_task - self.pipeline_task = pipeline_task - self.dbt_task = dbt_task - self.new_cluster = new_cluster - self.existing_cluster_id = existing_cluster_id - self.libraries = libraries - self.run_name = run_name - self.timeout_seconds = timeout_seconds - self.idempotency_token = idempotency_token - self.access_control_list = access_control_list - self.git_source = git_source + if tasks is not None: + self.json["tasks"] = tasks + if spark_jar_task is not None: + self.json["spark_jar_task"] = spark_jar_task + if notebook_task is not None: + self.json["notebook_task"] = notebook_task + if spark_python_task is not None: + self.json["spark_python_task"] = spark_python_task + if spark_submit_task is not None: + self.json["spark_submit_task"] = spark_submit_task + if pipeline_task is not None: + self.json["pipeline_task"] = pipeline_task + if dbt_task is not None: + self.json["dbt_task"] = dbt_task + if new_cluster is not None: + self.json["new_cluster"] = new_cluster + if existing_cluster_id is not None: + self.json["existing_cluster_id"] = existing_cluster_id + if libraries is not None: + self.json["libraries"] = libraries + if run_name is not None: + self.json["run_name"] = run_name + if timeout_seconds is not None: + self.json["timeout_seconds"] = timeout_seconds + if "run_name" not in self.json: + self.json["run_name"] = run_name or kwargs["task_id"] + if idempotency_token is not None: + self.json["idempotency_token"] = idempotency_token + if access_control_list is not None: + self.json["access_control_list"] = access_control_list + if git_source is not None: + self.json["git_source"] = git_source + + if "dbt_task" in self.json and "git_source" not in self.json: + raise AirflowException("git_source is required for dbt_task") + if pipeline_task is not None and "pipeline_id" in pipeline_task and "pipeline_name" in pipeline_task: + raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'") # This variable will be used in case our task gets killed. self.run_id: int | None = None @@ -597,43 +565,7 @@ def _get_hook(self, caller: str) -> DatabricksHook: caller=caller, ) - def _setup_and_validate_json(self): - self.overridden_json_params = { - "tasks": self.tasks, - "spark_jar_task": self.spark_jar_task, - "notebook_task": self.notebook_task, - "spark_python_task": self.spark_python_task, - "spark_submit_task": self.spark_submit_task, - "pipeline_task": self.pipeline_task, - "dbt_task": self.dbt_task, - "new_cluster": self.new_cluster, - "existing_cluster_id": self.existing_cluster_id, - "libraries": self.libraries, - "run_name": self.run_name, - "timeout_seconds": self.timeout_seconds, - "idempotency_token": self.idempotency_token, - "access_control_list": self.access_control_list, - "git_source": self.git_source, - } - - _handle_overridden_json_params(self) - - if "run_name" not in self.json or self.json["run_name"] is None: - self.json["run_name"] = self.task_id - - if "dbt_task" in self.json and "git_source" not in self.json: - raise AirflowException("git_source is required for dbt_task") - if ( - "pipeline_task" in self.json - and "pipeline_id" in self.json["pipeline_task"] - and "pipeline_name" in self.json["pipeline_task"] - ): - raise AirflowException("'pipeline_name' is not allowed in conjunction with 'pipeline_id'") - - normalise_json_content(self) - def execute(self, context: Context): - self._setup_and_validate_json() if ( "pipeline_task" in self.json and self.json["pipeline_task"].get("pipeline_id") is None @@ -643,7 +575,7 @@ def execute(self, context: Context): pipeline_name = self.json["pipeline_task"]["pipeline_name"] self.json["pipeline_task"]["pipeline_id"] = self._hook.find_pipeline_id_by_name(pipeline_name) del self.json["pipeline_task"]["pipeline_name"] - json_normalised = _normalise_json_content(self.json) + json_normalised = normalise_json_content(self.json) self.run_id = self._hook.submit_run(json_normalised) if self.deferrable: _handle_deferrable_databricks_operator_execution(self, self._hook, self.log, context) @@ -679,7 +611,7 @@ def __init__(self, *args, **kwargs): def execute(self, context): hook = self._get_hook(caller="DatabricksSubmitRunDeferrableOperator") - json_normalised = _normalise_json_content(self.json) + json_normalised = normalise_json_content(self.json) self.run_id = hook.submit_run(json_normalised) _handle_deferrable_databricks_operator_execution(self, hook, self.log, context) @@ -836,18 +768,7 @@ class DatabricksRunNowOperator(BaseOperator): """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ( - "json", - "databricks_conn_id", - "job_id", - "job_name", - "notebook_params", - "python_params", - "python_named_params", - "jar_params", - "spark_submit_params", - "idempotency_token", - ) + template_fields: Sequence[str] = ("json", "databricks_conn_id") template_ext: Sequence[str] = (".json-tpl",) # Databricks brand color (blue) under white text ui_color = "#1CB1C2" @@ -890,14 +811,27 @@ def __init__( self.deferrable = deferrable self.repair_run = repair_run self.cancel_previous_runs = cancel_previous_runs - self.job_id = job_id - self.job_name = job_name - self.notebook_params = notebook_params - self.python_params = python_params - self.python_named_params = python_named_params - self.jar_params = jar_params - self.spark_submit_params = spark_submit_params - self.idempotency_token = idempotency_token + + if job_id is not None: + self.json["job_id"] = job_id + if job_name is not None: + self.json["job_name"] = job_name + if "job_id" in self.json and "job_name" in self.json: + raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'") + if notebook_params is not None: + self.json["notebook_params"] = notebook_params + if python_params is not None: + self.json["python_params"] = python_params + if python_named_params is not None: + self.json["python_named_params"] = python_named_params + if jar_params is not None: + self.json["jar_params"] = jar_params + if spark_submit_params is not None: + self.json["spark_submit_params"] = spark_submit_params + if idempotency_token is not None: + self.json["idempotency_token"] = idempotency_token + if self.json: + self.json = normalise_json_content(self.json) # This variable will be used in case our task gets killed. self.run_id: int | None = None self.do_xcom_push = do_xcom_push @@ -915,26 +849,7 @@ def _get_hook(self, caller: str) -> DatabricksHook: caller=caller, ) - def _setup_and_validate_json(self): - self.overridden_json_params = { - "job_id": self.job_id, - "job_name": self.job_name, - "notebook_params": self.notebook_params, - "python_params": self.python_params, - "python_named_params": self.python_named_params, - "jar_params": self.jar_params, - "spark_submit_params": self.spark_submit_params, - "idempotency_token": self.idempotency_token, - } - _handle_overridden_json_params(self) - - if "job_id" in self.json and "job_name" in self.json: - raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'") - - normalise_json_content(self) - def execute(self, context: Context): - self._setup_and_validate_json() hook = self._hook if "job_name" in self.json: job_id = hook.find_job_id_by_name(self.json["job_name"]) diff --git a/airflow/providers/databricks/utils/databricks.py b/airflow/providers/databricks/utils/databricks.py index ec99bce17873c..88d622c3bc1fb 100644 --- a/airflow/providers/databricks/utils/databricks.py +++ b/airflow/providers/databricks/utils/databricks.py @@ -21,7 +21,7 @@ from airflow.providers.databricks.hooks.databricks import RunState -def _normalise_json_content(content, json_path: str = "json") -> str | bool | list | dict: +def normalise_json_content(content, json_path: str = "json") -> str | bool | list | dict: """ Normalize content or all values of content if it is a dict to a string. @@ -33,7 +33,7 @@ def _normalise_json_content(content, json_path: str = "json") -> str | bool | li The only one exception is when we have boolean values, they can not be converted to string type because databricks does not understand 'True' or 'False' values. """ - normalise = _normalise_json_content + normalise = normalise_json_content if isinstance(content, (str, bool)): return content elif isinstance(content, (int, float)): diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index a7337669047cb..7ff2295eda94a 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -23,10 +23,8 @@ import pytest -from airflow.decorators import task from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG -from airflow.operators.python import PythonOperator from airflow.providers.databricks.hooks.databricks import RunState from airflow.providers.databricks.operators.databricks import ( DatabricksCreateJobsOperator, @@ -38,7 +36,6 @@ ) from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger from airflow.providers.databricks.utils import databricks as utils -from airflow.utils import timezone pytestmark = pytest.mark.db_test @@ -66,11 +63,7 @@ RUN_ID = 1 RUN_PAGE_URL = "run-page-url" JOB_ID = "42" -TEMPLATED_JOB_ID = "job-id-{{ ds }}" -RENDERED_TEMPLATED_JOB_ID = f"job-id-{DATE}" JOB_NAME = "job-name" -TEMPLATED_JOB_NAME = "job-name-{{ ds }}" -RENDERED_TEMPLATED_JOB_NAME = f"job-name-{DATE}" JOB_DESCRIPTION = "job-description" NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"} JAR_PARAMS = ["param1", "param2"] @@ -255,9 +248,9 @@ def make_run_with_state_mock( class TestDatabricksCreateJobsOperator: - def test_validate_json_with_named_parameters(self): + def test_init_with_named_parameters(self): """ - Test the _setup_and_validate_json function with the named parameters. + Test the initializer with the named parameters. """ op = DatabricksCreateJobsOperator( task_id=TASK_ID, @@ -273,9 +266,7 @@ def test_validate_json_with_named_parameters(self): git_source=GIT_SOURCE, access_control_list=ACCESS_CONTROL_LIST, ) - op._setup_and_validate_json() - - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "name": JOB_NAME, "tags": TAGS, @@ -293,9 +284,9 @@ def test_validate_json_with_named_parameters(self): assert expected == op.json - def test_validate_json_with_json(self): + def test_init_with_json(self): """ - Test the _setup_and_validate_json function with json data. + Test the initializer with json data. """ json = { "name": JOB_NAME, @@ -311,9 +302,8 @@ def test_validate_json_with_json(self): "access_control_list": ACCESS_CONTROL_LIST, } op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) - op._setup_and_validate_json() - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "name": JOB_NAME, "tags": TAGS, @@ -331,9 +321,9 @@ def test_validate_json_with_json(self): assert expected == op.json - def test_validate_json_with_merging(self): + def test_init_with_merging(self): """ - Test the _setup_and_validate_json function when json and other named parameters are both + Test the initializer when json and other named parameters are both provided. The named parameters should override top level keys in the json dict. """ @@ -377,9 +367,8 @@ def test_validate_json_with_merging(self): git_source=override_git_source, access_control_list=override_access_control_list, ) - op._setup_and_validate_json() - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "name": override_name, "tags": override_tags, @@ -397,220 +386,24 @@ def test_validate_json_with_merging(self): assert expected == op.json - def test_validate_json_with_templating(self): + def test_init_with_templating(self): json = {"name": "test-{{ ds }}"} dag = DAG("test", start_date=datetime.now()) op = DatabricksCreateJobsOperator(dag=dag, task_id=TASK_ID, json=json) op.render_template_fields(context={"ds": DATE}) - op._setup_and_validate_json() - - expected = utils._normalise_json_content({"name": f"test-{DATE}"}) + expected = utils.normalise_json_content({"name": f"test-{DATE}"}) assert expected == op.json - def test_validate_json_with_bad_type(self): - json = {"test": datetime.now(), "name": "test"} + def test_init_with_bad_type(self): + json = {"test": datetime.now()} # Looks a bit weird since we have to escape regex reserved symbols. exception_message = ( r"Type \<(type|class) \'datetime.datetime\'\> used " r"for parameter json\[test\] is not a number or a string" ) with pytest.raises(AirflowException, match=exception_message): - DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)._setup_and_validate_json() - - def test_validate_json_with_no_name(self): - json = {} - exception_message = "Missing required parameter: name" - with pytest.raises(AirflowException, match=exception_message): - DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)._setup_and_validate_json() - - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): - json = "{{ ti.xcom_pull(task_ids='push_json') }}" - with dag_maker("test_templated", render_template_as_native_obj=True): - push_json = PythonOperator( - task_id="push_json", - python_callable=lambda: { - "name": JOB_NAME, - "description": JOB_DESCRIPTION, - "tags": TAGS, - "tasks": TASKS, - "job_clusters": JOB_CLUSTERS, - "email_notifications": EMAIL_NOTIFICATIONS, - "webhook_notifications": WEBHOOK_NOTIFICATIONS, - "notification_settings": NOTIFICATION_SETTINGS, - "timeout_seconds": TIMEOUT_SECONDS, - "schedule": SCHEDULE, - "max_concurrent_runs": MAX_CONCURRENT_RUNS, - "git_source": GIT_SOURCE, - "access_control_list": ACCESS_CONTROL_LIST, - }, - ) - op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) - push_json >> op - - db_mock = db_mock_class.return_value - db_mock.create_job.return_value = JOB_ID - - db_mock.find_job_id_by_name.return_value = None - - dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) - tis = {ti.task_id: ti for ti in dagrun.task_instances} - tis["push_json"].run() - tis[TASK_ID].run() - - expected = utils._normalise_json_content( - { - "name": JOB_NAME, - "description": JOB_DESCRIPTION, - "tags": TAGS, - "tasks": TASKS, - "job_clusters": JOB_CLUSTERS, - "email_notifications": EMAIL_NOTIFICATIONS, - "webhook_notifications": WEBHOOK_NOTIFICATIONS, - "notification_settings": NOTIFICATION_SETTINGS, - "timeout_seconds": TIMEOUT_SECONDS, - "schedule": SCHEDULE, - "max_concurrent_runs": MAX_CONCURRENT_RUNS, - "git_source": GIT_SOURCE, - "access_control_list": ACCESS_CONTROL_LIST, - } - ) - db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay, - retry_args=None, - caller="DatabricksCreateJobsOperator", - ) - - db_mock.create_job.assert_called_once_with(expected) - assert JOB_ID == tis[TASK_ID].xcom_pull() - - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_validate_json_with_templated_named_param(self, db_mock_class, dag_maker): - json = "{{ ti.xcom_pull(task_ids='push_json') }}" - with dag_maker("test_templated", render_template_as_native_obj=True): - push_json = PythonOperator( - task_id="push_json", - python_callable=lambda: { - "description": JOB_DESCRIPTION, - "tags": TAGS, - "tasks": TASKS, - "job_clusters": JOB_CLUSTERS, - "email_notifications": EMAIL_NOTIFICATIONS, - "webhook_notifications": WEBHOOK_NOTIFICATIONS, - "notification_settings": NOTIFICATION_SETTINGS, - "timeout_seconds": TIMEOUT_SECONDS, - "schedule": SCHEDULE, - "max_concurrent_runs": MAX_CONCURRENT_RUNS, - "git_source": GIT_SOURCE, - "access_control_list": ACCESS_CONTROL_LIST, - }, - ) - op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json, name=TEMPLATED_JOB_NAME) - push_json >> op - - db_mock = db_mock_class.return_value - db_mock.create_job.return_value = JOB_ID - - db_mock.find_job_id_by_name.return_value = None - - dagrun = dag_maker.create_dagrun(execution_date=datetime.strptime(DATE, "%Y-%m-%d")) - tis = {ti.task_id: ti for ti in dagrun.task_instances} - tis["push_json"].run() - tis[TASK_ID].run() - - expected = utils._normalise_json_content( - { - "name": RENDERED_TEMPLATED_JOB_NAME, - "description": JOB_DESCRIPTION, - "tags": TAGS, - "tasks": TASKS, - "job_clusters": JOB_CLUSTERS, - "email_notifications": EMAIL_NOTIFICATIONS, - "webhook_notifications": WEBHOOK_NOTIFICATIONS, - "notification_settings": NOTIFICATION_SETTINGS, - "timeout_seconds": TIMEOUT_SECONDS, - "schedule": SCHEDULE, - "max_concurrent_runs": MAX_CONCURRENT_RUNS, - "git_source": GIT_SOURCE, - "access_control_list": ACCESS_CONTROL_LIST, - } - ) - db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay, - retry_args=None, - caller="DatabricksCreateJobsOperator", - ) - - db_mock.create_job.assert_called_once_with(expected) - assert JOB_ID == tis[TASK_ID].xcom_pull() - - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker): - with dag_maker("test_xcomarg", render_template_as_native_obj=True): - - @task - def push_json() -> dict: - return { - "name": JOB_NAME, - "description": JOB_DESCRIPTION, - "tags": TAGS, - "tasks": TASKS, - "job_clusters": JOB_CLUSTERS, - "email_notifications": EMAIL_NOTIFICATIONS, - "webhook_notifications": WEBHOOK_NOTIFICATIONS, - "notification_settings": NOTIFICATION_SETTINGS, - "timeout_seconds": TIMEOUT_SECONDS, - "schedule": SCHEDULE, - "max_concurrent_runs": MAX_CONCURRENT_RUNS, - "git_source": GIT_SOURCE, - "access_control_list": ACCESS_CONTROL_LIST, - } - - json = push_json() - op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) - - db_mock = db_mock_class.return_value - db_mock.create_job.return_value = JOB_ID - - db_mock.find_job_id_by_name.return_value = None - - dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) - tis = {ti.task_id: ti for ti in dagrun.task_instances} - tis["push_json"].run() - tis[TASK_ID].run() - - expected = utils._normalise_json_content( - { - "name": JOB_NAME, - "description": JOB_DESCRIPTION, - "tags": TAGS, - "tasks": TASKS, - "job_clusters": JOB_CLUSTERS, - "email_notifications": EMAIL_NOTIFICATIONS, - "webhook_notifications": WEBHOOK_NOTIFICATIONS, - "notification_settings": NOTIFICATION_SETTINGS, - "timeout_seconds": TIMEOUT_SECONDS, - "schedule": SCHEDULE, - "max_concurrent_runs": MAX_CONCURRENT_RUNS, - "git_source": GIT_SOURCE, - "access_control_list": ACCESS_CONTROL_LIST, - } - ) - db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay, - retry_args=None, - caller="DatabricksCreateJobsOperator", - ) - - db_mock.create_job.assert_called_once_with(expected) - assert JOB_ID == tis[TASK_ID].xcom_pull() + DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_exec_create(self, db_mock_class): @@ -640,7 +433,7 @@ def test_exec_create(self, db_mock_class): return_result = op.execute({}) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "name": JOB_NAME, "description": JOB_DESCRIPTION, @@ -694,7 +487,7 @@ def test_exec_reset(self, db_mock_class): return_result = op.execute({}) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "name": JOB_NAME, "description": JOB_DESCRIPTION, @@ -746,7 +539,7 @@ def test_exec_update_job_permission(self, db_mock_class): op.execute({}) - expected = utils._normalise_json_content({"access_control_list": ACCESS_CONTROL_LIST}) + expected = utils.normalise_json_content({"access_control_list": ACCESS_CONTROL_LIST}) db_mock_class.assert_called_once_with( DEFAULT_CONN_ID, @@ -793,76 +586,66 @@ def test_exec_update_job_permission_with_empty_acl(self, db_mock_class): class TestDatabricksSubmitRunOperator: - def test_validate_json_with_notebook_task_named_parameters(self): + def test_init_with_notebook_task_named_parameters(self): """ - Test the _setup_and_validate_json function with named parameters. + Test the initializer with the named parameters. """ op = DatabricksSubmitRunOperator( task_id=TASK_ID, new_cluster=NEW_CLUSTER, notebook_task=NOTEBOOK_TASK ) - op._setup_and_validate_json() - - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) - assert expected == utils._normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op.json) - def test_validate_json_with_spark_python_task_named_parameters(self): + def test_init_with_spark_python_task_named_parameters(self): """ - Test the _setup_and_validate_json function with the named parameters. + Test the initializer with the named parameters. """ op = DatabricksSubmitRunOperator( task_id=TASK_ID, new_cluster=NEW_CLUSTER, spark_python_task=SPARK_PYTHON_TASK ) - op._setup_and_validate_json() - - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "spark_python_task": SPARK_PYTHON_TASK, "run_name": TASK_ID} ) - assert expected == utils._normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op.json) - def test_validate_json_with_pipeline_name_task_named_parameters(self): + def test_init_with_pipeline_name_task_named_parameters(self): """ - Test the _setup_and_validate_json function with the named parameters. + Test the initializer with the named parameters. """ op = DatabricksSubmitRunOperator(task_id=TASK_ID, pipeline_task=PIPELINE_NAME_TASK) - op._setup_and_validate_json() + expected = utils.normalise_json_content({"pipeline_task": PIPELINE_NAME_TASK, "run_name": TASK_ID}) - expected = utils._normalise_json_content({"pipeline_task": PIPELINE_NAME_TASK, "run_name": TASK_ID}) + assert expected == utils.normalise_json_content(op.json) - assert expected == utils._normalise_json_content(op.json) - - def test_validate_json_with_pipeline_id_task_named_parameters(self): + def test_init_with_pipeline_id_task_named_parameters(self): """ - Test the _setup_and_validate_json function with the named parameters. + Test the initializer with the named parameters. """ op = DatabricksSubmitRunOperator(task_id=TASK_ID, pipeline_task=PIPELINE_ID_TASK) - op._setup_and_validate_json() - - expected = utils._normalise_json_content({"pipeline_task": PIPELINE_ID_TASK, "run_name": TASK_ID}) + expected = utils.normalise_json_content({"pipeline_task": PIPELINE_ID_TASK, "run_name": TASK_ID}) - assert expected == utils._normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op.json) - def test_validate_json_with_spark_submit_task_named_parameters(self): + def test_init_with_spark_submit_task_named_parameters(self): """ - Test the _setup_and_validate_json function with the named parameters. + Test the initializer with the named parameters. """ op = DatabricksSubmitRunOperator( task_id=TASK_ID, new_cluster=NEW_CLUSTER, spark_submit_task=SPARK_SUBMIT_TASK ) - op._setup_and_validate_json() - - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "spark_submit_task": SPARK_SUBMIT_TASK, "run_name": TASK_ID} ) - assert expected == utils._normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op.json) - def test_validate_json_with_dbt_task_named_parameters(self): + def test_init_with_dbt_task_named_parameters(self): """ - Test the _setup_and_validate_json function with the named parameters. + Test the initializer with the named parameters. """ git_source = { "git_url": "https://github.com/dbt-labs/jaffle_shop", @@ -872,17 +655,15 @@ def test_validate_json_with_dbt_task_named_parameters(self): op = DatabricksSubmitRunOperator( task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK, git_source=git_source ) - op._setup_and_validate_json() - - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "dbt_task": DBT_TASK, "git_source": git_source, "run_name": TASK_ID} ) - assert expected == utils._normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op.json) - def test_validate_json_with_dbt_task_mixed_parameters(self): + def test_init_with_dbt_task_mixed_parameters(self): """ - Test the _setup_and_validate_json function with mixed parameters. + Test the initializer with mixed parameters. """ git_source = { "git_url": "https://github.com/dbt-labs/jaffle_shop", @@ -893,85 +674,73 @@ def test_validate_json_with_dbt_task_mixed_parameters(self): op = DatabricksSubmitRunOperator( task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK, json=json ) - op._setup_and_validate_json() - - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "dbt_task": DBT_TASK, "git_source": git_source, "run_name": TASK_ID} ) - assert expected == utils._normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op.json) - def test_validate_json_with_dbt_task_without_git_source_raises_error(self): + def test_init_with_dbt_task_without_git_source_raises_error(self): """ - Test the _setup_and_validate_json function without the necessary git_source for dbt_task raises error. + Test the initializer without the necessary git_source for dbt_task raises error. """ exception_message = "git_source is required for dbt_task" with pytest.raises(AirflowException, match=exception_message): - DatabricksSubmitRunOperator( - task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK - )._setup_and_validate_json() + DatabricksSubmitRunOperator(task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK) - def test_validate_json_with_dbt_task_json_without_git_source_raises_error(self): + def test_init_with_dbt_task_json_without_git_source_raises_error(self): """ - Test the _setup_and_validate_json function without the necessary git_source for dbt_task raises error. + Test the initializer without the necessary git_source for dbt_task raises error. """ json = {"dbt_task": DBT_TASK, "new_cluster": NEW_CLUSTER} exception_message = "git_source is required for dbt_task" with pytest.raises(AirflowException, match=exception_message): - DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)._setup_and_validate_json() + DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) - def test_validate_json_with_json(self): + def test_init_with_json(self): """ - Test the _setup_and_validate_json function with json data. + Test the initializer with json data. """ json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK} op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) - op._setup_and_validate_json() - - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) - assert expected == utils._normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op.json) - def test_validate_json_with_tasks(self): + def test_init_with_tasks(self): tasks = [{"task_key": 1, "new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK}] op = DatabricksSubmitRunOperator(task_id=TASK_ID, tasks=tasks) - op._setup_and_validate_json() - - expected = utils._normalise_json_content({"run_name": TASK_ID, "tasks": tasks}) - assert expected == utils._normalise_json_content(op.json) + expected = utils.normalise_json_content({"run_name": TASK_ID, "tasks": tasks}) + assert expected == utils.normalise_json_content(op.json) - def test_validate_json_with_specified_run_name(self): + def test_init_with_specified_run_name(self): """ - Test the _setup_and_validate_json function with a specified run_name. + Test the initializer with a specified run_name. """ json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": RUN_NAME} op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) - op._setup_and_validate_json() - - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": RUN_NAME} ) - assert expected == utils._normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op.json) - def test_validate_json_with_pipeline_task(self): + def test_pipeline_task(self): """ - Test the _setup_and_validate_json function with a pipeline task. + Test the initializer with a pipeline task. """ pipeline_task = {"pipeline_id": "test-dlt"} json = {"new_cluster": NEW_CLUSTER, "run_name": RUN_NAME, "pipeline_task": pipeline_task} op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) - op._setup_and_validate_json() - - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "pipeline_task": pipeline_task, "run_name": RUN_NAME} ) - assert expected == utils._normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op.json) - def test_validate_json_with_merging(self): + def test_init_with_merging(self): """ - Test the _setup_and_validate_json function when json and other named parameters are both + Test the initializer when json and other named parameters are both provided. The named parameters should override top level keys in the json dict. """ @@ -981,38 +750,34 @@ def test_validate_json_with_merging(self): "notebook_task": NOTEBOOK_TASK, } op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, new_cluster=override_new_cluster) - op._setup_and_validate_json() - - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "new_cluster": override_new_cluster, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID, } ) - assert expected == utils._normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op.json) @pytest.mark.db_test - def test_validate_json_with_templating(self): + def test_init_with_templating(self): json = { "new_cluster": NEW_CLUSTER, "notebook_task": TEMPLATED_NOTEBOOK_TASK, } dag = DAG("test", start_date=datetime.now()) op = DatabricksSubmitRunOperator(dag=dag, task_id=TASK_ID, json=json) - op._setup_and_validate_json() - op.render_template_fields(context={"ds": DATE}) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "new_cluster": NEW_CLUSTER, "notebook_task": RENDERED_TEMPLATED_NOTEBOOK_TASK, "run_name": TASK_ID, } ) - assert expected == utils._normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op.json) - def test_validate_json_with_git_source(self): + def test_init_with_git_source(self): json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": RUN_NAME} git_source = { "git_url": "https://github.com/apache/airflow", @@ -1020,9 +785,7 @@ def test_validate_json_with_git_source(self): "git_branch": "main", } op = DatabricksSubmitRunOperator(task_id=TASK_ID, git_source=git_source, json=json) - op._setup_and_validate_json() - - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, @@ -1030,139 +793,18 @@ def test_validate_json_with_git_source(self): "git_source": git_source, } ) - assert expected == utils._normalise_json_content(op.json) + assert expected == utils.normalise_json_content(op.json) - def test_validate_json_with_bad_type(self): + def test_init_with_bad_type(self): json = {"test": datetime.now()} + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) # Looks a bit weird since we have to escape regex reserved symbols. exception_message = ( r"Type \<(type|class) \'datetime.datetime\'\> used " r"for parameter json\[test\] is not a number or a string" ) with pytest.raises(AirflowException, match=exception_message): - DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)._setup_and_validate_json() - - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): - json = "{{ ti.xcom_pull(task_ids='push_json') }}" - with dag_maker("test_templated", render_template_as_native_obj=True): - push_json = PythonOperator( - task_id="push_json", - python_callable=lambda: { - "new_cluster": NEW_CLUSTER, - "notebook_task": NOTEBOOK_TASK, - }, - ) - op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) - push_json >> op - - db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = RUN_ID - db_mock.get_run_page_url.return_value = RUN_PAGE_URL - db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") - - dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) - tis = {ti.task_id: ti for ti in dagrun.task_instances} - tis["push_json"].run() - tis[TASK_ID].run() - - expected = utils._normalise_json_content( - {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} - ) - db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay, - retry_args=None, - caller="DatabricksSubmitRunOperator", - ) - - db_mock.submit_run.assert_called_once_with(expected) - db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run.assert_called_once_with(RUN_ID) - - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_validate_json_with_templated_named_params(self, db_mock_class, dag_maker): - json = "{{ ti.xcom_pull(task_ids='push_json') }}" - with dag_maker("test_templated", render_template_as_native_obj=True): - push_json = PythonOperator( - task_id="push_json", - python_callable=lambda: { - "new_cluster": NEW_CLUSTER, - }, - ) - op = DatabricksSubmitRunOperator( - task_id=TASK_ID, json=json, notebook_task=TEMPLATED_NOTEBOOK_TASK - ) - push_json >> op - - db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = RUN_ID - db_mock.get_run_page_url.return_value = RUN_PAGE_URL - db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") - - dagrun = dag_maker.create_dagrun(execution_date=datetime.strptime(DATE, "%Y-%m-%d")) - tis = {ti.task_id: ti for ti in dagrun.task_instances} - tis["push_json"].run() - tis[TASK_ID].run() - - expected = utils._normalise_json_content( - { - "new_cluster": NEW_CLUSTER, - "notebook_task": RENDERED_TEMPLATED_NOTEBOOK_TASK, - "run_name": TASK_ID, - } - ) - db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay, - retry_args=None, - caller="DatabricksSubmitRunOperator", - ) - - db_mock.submit_run.assert_called_once_with(expected) - db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run.assert_called_once_with(RUN_ID) - - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker): - with dag_maker("test_xcomarg", render_template_as_native_obj=True): - - @task - def push_json() -> dict: - return { - "new_cluster": NEW_CLUSTER, - "notebook_task": NOTEBOOK_TASK, - } - - json = push_json() - - op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) - db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = RUN_ID - db_mock.get_run_page_url.return_value = RUN_PAGE_URL - db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") - - dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) - tis = {ti.task_id: ti for ti in dagrun.task_instances} - tis["push_json"].run() - tis[TASK_ID].run() - - expected = utils._normalise_json_content( - {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} - ) - db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay, - retry_args=None, - caller="DatabricksSubmitRunOperator", - ) - - db_mock.submit_run.assert_called_once_with(expected) - db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run.assert_called_once_with(RUN_ID) + utils.normalise_json_content(op.json) @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_exec_success(self, db_mock_class): @@ -1180,7 +822,7 @@ def test_exec_success(self, db_mock_class): op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) db_mock_class.assert_called_once_with( @@ -1210,7 +852,7 @@ def test_exec_pipeline_name(self, db_mock_class): op.execute(None) - expected = utils._normalise_json_content({"pipeline_task": PIPELINE_ID_TASK, "run_name": TASK_ID}) + expected = utils.normalise_json_content({"pipeline_task": PIPELINE_ID_TASK, "run_name": TASK_ID}) db_mock_class.assert_called_once_with( DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, @@ -1242,7 +884,7 @@ def test_exec_failure(self, db_mock_class): with pytest.raises(AirflowException): op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, @@ -1290,7 +932,7 @@ def test_wait_for_termination(self, db_mock_class): op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) db_mock_class.assert_called_once_with( @@ -1319,7 +961,7 @@ def test_no_wait_for_termination(self, db_mock_class): op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) db_mock_class.assert_called_once_with( @@ -1353,7 +995,7 @@ def test_execute_task_deferred(self, db_mock_class): assert isinstance(exc.value.trigger, DatabricksExecutionTrigger) assert exc.value.method_name == "execute_complete" - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) db_mock_class.assert_called_once_with( @@ -1435,7 +1077,7 @@ def test_databricks_submit_run_deferrable_operator_failed_before_defer(self, moc db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) db_mock_class.assert_called_once_with( @@ -1465,7 +1107,7 @@ def test_databricks_submit_run_deferrable_operator_success_before_defer(self, mo db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} ) db_mock_class.assert_called_once_with( @@ -1483,20 +1125,18 @@ def test_databricks_submit_run_deferrable_operator_success_before_defer(self, mo class TestDatabricksRunNowOperator: - def test_validate_json_with_named_parameters(self): + def test_init_with_named_parameters(self): """ - Test the _setup_and_validate_json function with named parameters. + Test the initializer with the named parameters. """ op = DatabricksRunNowOperator(job_id=JOB_ID, task_id=TASK_ID) - op._setup_and_validate_json() - - expected = utils._normalise_json_content({"job_id": 42}) + expected = utils.normalise_json_content({"job_id": 42}) assert expected == op.json - def test_validate_json_with_json(self): + def test_init_with_json(self): """ - Test the _setup_and_validate_json function with json data. + Test the initializer with json data. """ json = { "notebook_params": NOTEBOOK_PARAMS, @@ -1507,9 +1147,8 @@ def test_validate_json_with_json(self): "repair_run": False, } op = DatabricksRunNowOperator(task_id=TASK_ID, json=json) - op._setup_and_validate_json() - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "jar_params": JAR_PARAMS, @@ -1522,9 +1161,9 @@ def test_validate_json_with_json(self): assert expected == op.json - def test_validate_json_with_merging(self): + def test_init_with_merging(self): """ - Test the _setup_and_validate_json function when json and other named parameters are both + Test the initializer when json and other named parameters are both provided. The named parameters should override top level keys in the json dict. """ @@ -1541,9 +1180,8 @@ def test_validate_json_with_merging(self): jar_params=override_jar_params, spark_submit_params=SPARK_SUBMIT_PARAMS, ) - op._setup_and_validate_json() - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "notebook_params": override_notebook_params, "jar_params": override_jar_params, @@ -1556,14 +1194,13 @@ def test_validate_json_with_merging(self): assert expected == op.json @pytest.mark.db_test - def test_validate_json_with_templating(self): + def test_init_with_templating(self): json = {"notebook_params": NOTEBOOK_PARAMS, "jar_params": TEMPLATED_JAR_PARAMS} dag = DAG("test", start_date=datetime.now()) op = DatabricksRunNowOperator(dag=dag, task_id=TASK_ID, job_id=JOB_ID, json=json) op.render_template_fields(context={"ds": DATE}) - op._setup_and_validate_json() - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "jar_params": RENDERED_TEMPLATED_JAR_PARAMS, @@ -1572,7 +1209,7 @@ def test_validate_json_with_templating(self): ) assert expected == op.json - def test_validate_json_with_bad_type(self): + def test_init_with_bad_type(self): json = {"test": datetime.now()} # Looks a bit weird since we have to escape regex reserved symbols. exception_message = ( @@ -1580,162 +1217,7 @@ def test_validate_json_with_bad_type(self): r"for parameter json\[test\] is not a number or a string" ) with pytest.raises(AirflowException, match=exception_message): - DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json)._setup_and_validate_json() - - def test_validate_json_exception_with_job_name_and_job_id(self): - exception_message = "Argument 'job_name' is not allowed with argument 'job_id'" - - with pytest.raises(AirflowException, match=exception_message): - DatabricksRunNowOperator( - task_id=TASK_ID, job_id=JOB_ID, job_name=JOB_NAME - )._setup_and_validate_json() - - run = {"job_id": JOB_ID, "job_name": JOB_NAME} - with pytest.raises(AirflowException, match=exception_message): - DatabricksRunNowOperator(task_id=TASK_ID, json=run)._setup_and_validate_json() - - run = {"job_id": JOB_ID} - with pytest.raises(AirflowException, match=exception_message): - DatabricksRunNowOperator(task_id=TASK_ID, json=run, job_name=JOB_NAME)._setup_and_validate_json() - - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_validate_json_with_templated_json(self, db_mock_class, dag_maker): - json = "{{ ti.xcom_pull(task_ids='push_json') }}" - with dag_maker("test_templated", render_template_as_native_obj=True): - push_json = PythonOperator( - task_id="push_json", - python_callable=lambda: { - "notebook_params": NOTEBOOK_PARAMS, - "notebook_task": NOTEBOOK_TASK, - "jar_params": JAR_PARAMS, - "job_id": JOB_ID, - }, - ) - op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json) - push_json >> op - - db_mock = db_mock_class.return_value - db_mock.run_now.return_value = RUN_ID - db_mock.get_run_page_url.return_value = RUN_PAGE_URL - db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") - - dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) - tis = {ti.task_id: ti for ti in dagrun.task_instances} - tis["push_json"].run() - tis[TASK_ID].run() - - expected = utils._normalise_json_content( - { - "notebook_params": NOTEBOOK_PARAMS, - "notebook_task": NOTEBOOK_TASK, - "jar_params": JAR_PARAMS, - "job_id": JOB_ID, - } - ) - - db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay, - retry_args=None, - caller="DatabricksRunNowOperator", - ) - db_mock.run_now.assert_called_once_with(expected) - db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run.assert_called_once_with(RUN_ID) - - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_validate_json_with_templated_named_params(self, db_mock_class, dag_maker): - json = "{{ ti.xcom_pull(task_ids='push_json') }}" - with dag_maker("test_templated", render_template_as_native_obj=True): - push_json = PythonOperator( - task_id="push_json", - python_callable=lambda: { - "notebook_params": NOTEBOOK_PARAMS, - "notebook_task": NOTEBOOK_TASK, - }, - ) - op = DatabricksRunNowOperator( - task_id=TASK_ID, job_id=TEMPLATED_JOB_ID, jar_params=TEMPLATED_JAR_PARAMS, json=json - ) - push_json >> op - - db_mock = db_mock_class.return_value - db_mock.run_now.return_value = RUN_ID - db_mock.get_run_page_url.return_value = RUN_PAGE_URL - db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") - - dagrun = dag_maker.create_dagrun(execution_date=datetime.strptime(DATE, "%Y-%m-%d")) - tis = {ti.task_id: ti for ti in dagrun.task_instances} - tis["push_json"].run() - tis[TASK_ID].run() - - expected = utils._normalise_json_content( - { - "notebook_params": NOTEBOOK_PARAMS, - "notebook_task": NOTEBOOK_TASK, - "jar_params": RENDERED_TEMPLATED_JAR_PARAMS, - "job_id": RENDERED_TEMPLATED_JOB_ID, - } - ) - - db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay, - retry_args=None, - caller="DatabricksRunNowOperator", - ) - db_mock.run_now.assert_called_once_with(expected) - db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run.assert_called_once_with(RUN_ID) - - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") - def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker): - with dag_maker("test_xcomarg", render_template_as_native_obj=True): - - @task - def push_json() -> dict: - return { - "notebook_params": NOTEBOOK_PARAMS, - "notebook_task": NOTEBOOK_TASK, - "jar_params": JAR_PARAMS, - "job_id": JOB_ID, - } - - json = push_json() - - op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json) - - db_mock = db_mock_class.return_value - db_mock.run_now.return_value = RUN_ID - db_mock.get_run_page_url.return_value = RUN_PAGE_URL - db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") - - dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow()) - tis = {ti.task_id: ti for ti in dagrun.task_instances} - tis["push_json"].run() - tis[TASK_ID].run() - - expected = utils._normalise_json_content( - { - "notebook_params": NOTEBOOK_PARAMS, - "notebook_task": NOTEBOOK_TASK, - "jar_params": JAR_PARAMS, - "job_id": JOB_ID, - } - ) - - db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit, - retry_delay=op.databricks_retry_delay, - retry_args=None, - caller="DatabricksRunNowOperator", - ) - db_mock.run_now.assert_called_once_with(expected) - db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - db_mock.get_run.assert_called_once_with(RUN_ID) + DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json) @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_exec_success(self, db_mock_class): @@ -1750,7 +1232,7 @@ def test_exec_success(self, db_mock_class): op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -1785,7 +1267,7 @@ def test_exec_failure(self, db_mock_class): with pytest.raises(AirflowException): op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -1841,7 +1323,7 @@ def test_exec_failure_with_message(self, db_mock_class): with pytest.raises(AirflowException, match="Exception: Something went wrong"): op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -1909,7 +1391,7 @@ def test_exec_multiple_failures_with_message(self, db_mock_class): ): op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -1953,7 +1435,7 @@ def test_wait_for_termination(self, db_mock_class): op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -1984,7 +1466,7 @@ def test_no_wait_for_termination(self, db_mock_class): op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -2004,6 +1486,20 @@ def test_no_wait_for_termination(self, db_mock_class): db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run.assert_not_called() + def test_init_exception_with_job_name_and_job_id(self): + exception_message = "Argument 'job_name' is not allowed with argument 'job_id'" + + with pytest.raises(AirflowException, match=exception_message): + DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, job_name=JOB_NAME) + + run = {"job_id": JOB_ID, "job_name": JOB_NAME} + with pytest.raises(AirflowException, match=exception_message): + DatabricksRunNowOperator(task_id=TASK_ID, json=run) + + run = {"job_id": JOB_ID} + with pytest.raises(AirflowException, match=exception_message): + DatabricksRunNowOperator(task_id=TASK_ID, json=run, job_name=JOB_NAME) + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_exec_with_job_name(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} @@ -2015,7 +1511,7 @@ def test_exec_with_job_name(self, db_mock_class): op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -2063,7 +1559,7 @@ def test_cancel_previous_runs(self, db_mock_class): op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -2097,7 +1593,7 @@ def test_no_cancel_previous_runs(self, db_mock_class): op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -2134,7 +1630,7 @@ def test_execute_task_deferred(self, db_mock_class): assert isinstance(exc.value.trigger, DatabricksExecutionTrigger) assert exc.value.method_name == "execute_complete" - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -2244,7 +1740,7 @@ def test_databricks_run_now_deferrable_operator_failed_before_defer(self, mock_d db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, @@ -2278,7 +1774,7 @@ def test_databricks_run_now_deferrable_operator_success_before_defer(self, mock_ op.execute(None) - expected = utils._normalise_json_content( + expected = utils.normalise_json_content( { "notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, diff --git a/tests/providers/databricks/utils/test_databricks.py b/tests/providers/databricks/utils/test_databricks.py index 4b57573253d47..8c6ce8ce4ba59 100644 --- a/tests/providers/databricks/utils/test_databricks.py +++ b/tests/providers/databricks/utils/test_databricks.py @@ -21,7 +21,7 @@ from airflow.exceptions import AirflowException from airflow.providers.databricks.hooks.databricks import RunState -from airflow.providers.databricks.utils.databricks import _normalise_json_content, validate_trigger_event +from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event RUN_ID = 1 RUN_PAGE_URL = "run-page-url" @@ -46,7 +46,7 @@ def test_normalise_json_content(self): "test_list": ["1", "1.0", "a", "b"], "test_tuple": ["1", "1.0", "a", "b"], } - assert _normalise_json_content(test_json) == expected + assert normalise_json_content(test_json) == expected def test_validate_trigger_event_success(self): event = {