Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add operator to create jobs in Databricks #35156

Merged
merged 19 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start")
TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete")

CREATE_ENDPOINT = ("POST", "api/2.1/jobs/create")
RESET_ENDPOINT = ("POST", "api/2.1/jobs/reset")
RUN_NOW_ENDPOINT = ("POST", "api/2.1/jobs/run-now")
SUBMIT_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/submit")
GET_RUN_ENDPOINT = ("GET", "api/2.1/jobs/runs/get")
Expand Down Expand Up @@ -194,6 +196,24 @@ def __init__(
) -> None:
super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay, retry_args, caller)

def create_job(self, json: dict) -> int:
"""
Utility function to call the ``api/2.1/jobs/create`` endpoint.
:param json: The data used in the body of the request to the ``create`` endpoint.
:return: the job_id as an int
"""
response = self._do_api_call(CREATE_ENDPOINT, json)
return response["job_id"]

def reset_job(self, job_id: str, json: dict) -> None:
"""
Utility function to call the ``api/2.1/jobs/reset`` endpoint.
:param json: The data used in the new_settings of the request to the ``reset`` endpoint.
"""
self._do_api_call(RESET_ENDPOINT, {"job_id": job_id, "new_settings": json})

def run_now(self, json: dict) -> int:
"""
Call the ``api/2.1/jobs/run-now`` endpoint.
Expand Down
132 changes: 130 additions & 2 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import time
import warnings
from functools import cached_property
from logging import Logger
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
Expand All @@ -31,8 +32,6 @@
from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event

if TYPE_CHECKING:
from logging import Logger

from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context

Expand Down Expand Up @@ -162,6 +161,135 @@ def get_link(
return XCom.get_value(key=XCOM_RUN_PAGE_URL_KEY, ti_key=ti_key)


class DatabricksCreateJobsOperator(BaseOperator):
"""Creates (or resets) a Databricks job using the API endpoint.
.. seealso::
https://docs.databricks.com/api/workspace/jobs/create
https://docs.databricks.com/api/workspace/jobs/reset
:param json: A JSON object containing API parameters which will be passed
directly to the ``api/2.1/jobs/create`` endpoint. The other named parameters
(i.e. ``name``, ``tags``, ``tasks``, etc.) to this operator will
be merged with this json dictionary if they are provided.
If there are conflicts during the merge, the named parameters will
take precedence and override the top level json keys. (templated)
.. seealso::
For more information about templating see :ref:`concepts:jinja-templating`.
:param name: An optional name for the job.
:param tags: A map of tags associated with the job.
:param tasks: A list of task specifications to be executed by this job.
Array of objects (JobTaskSettings).
:param job_clusters: A list of job cluster specifications that can be shared and reused by
tasks of this job. Array of objects (JobCluster).
:param email_notifications: Object (JobEmailNotifications).
:param webhook_notifications: Object (WebhookNotifications).
:param timeout_seconds: An optional timeout applied to each run of this job.
:param schedule: Object (CronSchedule).
:param max_concurrent_runs: An optional maximum allowed number of concurrent runs of the job.
:param git_source: An optional specification for a remote repository containing the notebooks
used by this job's notebook tasks. Object (GitSource).
:param access_control_list: List of permissions to set on the job. Array of object
(AccessControlRequestForUser) or object (AccessControlRequestForGroup) or object
(AccessControlRequestForServicePrincipal).
.. seealso::
This will only be used on create. In order to reset ACL consider using the Databricks
UI.
:param databricks_conn_id: Reference to the
:ref:`Databricks connection <howto/connection:databricks>`. (templated)
:param polling_period_seconds: Controls the rate which we poll for the result of
this run. By default the operator will poll every 30 seconds.
:param databricks_retry_limit: Amount of times retry if the Databricks backend is
unreachable. Its value must be greater than or equal to 1.
:param databricks_retry_delay: Number of seconds to wait between retries (it
might be a floating point number).
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = ("json", "databricks_conn_id")
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
ui_fgcolor = "#fff"

def __init__(
self,
*,
json: Any | None = None,
name: str | None = None,
tags: dict[str, str] | None = None,
tasks: list[dict] | None = None,
job_clusters: list[dict] | None = None,
email_notifications: dict | None = None,
webhook_notifications: dict | None = None,
timeout_seconds: int | None = None,
schedule: dict | None = None,
max_concurrent_runs: int | None = None,
git_source: dict | None = None,
access_control_list: list[dict] | None = None,
databricks_conn_id: str = "databricks_default",
polling_period_seconds: int = 30,
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
databricks_retry_args: dict[Any, Any] | None = None,
**kwargs,
) -> None:
"""Creates a new ``DatabricksCreateJobsOperator``."""
super().__init__(**kwargs)
self.json = json or {}
self.databricks_conn_id = databricks_conn_id
self.polling_period_seconds = polling_period_seconds
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
if name is not None:
self.json["name"] = name
if tags is not None:
self.json["tags"] = tags
if tasks is not None:
self.json["tasks"] = tasks
if job_clusters is not None:
self.json["job_clusters"] = job_clusters
if email_notifications is not None:
self.json["email_notifications"] = email_notifications
if webhook_notifications is not None:
self.json["webhook_notifications"] = webhook_notifications
if timeout_seconds is not None:
self.json["timeout_seconds"] = timeout_seconds
if schedule is not None:
self.json["schedule"] = schedule
if max_concurrent_runs is not None:
self.json["max_concurrent_runs"] = max_concurrent_runs
if git_source is not None:
self.json["git_source"] = git_source
if access_control_list is not None:
self.json["access_control_list"] = access_control_list

self.json = normalise_json_content(self.json)

@cached_property
def _hook(self):
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller="DatabricksCreateJobsOperator",
)

def execute(self, context: Context) -> int:
if "name" not in self.json:
raise AirflowException("Missing required parameter: name")
job_id = self._hook.find_job_id_by_name(self.json["name"])
if job_id is None:
return self._hook.create_job(self.json)
self._hook.reset_job(str(job_id), self.json)
return job_id


class DatabricksSubmitRunOperator(BaseOperator):
"""
Submits a Spark job run to Databricks using the api/2.1/jobs/runs/submit API endpoint.
Expand Down
8 changes: 8 additions & 0 deletions airflow/providers/databricks/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ integrations:
- integration-name: Databricks
external-doc-url: https://databricks.com/
how-to-guide:
- /docs/apache-airflow-providers-databricks/operators/jobs_create.rst
- /docs/apache-airflow-providers-databricks/operators/submit_run.rst
- /docs/apache-airflow-providers-databricks/operators/run_now.rst
logo: /integration-logos/databricks/Databricks.png
Expand Down Expand Up @@ -123,3 +124,10 @@ connection-types:

extra-links:
- airflow.providers.databricks.operators.databricks.DatabricksJobRunLink

additional-extras:
# pip install apache-airflow-providers-databricks[sdk]
- name: sdk
description: Install Databricks SDK
dependencies:
- databricks-sdk==0.10.0
91 changes: 91 additions & 0 deletions docs/apache-airflow-providers-databricks/operators/jobs_create.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
DatabricksCreateJobsOperator
============================

Use the :class:`~airflow.providers.databricks.operators.DatabricksCreateJobsOperator` to create
(or reset) a Databricks job. This operator relies on past XComs to remember the ``job_id`` that
was created so that repeated calls with this operator will update the existing job rather than
creating new ones. When paired with the DatabricksRunNowOperator all runs will fall under the same
job within the Databricks UI.


Using the Operator
------------------

There are three ways to instantiate this operator. In the first way, you can take the JSON payload that you typically use
to call the ``api/2.1/jobs/create`` endpoint and pass it directly to our ``DatabricksCreateJobsOperator`` through the
``json`` parameter. With this approach you get full control over the underlying payload to Jobs REST API, including
execution of Databricks jobs with multiple tasks, but it's harder to detect errors because of the lack of the type checking.

The second way to accomplish the same thing is to use the named parameters of the ``DatabricksCreateJobsOperator`` directly. Note that there is exactly
one named parameter for each top level parameter in the ``api/2.1/jobs/create`` endpoint.

The third way is to use both the json parameter **AND** the named parameters. They will be merged
together. If there are conflicts during the merge, the named parameters will take precedence and
override the top level ``json`` keys.

Currently the named parameters that ``DatabricksCreateJobsOperator`` supports are:
- ``name``
- ``tags``
- ``tasks``
- ``job_clusters``
- ``email_notifications``
- ``webhook_notifications``
- ``timeout_seconds``
- ``schedule``
- ``max_concurrent_runs``
- ``git_source``
- ``access_control_list``


Examples
--------

Specifying parameters as JSON
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

An example usage of the DatabricksCreateJobsOperator is as follows:

.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py
:language: python
:start-after: [START howto_operator_databricks_jobs_create_json]
:end-before: [END howto_operator_databricks_jobs_create_json]

Using named parameters
^^^^^^^^^^^^^^^^^^^^^^

You can also use named parameters to initialize the operator and run the job.

.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py
:language: python
:start-after: [START howto_operator_databricks_jobs_create_named]
:end-before: [END howto_operator_databricks_jobs_create_named]

Pairing with DatabricksRunNowOperator
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

You can use the ``job_id`` that is returned by the DatabricksCreateJobsOperator in the
return_value XCom as an argument to the DatabricksRunNowOperator to run the job.

.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py
:language: python
:start-after: [START howto_operator_databricks_run_now]
:end-before: [END howto_operator_databricks_run_now]
51 changes: 51 additions & 0 deletions tests/providers/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@
}


def create_endpoint(host):
"""
Utility function to generate the create endpoint given the host.
"""
return f"https://{host}/api/2.1/jobs/create"


def reset_endpoint(host):
"""
Utility function to generate the reset endpoint given the host.
"""
return f"https://{host}/api/2.1/jobs/reset"


def run_now_endpoint(host):
"""
Utility function to generate the run now endpoint given the host.
Expand Down Expand Up @@ -387,6 +401,43 @@ def test_do_api_call_patch(self, mock_requests):
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_create(self, mock_requests):
mock_requests.codes.ok = 200
mock_requests.post.return_value.json.return_value = {"job_id": JOB_ID}
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock
json = {"name": "test"}
job_id = self.hook.create_job(json)

assert job_id == JOB_ID

mock_requests.post.assert_called_once_with(
create_endpoint(HOST),
json={"name": "test"},
params=None,
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_reset(self, mock_requests):
mock_requests.codes.ok = 200
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock
json = {"name": "test"}
self.hook.reset_job(JOB_ID, json)

mock_requests.post.assert_called_once_with(
reset_endpoint(HOST),
json={"job_id": JOB_ID, "new_settings": {"name": "test"}},
params=None,
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_submit_run(self, mock_requests):
mock_requests.post.return_value.json.return_value = {"run_id": "1"}
Expand Down
Loading
Loading