diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 5d3f714b44e6d..e1da837f43e5f 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -41,6 +41,8 @@ START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start") TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete") +CREATE_ENDPOINT = ("POST", "api/2.1/jobs/create") +RESET_ENDPOINT = ("POST", "api/2.1/jobs/reset") RUN_NOW_ENDPOINT = ("POST", "api/2.1/jobs/run-now") SUBMIT_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/submit") GET_RUN_ENDPOINT = ("GET", "api/2.1/jobs/runs/get") @@ -194,6 +196,24 @@ def __init__( ) -> None: super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay, retry_args, caller) + def create_job(self, json: dict) -> int: + """ + Utility function to call the ``api/2.1/jobs/create`` endpoint. + + :param json: The data used in the body of the request to the ``create`` endpoint. + :return: the job_id as an int + """ + response = self._do_api_call(CREATE_ENDPOINT, json) + return response["job_id"] + + def reset_job(self, job_id: str, json: dict) -> None: + """ + Utility function to call the ``api/2.1/jobs/reset`` endpoint. + + :param json: The data used in the new_settings of the request to the ``reset`` endpoint. + """ + self._do_api_call(RESET_ENDPOINT, {"job_id": job_id, "new_settings": json}) + def run_now(self, json: dict) -> int: """ Call the ``api/2.1/jobs/run-now`` endpoint. diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index fd9e256230f7c..5a54b7e9eea2f 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -21,6 +21,7 @@ import time import warnings from functools import cached_property +from logging import Logger from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf @@ -31,8 +32,6 @@ from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event if TYPE_CHECKING: - from logging import Logger - from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.context import Context @@ -162,6 +161,135 @@ def get_link( return XCom.get_value(key=XCOM_RUN_PAGE_URL_KEY, ti_key=ti_key) +class DatabricksCreateJobsOperator(BaseOperator): + """Creates (or resets) a Databricks job using the API endpoint. + + .. seealso:: + https://docs.databricks.com/api/workspace/jobs/create + https://docs.databricks.com/api/workspace/jobs/reset + + :param json: A JSON object containing API parameters which will be passed + directly to the ``api/2.1/jobs/create`` endpoint. The other named parameters + (i.e. ``name``, ``tags``, ``tasks``, etc.) to this operator will + be merged with this json dictionary if they are provided. + If there are conflicts during the merge, the named parameters will + take precedence and override the top level json keys. (templated) + + .. seealso:: + For more information about templating see :ref:`concepts:jinja-templating`. + :param name: An optional name for the job. + :param tags: A map of tags associated with the job. + :param tasks: A list of task specifications to be executed by this job. + Array of objects (JobTaskSettings). + :param job_clusters: A list of job cluster specifications that can be shared and reused by + tasks of this job. Array of objects (JobCluster). + :param email_notifications: Object (JobEmailNotifications). + :param webhook_notifications: Object (WebhookNotifications). + :param timeout_seconds: An optional timeout applied to each run of this job. + :param schedule: Object (CronSchedule). + :param max_concurrent_runs: An optional maximum allowed number of concurrent runs of the job. + :param git_source: An optional specification for a remote repository containing the notebooks + used by this job's notebook tasks. Object (GitSource). + :param access_control_list: List of permissions to set on the job. Array of object + (AccessControlRequestForUser) or object (AccessControlRequestForGroup) or object + (AccessControlRequestForServicePrincipal). + + .. seealso:: + This will only be used on create. In order to reset ACL consider using the Databricks + UI. + :param databricks_conn_id: Reference to the + :ref:`Databricks connection `. (templated) + :param polling_period_seconds: Controls the rate which we poll for the result of + this run. By default the operator will poll every 30 seconds. + :param databricks_retry_limit: Amount of times retry if the Databricks backend is + unreachable. Its value must be greater than or equal to 1. + :param databricks_retry_delay: Number of seconds to wait between retries (it + might be a floating point number). + :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. + + """ + + # Used in airflow.models.BaseOperator + template_fields: Sequence[str] = ("json", "databricks_conn_id") + # Databricks brand color (blue) under white text + ui_color = "#1CB1C2" + ui_fgcolor = "#fff" + + def __init__( + self, + *, + json: Any | None = None, + name: str | None = None, + tags: dict[str, str] | None = None, + tasks: list[dict] | None = None, + job_clusters: list[dict] | None = None, + email_notifications: dict | None = None, + webhook_notifications: dict | None = None, + timeout_seconds: int | None = None, + schedule: dict | None = None, + max_concurrent_runs: int | None = None, + git_source: dict | None = None, + access_control_list: list[dict] | None = None, + databricks_conn_id: str = "databricks_default", + polling_period_seconds: int = 30, + databricks_retry_limit: int = 3, + databricks_retry_delay: int = 1, + databricks_retry_args: dict[Any, Any] | None = None, + **kwargs, + ) -> None: + """Creates a new ``DatabricksCreateJobsOperator``.""" + super().__init__(**kwargs) + self.json = json or {} + self.databricks_conn_id = databricks_conn_id + self.polling_period_seconds = polling_period_seconds + self.databricks_retry_limit = databricks_retry_limit + self.databricks_retry_delay = databricks_retry_delay + self.databricks_retry_args = databricks_retry_args + if name is not None: + self.json["name"] = name + if tags is not None: + self.json["tags"] = tags + if tasks is not None: + self.json["tasks"] = tasks + if job_clusters is not None: + self.json["job_clusters"] = job_clusters + if email_notifications is not None: + self.json["email_notifications"] = email_notifications + if webhook_notifications is not None: + self.json["webhook_notifications"] = webhook_notifications + if timeout_seconds is not None: + self.json["timeout_seconds"] = timeout_seconds + if schedule is not None: + self.json["schedule"] = schedule + if max_concurrent_runs is not None: + self.json["max_concurrent_runs"] = max_concurrent_runs + if git_source is not None: + self.json["git_source"] = git_source + if access_control_list is not None: + self.json["access_control_list"] = access_control_list + + self.json = normalise_json_content(self.json) + + @cached_property + def _hook(self): + return DatabricksHook( + self.databricks_conn_id, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller="DatabricksCreateJobsOperator", + ) + + def execute(self, context: Context) -> int: + if "name" not in self.json: + raise AirflowException("Missing required parameter: name") + job_id = self._hook.find_job_id_by_name(self.json["name"]) + if job_id is None: + return self._hook.create_job(self.json) + self._hook.reset_job(str(job_id), self.json) + return job_id + + class DatabricksSubmitRunOperator(BaseOperator): """ Submits a Spark job run to Databricks using the api/2.1/jobs/runs/submit API endpoint. diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index bfaed1bc70c5c..52463c351705c 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -66,6 +66,7 @@ integrations: - integration-name: Databricks external-doc-url: https://databricks.com/ how-to-guide: + - /docs/apache-airflow-providers-databricks/operators/jobs_create.rst - /docs/apache-airflow-providers-databricks/operators/submit_run.rst - /docs/apache-airflow-providers-databricks/operators/run_now.rst logo: /integration-logos/databricks/Databricks.png @@ -123,3 +124,10 @@ connection-types: extra-links: - airflow.providers.databricks.operators.databricks.DatabricksJobRunLink + +additional-extras: + # pip install apache-airflow-providers-databricks[sdk] + - name: sdk + description: Install Databricks SDK + dependencies: + - databricks-sdk==0.10.0 diff --git a/docs/apache-airflow-providers-databricks/operators/jobs_create.rst b/docs/apache-airflow-providers-databricks/operators/jobs_create.rst new file mode 100644 index 0000000000000..779095e92cd6b --- /dev/null +++ b/docs/apache-airflow-providers-databricks/operators/jobs_create.rst @@ -0,0 +1,91 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +DatabricksCreateJobsOperator +============================ + +Use the :class:`~airflow.providers.databricks.operators.DatabricksCreateJobsOperator` to create +(or reset) a Databricks job. This operator relies on past XComs to remember the ``job_id`` that +was created so that repeated calls with this operator will update the existing job rather than +creating new ones. When paired with the DatabricksRunNowOperator all runs will fall under the same +job within the Databricks UI. + + +Using the Operator +------------------ + +There are three ways to instantiate this operator. In the first way, you can take the JSON payload that you typically use +to call the ``api/2.1/jobs/create`` endpoint and pass it directly to our ``DatabricksCreateJobsOperator`` through the +``json`` parameter. With this approach you get full control over the underlying payload to Jobs REST API, including +execution of Databricks jobs with multiple tasks, but it's harder to detect errors because of the lack of the type checking. + +The second way to accomplish the same thing is to use the named parameters of the ``DatabricksCreateJobsOperator`` directly. Note that there is exactly +one named parameter for each top level parameter in the ``api/2.1/jobs/create`` endpoint. + +The third way is to use both the json parameter **AND** the named parameters. They will be merged +together. If there are conflicts during the merge, the named parameters will take precedence and +override the top level ``json`` keys. + +Currently the named parameters that ``DatabricksCreateJobsOperator`` supports are: + - ``name`` + - ``tags`` + - ``tasks`` + - ``job_clusters`` + - ``email_notifications`` + - ``webhook_notifications`` + - ``timeout_seconds`` + - ``schedule`` + - ``max_concurrent_runs`` + - ``git_source`` + - ``access_control_list`` + + +Examples +-------- + +Specifying parameters as JSON +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +An example usage of the DatabricksCreateJobsOperator is as follows: + +.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py + :language: python + :start-after: [START howto_operator_databricks_jobs_create_json] + :end-before: [END howto_operator_databricks_jobs_create_json] + +Using named parameters +^^^^^^^^^^^^^^^^^^^^^^ + +You can also use named parameters to initialize the operator and run the job. + +.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py + :language: python + :start-after: [START howto_operator_databricks_jobs_create_named] + :end-before: [END howto_operator_databricks_jobs_create_named] + +Pairing with DatabricksRunNowOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +You can use the ``job_id`` that is returned by the DatabricksCreateJobsOperator in the +return_value XCom as an argument to the DatabricksRunNowOperator to run the job. + +.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py + :language: python + :start-after: [START howto_operator_databricks_run_now] + :end-before: [END howto_operator_databricks_run_now] diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 2566e9a394769..c836691b1e178 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -107,6 +107,20 @@ } +def create_endpoint(host): + """ + Utility function to generate the create endpoint given the host. + """ + return f"https://{host}/api/2.1/jobs/create" + + +def reset_endpoint(host): + """ + Utility function to generate the reset endpoint given the host. + """ + return f"https://{host}/api/2.1/jobs/reset" + + def run_now_endpoint(host): """ Utility function to generate the run now endpoint given the host. @@ -387,6 +401,43 @@ def test_do_api_call_patch(self, mock_requests): timeout=self.hook.timeout_seconds, ) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") + def test_create(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.return_value.json.return_value = {"job_id": JOB_ID} + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + json = {"name": "test"} + job_id = self.hook.create_job(json) + + assert job_id == JOB_ID + + mock_requests.post.assert_called_once_with( + create_endpoint(HOST), + json={"name": "test"}, + params=None, + auth=HTTPBasicAuth(LOGIN, PASSWORD), + headers=self.hook.user_agent_header, + timeout=self.hook.timeout_seconds, + ) + + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") + def test_reset(self, mock_requests): + mock_requests.codes.ok = 200 + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + json = {"name": "test"} + self.hook.reset_job(JOB_ID, json) + + mock_requests.post.assert_called_once_with( + reset_endpoint(HOST), + json={"job_id": JOB_ID, "new_settings": {"name": "test"}}, + params=None, + auth=HTTPBasicAuth(LOGIN, PASSWORD), + headers=self.hook.user_agent_header, + timeout=self.hook.timeout_seconds, + ) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") def test_submit_run(self, mock_requests): mock_requests.post.return_value.json.return_value = {"run_id": "1"} diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index e03eb7dcccfc1..5fb1a31cd32bc 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -27,6 +27,7 @@ from airflow.models import DAG from airflow.providers.databricks.hooks.databricks import RunState from airflow.providers.databricks.operators.databricks import ( + DatabricksCreateJobsOperator, DatabricksRunNowDeferrableOperator, DatabricksRunNowOperator, DatabricksSubmitRunDeferrableOperator, @@ -71,6 +72,152 @@ "schema": "jaffle_shop", "warehouse_id": "123456789abcdef0", } +TAGS = { + "cost-center": "engineering", + "team": "jobs", +} +TASKS = [ + { + "task_key": "Sessionize", + "description": "Extracts session data from events", + "existing_cluster_id": "0923-164208-meows279", + "spark_jar_task": { + "main_class_name": "com.databricks.Sessionize", + "parameters": [ + "--data", + "dbfs:/path/to/data.json", + ], + }, + "libraries": [ + {"jar": "dbfs:/mnt/databricks/Sessionize.jar"}, + ], + "timeout_seconds": 86400, + "max_retries": 3, + "min_retry_interval_millis": 2000, + "retry_on_timeout": False, + }, + { + "task_key": "Orders_Ingest", + "description": "Ingests order data", + "job_cluster_key": "auto_scaling_cluster", + "spark_jar_task": { + "main_class_name": "com.databricks.OrdersIngest", + "parameters": ["--data", "dbfs:/path/to/order-data.json"], + }, + "libraries": [ + {"jar": "dbfs:/mnt/databricks/OrderIngest.jar"}, + ], + "timeout_seconds": 86400, + "max_retries": 3, + "min_retry_interval_millis": 2000, + "retry_on_timeout": False, + }, + { + "task_key": "Match", + "description": "Matches orders with user sessions", + "depends_on": [ + {"task_key": "Orders_Ingest"}, + {"task_key": "Sessionize"}, + ], + "new_cluster": { + "spark_version": "7.3.x-scala2.12", + "node_type_id": "i3.xlarge", + "spark_conf": { + "spark.speculation": True, + }, + "aws_attributes": { + "availability": "SPOT", + "zone_id": "us-west-2a", + }, + "autoscale": { + "min_workers": 2, + "max_workers": 16, + }, + }, + "notebook_task": { + "notebook_path": "/Users/user.name@databricks.com/Match", + "source": "WORKSPACE", + "base_parameters": { + "name": "John Doe", + "age": "35", + }, + }, + "timeout_seconds": 86400, + "max_retries": 3, + "min_retry_interval_millis": 2000, + "retry_on_timeout": False, + }, +] +JOB_CLUSTERS = [ + { + "job_cluster_key": "auto_scaling_cluster", + "new_cluster": { + "spark_version": "7.3.x-scala2.12", + "node_type_id": "i3.xlarge", + "spark_conf": { + "spark.speculation": True, + }, + "aws_attributes": { + "availability": "SPOT", + "zone_id": "us-west-2a", + }, + "autoscale": { + "min_workers": 2, + "max_workers": 16, + }, + }, + }, +] +EMAIL_NOTIFICATIONS = { + "on_start": [ + "user.name@databricks.com", + ], + "on_success": [ + "user.name@databricks.com", + ], + "on_failure": [ + "user.name@databricks.com", + ], + "no_alert_for_skipped_runs": False, +} +WEBHOOK_NOTIFICATIONS = { + "on_start": [ + { + "id": "03dd86e4-57ef-4818-a950-78e41a1d71ab", + }, + { + "id": "0481e838-0a59-4eff-9541-a4ca6f149574", + }, + ], + "on_success": [ + { + "id": "03dd86e4-57ef-4818-a950-78e41a1d71ab", + } + ], + "on_failure": [ + { + "id": "0481e838-0a59-4eff-9541-a4ca6f149574", + } + ], +} +TIMEOUT_SECONDS = 86400 +SCHEDULE = { + "quartz_cron_expression": "20 30 * * * ?", + "timezone_id": "Europe/London", + "pause_status": "PAUSED", +} +MAX_CONCURRENT_RUNS = 10 +GIT_SOURCE = { + "git_url": "https://github.com/databricks/databricks-cli", + "git_branch": "main", + "git_provider": "gitHub", +} +ACCESS_CONTROL_LIST = [ + { + "user_name": "jsmith@example.com", + "permission_level": "CAN_MANAGE", + } +] def mock_dict(d: dict): @@ -95,6 +242,267 @@ def make_run_with_state_mock( ) +class TestDatabricksCreateJobsOperator: + def test_init_with_named_parameters(self): + """ + Test the initializer with the named parameters. + """ + op = DatabricksCreateJobsOperator( + task_id=TASK_ID, + name=JOB_NAME, + tags=TAGS, + tasks=TASKS, + job_clusters=JOB_CLUSTERS, + email_notifications=EMAIL_NOTIFICATIONS, + webhook_notifications=WEBHOOK_NOTIFICATIONS, + timeout_seconds=TIMEOUT_SECONDS, + schedule=SCHEDULE, + max_concurrent_runs=MAX_CONCURRENT_RUNS, + git_source=GIT_SOURCE, + access_control_list=ACCESS_CONTROL_LIST, + ) + expected = utils.normalise_json_content( + { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + ) + + assert expected == op.json + + def test_init_with_json(self): + """ + Test the initializer with json data. + """ + json = { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) + + expected = utils.normalise_json_content( + { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + ) + + assert expected == op.json + + def test_init_with_merging(self): + """ + Test the initializer when json and other named parameters are both + provided. The named parameters should override top level keys in the + json dict. + """ + override_name = "override" + override_tags = {} + override_tasks = [] + override_job_clusters = [] + override_email_notifications = {} + override_webhook_notifications = {} + override_timeout_seconds = 0 + override_schedule = {} + override_max_concurrent_runs = 0 + override_git_source = {} + override_access_control_list = [] + json = { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + + op = DatabricksCreateJobsOperator( + task_id=TASK_ID, + json=json, + name=override_name, + tags=override_tags, + tasks=override_tasks, + job_clusters=override_job_clusters, + email_notifications=override_email_notifications, + webhook_notifications=override_webhook_notifications, + timeout_seconds=override_timeout_seconds, + schedule=override_schedule, + max_concurrent_runs=override_max_concurrent_runs, + git_source=override_git_source, + access_control_list=override_access_control_list, + ) + + expected = utils.normalise_json_content( + { + "name": override_name, + "tags": override_tags, + "tasks": override_tasks, + "job_clusters": override_job_clusters, + "email_notifications": override_email_notifications, + "webhook_notifications": override_webhook_notifications, + "timeout_seconds": override_timeout_seconds, + "schedule": override_schedule, + "max_concurrent_runs": override_max_concurrent_runs, + "git_source": override_git_source, + "access_control_list": override_access_control_list, + } + ) + + assert expected == op.json + + def test_init_with_templating(self): + json = {"name": "test-{{ ds }}"} + + dag = DAG("test", start_date=datetime.now()) + op = DatabricksCreateJobsOperator(dag=dag, task_id=TASK_ID, json=json) + op.render_template_fields(context={"ds": DATE}) + expected = utils.normalise_json_content({"name": f"test-{DATE}"}) + assert expected == op.json + + def test_init_with_bad_type(self): + json = {"test": datetime.now()} + # Looks a bit weird since we have to escape regex reserved symbols. + exception_message = ( + r"Type \<(type|class) \'datetime.datetime\'\> used " + r"for parameter json\[test\] is not a number or a string" + ) + with pytest.raises(AirflowException, match=exception_message): + DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_exec_create(self, db_mock_class): + """ + Test the execute function in case where the job does not exist. + """ + json = { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) + db_mock = db_mock_class.return_value + db_mock.create_job.return_value = JOB_ID + + db_mock.find_job_id_by_name.return_value = None + + return_result = op.execute({}) + + expected = utils.normalise_json_content( + { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksCreateJobsOperator", + ) + + db_mock.create_job.assert_called_once_with(expected) + assert JOB_ID == return_result + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_exec_reset(self, db_mock_class): + """ + Test the execute function in case where the job already exists. + """ + json = { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) + db_mock = db_mock_class.return_value + db_mock.find_job_id_by_name.return_value = JOB_ID + + return_result = op.execute({}) + + expected = utils.normalise_json_content( + { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksCreateJobsOperator", + ) + + db_mock.reset_job.assert_called_once_with(JOB_ID, expected) + assert JOB_ID == return_result + + class TestDatabricksSubmitRunOperator: def test_init_with_notebook_task_named_parameters(self): """ diff --git a/tests/system/providers/databricks/example_databricks.py b/tests/system/providers/databricks/example_databricks.py index a6cae8a518a98..3a7ed3e53b2e0 100644 --- a/tests/system/providers/databricks/example_databricks.py +++ b/tests/system/providers/databricks/example_databricks.py @@ -36,7 +36,11 @@ from datetime import datetime from airflow import DAG -from airflow.providers.databricks.operators.databricks import DatabricksSubmitRunOperator +from airflow.providers.databricks.operators.databricks import ( + DatabricksCreateJobsOperator, + DatabricksRunNowOperator, + DatabricksSubmitRunOperator, +) ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "example_databricks_operator" @@ -48,6 +52,69 @@ tags=["example"], catchup=False, ) as dag: + # [START howto_operator_databricks_jobs_create_json] + # Example of using the JSON parameter to initialize the operator. + job = { + "tasks": [ + { + "task_key": "test", + "job_cluster_key": "job_cluster", + "notebook_task": { + "notebook_path": "/Shared/test", + }, + }, + ], + "job_clusters": [ + { + "job_cluster_key": "job_cluster", + "new_cluster": { + "spark_version": "7.3.x-scala2.12", + "node_type_id": "i3.xlarge", + "num_workers": 2, + }, + }, + ], + } + + jobs_create_json = DatabricksCreateJobsOperator(task_id="jobs_create_json", json=job) + # [END howto_operator_databricks_jobs_create_json] + + # [START howto_operator_databricks_jobs_create_named] + # Example of using the named parameters to initialize the operator. + tasks = [ + { + "task_key": "test", + "job_cluster_key": "job_cluster", + "notebook_task": { + "notebook_path": "/Shared/test", + }, + }, + ] + job_clusters = [ + { + "job_cluster_key": "job_cluster", + "new_cluster": { + "spark_version": "7.3.x-scala2.12", + "node_type_id": "i3.xlarge", + "num_workers": 2, + }, + }, + ] + + jobs_create_named = DatabricksCreateJobsOperator( + task_id="jobs_create_named", tasks=tasks, job_clusters=job_clusters + ) + # [END howto_operator_databricks_jobs_create_named] + + # [START howto_operator_databricks_run_now] + # Example of using the DatabricksRunNowOperator after creating a job with DatabricksCreateJobsOperator. + run_now = DatabricksRunNowOperator( + task_id="run_now", job_id="{{ ti.xcom_pull(task_ids='jobs_create_named') }}" + ) + + jobs_create_named >> run_now + # [END howto_operator_databricks_run_now] + # [START howto_operator_databricks_json] # Example of using the JSON parameter to initialize the operator. new_cluster = {