From 24a84b51034da204e3d9f6cbbccf57ec7a3a36fb Mon Sep 17 00:00:00 2001 From: Kyle Winkelman Date: Mon, 27 Feb 2023 13:15:10 -0600 Subject: [PATCH 01/19] Provider Databricks add jobs create operator. --- .../providers/databricks/hooks/databricks.py | 20 + .../databricks/operators/databricks.py | 139 ++++++ .../operators/jobs_create.rst | 91 ++++ .../databricks/hooks/test_databricks.py | 51 +++ .../databricks/operators/test_databricks.py | 423 ++++++++++++++++++ .../databricks/example_databricks.py | 66 ++- 6 files changed, 789 insertions(+), 1 deletion(-) create mode 100644 docs/apache-airflow-providers-databricks/operators/jobs_create.rst diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 5d3f714b44e6d..5a7abdc5df8a2 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -41,6 +41,8 @@ START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start") TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete") +CREATE_ENDPOINT = ("POST", "api/2.1/jobs/create") +RESET_ENDPOINT = ("POST", "api/2.1/jobs/reset") RUN_NOW_ENDPOINT = ("POST", "api/2.1/jobs/run-now") SUBMIT_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/submit") GET_RUN_ENDPOINT = ("GET", "api/2.1/jobs/runs/get") @@ -194,6 +196,24 @@ def __init__( ) -> None: super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay, retry_args, caller) + def create(self, json: dict) -> int: + """ + Utility function to call the ``api/2.1/jobs/create`` endpoint. + + :param json: The data used in the body of the request to the ``create`` endpoint. + :return: the job_id as an int + """ + response = self._do_api_call(CREATE_ENDPOINT, json) + return response["job_id"] + + def reset(self, job_id: str, json: dict): + """ + Utility function to call the ``api/2.1/jobs/reset`` endpoint. + + :param json: The data used in the new_settings of the request to the ``reset`` endpoint. + """ + self._do_api_call(RESET_ENDPOINT, {"job_id": job_id, "new_settings": json}) + def run_now(self, json: dict) -> int: """ Call the ``api/2.1/jobs/run-now`` endpoint. diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index fd9e256230f7c..fc7ad114ecc81 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -162,6 +162,145 @@ def get_link( return XCom.get_value(key=XCOM_RUN_PAGE_URL_KEY, ti_key=ti_key) +class DatabricksJobsCreateOperator(BaseOperator): + """ + Creates (or resets) a Databricks job using the + `api/2.1/jobs/create + `_ + (or `api/2.1/jobs/reset + `_) + API endpoint. + + .. seealso:: + https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate + + :param json: A JSON object containing API parameters which will be passed + directly to the ``api/2.1/jobs/create`` endpoint. The other named parameters + (i.e. ``name``, ``tags``, ``tasks``, etc.) to this operator will + be merged with this json dictionary if they are provided. + If there are conflicts during the merge, the named parameters will + take precedence and override the top level json keys. (templated) + + .. seealso:: + For more information about templating see :ref:`concepts:jinja-templating`. + :param name: An optional name for the job. + :param tags: A map of tags associated with the job. + :param tasks: A list of task specifications to be executed by this job. + Array of objects (JobTaskSettings). + :param job_clusters: A list of job cluster specifications that can be shared and reused by + tasks of this job. Array of objects (JobCluster). + :param email_notifications: Object (JobEmailNotifications). + :param webhook_notifications: Object (WebhookNotifications). + :param timeout_seconds: An optional timeout applied to each run of this job. + :param schedule: Object (CronSchedule). + :param max_concurrent_runs: An optional maximum allowed number of concurrent runs of the job. + :param git_source: An optional specification for a remote repository containing the notebooks + used by this job's notebook tasks. Object (GitSource). + :param access_control_list: List of permissions to set on the job. Array of object + (AccessControlRequestForUser) or object (AccessControlRequestForGroup) or object + (AccessControlRequestForServicePrincipal). + + .. seealso:: + This will only be used on create. In order to reset ACL consider using the Databricks + UI. + :param databricks_conn_id: Reference to the + :ref:`Databricks connection `. (templated) + :param polling_period_seconds: Controls the rate which we poll for the result of + this run. By default the operator will poll every 30 seconds. + :param databricks_retry_limit: Amount of times retry if the Databricks backend is + unreachable. Its value must be greater than or equal to 1. + :param databricks_retry_delay: Number of seconds to wait between retries (it + might be a floating point number). + :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. + """ + + # Used in airflow.models.BaseOperator + template_fields: Sequence[str] = ("json", "databricks_conn_id") + # Databricks brand color (blue) under white text + ui_color = "#1CB1C2" + ui_fgcolor = "#fff" + + def __init__( + self, + *, + json: Any | None = None, + name: str | None = None, + tags: dict[str, str] | None = None, + tasks: list[object] | None = None, + job_clusters: list[object] | None = None, + email_notifications: object | None = None, + webhook_notifications: object | None = None, + timeout_seconds: int | None = None, + schedule: dict[str, str] | None = None, + max_concurrent_runs: int | None = None, + git_source: dict[str, str] | None = None, + access_control_list: list[dict[str, str]] | None = None, + databricks_conn_id: str = "databricks_default", + polling_period_seconds: int = 30, + databricks_retry_limit: int = 3, + databricks_retry_delay: int = 1, + databricks_retry_args: dict[Any, Any] | None = None, + **kwargs, + ) -> None: + """Creates a new ``DatabricksJobsCreateOperator``.""" + super().__init__(**kwargs) + self.json = json or {} + self.databricks_conn_id = databricks_conn_id + self.polling_period_seconds = polling_period_seconds + self.databricks_retry_limit = databricks_retry_limit + self.databricks_retry_delay = databricks_retry_delay + self.databricks_retry_args = databricks_retry_args + if name is not None: + self.json["name"] = name + if tags is not None: + self.json["tags"] = tags + if tasks is not None: + self.json["tasks"] = tasks + if job_clusters is not None: + self.json["job_clusters"] = job_clusters + if email_notifications is not None: + self.json["email_notifications"] = email_notifications + if webhook_notifications is not None: + self.json["webhook_notifications"] = webhook_notifications + if timeout_seconds is not None: + self.json["timeout_seconds"] = timeout_seconds + if schedule is not None: + self.json["schedule"] = schedule + if max_concurrent_runs is not None: + self.json["max_concurrent_runs"] = max_concurrent_runs + if git_source is not None: + self.json["git_source"] = git_source + if access_control_list is not None: + self.json["access_control_list"] = access_control_list + + self.json = normalise_json_content(self.json) + + @cached_property + def _hook(self): + return self._get_hook(caller="DatabricksJobsCreateOperator") + + def _get_hook(self, caller: str) -> DatabricksHook: + return DatabricksHook( + self.databricks_conn_id, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=caller, + ) + + def execute(self, context: Context) -> int: + self.job_id = self.xcom_pull( + context, + task_ids=self.task_id, + include_prior_dates=True, + ) + if self.job_id: + self._hook.reset(self.job_id, self.json) + else: + self.job_id = self._hook.create(self.json) + return self.job_id + + class DatabricksSubmitRunOperator(BaseOperator): """ Submits a Spark job run to Databricks using the api/2.1/jobs/runs/submit API endpoint. diff --git a/docs/apache-airflow-providers-databricks/operators/jobs_create.rst b/docs/apache-airflow-providers-databricks/operators/jobs_create.rst new file mode 100644 index 0000000000000..a9a69976768c3 --- /dev/null +++ b/docs/apache-airflow-providers-databricks/operators/jobs_create.rst @@ -0,0 +1,91 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +DatabricksJobsCreateOperator +============================ + +Use the :class:`~airflow.providers.databricks.operators.DatabricksJobsCreateOperator` to create +(or reset) a Databricks job. This operator relies on past XComs to remember the ``job_id`` that +was created so that repeated calls with this operator will update the existing job rather than +creating new ones. When paired with the DatabricksRunNowOperator all runs will fall under the same +job within the Databricks UI. + + +Using the Operator +------------------ + +There are three ways to instantiate this operator. In the first way, you can take the JSON payload that you typically use +to call the ``api/2.1/jobs/create`` endpoint and pass it directly to our ``DatabricksJobsCreateOperator`` through the +``json`` parameter. With this approach you get full control over the underlying payload to Jobs REST API, including +execution of Databricks jobs with multiple tasks, but it's harder to detect errors because of the lack of the type checking. + +The second way to accomplish the same thing is to use the named parameters of the ``DatabricksJobsCreateOperator`` directly. Note that there is exactly +one named parameter for each top level parameter in the ``api/2.1/jobs/create`` endpoint. + +The third way is to use both the json parameter **AND** the named parameters. They will be merged +together. If there are conflicts during the merge, the named parameters will take precedence and +override the top level ``json`` keys. + +Currently the named parameters that ``DatabricksJobsCreateOperator`` supports are: + - ``name`` + - ``tags`` + - ``tasks`` + - ``job_clusters`` + - ``email_notifications`` + - ``webhook_notifications`` + - ``timeout_seconds`` + - ``schedule`` + - ``max_concurrent_runs`` + - ``git_source`` + - ``access_control_list`` + + +Examples +-------- + +Specifying parameters as JSON +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +An example usage of the DatabricksJobsCreateOperator is as follows: + +.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py + :language: python + :start-after: [START howto_operator_databricks_jobs_create_json] + :end-before: [END howto_operator_databricks_jobs_create_json] + +Using named parameters +^^^^^^^^^^^^^^^^^^^^^^ + +You can also use named parameters to initialize the operator and run the job. + +.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py + :language: python + :start-after: [START howto_operator_databricks_jobs_create_named] + :end-before: [END howto_operator_databricks_jobs_create_named] + +Pairing with DatabricksRunNowOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +You can use the ``job_id`` that is returned by the DatabricksJobsCreateOperator in the +return_value XCom as an argument to the DatabricksRunNowOperator to run the job. + +.. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py + :language: python + :start-after: [START howto_operator_databricks_run_now] + :end-before: [END howto_operator_databricks_run_now] diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 2566e9a394769..dde00b7c7322f 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -107,6 +107,20 @@ } +def create_endpoint(host): + """ + Utility function to generate the create endpoint given the host. + """ + return f"https://{host}/api/2.1/jobs/create" + + +def reset_endpoint(host): + """ + Utility function to generate the reset endpoint given the host. + """ + return f"https://{host}/api/2.1/jobs/reset" + + def run_now_endpoint(host): """ Utility function to generate the run now endpoint given the host. @@ -387,6 +401,43 @@ def test_do_api_call_patch(self, mock_requests): timeout=self.hook.timeout_seconds, ) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") + def test_create(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.return_value.json.return_value = {"job_id": JOB_ID} + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + json = {"name": "test"} + job_id = self.hook.create(json) + + assert job_id == JOB_ID + + mock_requests.post.assert_called_once_with( + create_endpoint(HOST), + json={"name": "test"}, + params=None, + auth=HTTPBasicAuth(LOGIN, PASSWORD), + headers=self.hook.user_agent_header, + timeout=self.hook.timeout_seconds, + ) + + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") + def test_reset(self, mock_requests): + mock_requests.codes.ok = 200 + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + json = {"name": "test"} + self.hook.reset(JOB_ID, json) + + mock_requests.post.assert_called_once_with( + reset_endpoint(HOST), + json={"job_id": JOB_ID, "new_settings": {"name": "test"}}, + params=None, + auth=HTTPBasicAuth(LOGIN, PASSWORD), + headers=self.hook.user_agent_header, + timeout=self.hook.timeout_seconds, + ) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") def test_submit_run(self, mock_requests): mock_requests.post.return_value.json.return_value = {"run_id": "1"} diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index e03eb7dcccfc1..ffb5fb12240d2 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -27,6 +27,7 @@ from airflow.models import DAG from airflow.providers.databricks.hooks.databricks import RunState from airflow.providers.databricks.operators.databricks import ( + DatabricksJobsCreateOperator, DatabricksRunNowDeferrableOperator, DatabricksRunNowOperator, DatabricksSubmitRunDeferrableOperator, @@ -71,6 +72,163 @@ "schema": "jaffle_shop", "warehouse_id": "123456789abcdef0", } +TAGS = { + "cost-center": "engineering", + "team": "jobs", +} +TASKS = [ + { + "task_key": "Sessionize", + "description": "Extracts session data from events", + "existing_cluster_id": "0923-164208-meows279", + "spark_jar_task": { + "main_class_name": "com.databricks.Sessionize", + "parameters": [ + "--data", + "dbfs:/path/to/data.json", + ] + }, + "libraries": [ + { + "jar": "dbfs:/mnt/databricks/Sessionize.jar" + }, + ], + "timeout_seconds": 86400, + "max_retries": 3, + "min_retry_interval_millis": 2000, + "retry_on_timeout": False, + }, + { + "task_key": "Orders_Ingest", + "description": "Ingests order data", + "job_cluster_key": "auto_scaling_cluster", + "spark_jar_task": { + "main_class_name": "com.databricks.OrdersIngest", + "parameters": [ + "--data", + "dbfs:/path/to/order-data.json" + ], + }, + "libraries": [ + { + "jar": "dbfs:/mnt/databricks/OrderIngest.jar" + }, + ], + "timeout_seconds": 86400, + "max_retries": 3, + "min_retry_interval_millis": 2000, + "retry_on_timeout": False, + }, + { + "task_key": "Match", + "description": "Matches orders with user sessions", + "depends_on": [ + { + "task_key": "Orders_Ingest" + }, + { + "task_key": "Sessionize" + }, + ], + "new_cluster": { + "spark_version": "7.3.x-scala2.12", + "node_type_id": "i3.xlarge", + "spark_conf": { + "spark.speculation": True, + }, + "aws_attributes": { + "availability": "SPOT", + "zone_id": "us-west-2a", + }, + "autoscale": { + "min_workers": 2, + "max_workers": 16, + } + }, + "notebook_task": { + "notebook_path": "/Users/user.name@databricks.com/Match", + "source": "WORKSPACE", + "base_parameters": { + "name": "John Doe", + "age": "35", + } + }, + "timeout_seconds": 86400, + "max_retries": 3, + "min_retry_interval_millis": 2000, + "retry_on_timeout": False, + }, +] +JOB_CLUSTERS = [ + { + "job_cluster_key": "auto_scaling_cluster", + "new_cluster": { + "spark_version": "7.3.x-scala2.12", + "node_type_id": "i3.xlarge", + "spark_conf": { + "spark.speculation": True, + }, + "aws_attributes": { + "availability": "SPOT", + "zone_id": "us-west-2a", + }, + "autoscale": { + "min_workers": 2, + "max_workers": 16, + } + } + }, +] +EMAIL_NOTIFICATIONS = { + "on_start": [ + "user.name@databricks.com", + ], + "on_success": [ + "user.name@databricks.com", + ], + "on_failure": [ + "user.name@databricks.com", + ], + "no_alert_for_skipped_runs": False, +} +WEBHOOK_NOTIFICATIONS = { + "on_start": [ + { + "id": "03dd86e4-57ef-4818-a950-78e41a1d71ab", + }, + { + "id": "0481e838-0a59-4eff-9541-a4ca6f149574", + }, + ], + "on_success": [ + { + "id": "03dd86e4-57ef-4818-a950-78e41a1d71ab", + } + ], + "on_failure": [ + { + "id": "0481e838-0a59-4eff-9541-a4ca6f149574", + } + ], +} +TIMEOUT_SECONDS = 86400 +SCHEDULE = { + "quartz_cron_expression": "20 30 * * * ?", + "timezone_id": "Europe/London", + "pause_status": "PAUSED" +} +MAX_CONCURRENT_RUNS = 10 +GIT_SOURCE = { + "git_url": "https://github.com/databricks/databricks-cli", + "git_branch": "main", + "git_provider": "gitHub" +} +ACCESS_CONTROL_LIST = [ + { + "user_name": "jsmith@example.com", + "permission_level": "CAN_MANAGE", + } +] def mock_dict(d: dict): @@ -94,6 +252,271 @@ def make_run_with_state_mock( } ) +class TestDatabricksJobsCreateOperator: + def test_init_with_named_parameters(self): + """ + Test the initializer with the named parameters. + """ + op = DatabricksJobsCreateOperator( + task_id=TASK_ID, + name=JOB_NAME, + tags=TAGS, + tasks=TASKS, + job_clusters=JOB_CLUSTERS, + email_notifications=EMAIL_NOTIFICATIONS, + webhook_notifications=WEBHOOK_NOTIFICATIONS, + timeout_seconds=TIMEOUT_SECONDS, + schedule=SCHEDULE, + max_concurrent_runs=MAX_CONCURRENT_RUNS, + git_source=GIT_SOURCE, + access_control_list=ACCESS_CONTROL_LIST, + ) + expected = utils.normalise_json_content( + { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + ) + + assert expected == op.json + + def test_init_with_json(self): + """ + Test the initializer with json data. + """ + json = { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + op = DatabricksJobsCreateOperator(task_id=TASK_ID, json=json) + + expected = utils.normalise_json_content( + { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + ) + + assert expected == op.json + + def test_init_with_merging(self): + """ + Test the initializer when json and other named parameters are both + provided. The named parameters should override top level keys in the + json dict. + """ + override_name = "override" + override_tags = {} + override_tasks = [] + override_job_clusters = [] + override_email_notifications = {} + override_webhook_notifications = {} + override_timeout_seconds = 0 + override_schedule = {} + override_max_concurrent_runs = 0 + override_git_source = {} + override_access_control_list = [] + json = { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + + op = DatabricksJobsCreateOperator( + task_id=TASK_ID, + json=json, + name=override_name, + tags=override_tags, + tasks=override_tasks, + job_clusters=override_job_clusters, + email_notifications=override_email_notifications, + webhook_notifications=override_webhook_notifications, + timeout_seconds=override_timeout_seconds, + schedule=override_schedule, + max_concurrent_runs=override_max_concurrent_runs, + git_source=override_git_source, + access_control_list=override_access_control_list, + ) + + expected = utils.normalise_json_content( + { + "name": override_name, + "tags": override_tags, + "tasks": override_tasks, + "job_clusters": override_job_clusters, + "email_notifications": override_email_notifications, + "webhook_notifications": override_webhook_notifications, + "timeout_seconds": override_timeout_seconds, + "schedule": override_schedule, + "max_concurrent_runs": override_max_concurrent_runs, + "git_source": override_git_source, + "access_control_list": override_access_control_list, + } + ) + + assert expected == op.json + + def test_init_with_templating(self): + json = {"name": "test-{{ ds }}"} + + dag = DAG("test", start_date=datetime.now()) + op = DatabricksJobsCreateOperator(dag=dag, task_id=TASK_ID, json=json) + op.render_template_fields(context={"ds": DATE}) + expected = utils.normalise_json_content( + { + "name": f"test-{DATE}" + } + ) + assert expected == op.json + + def test_init_with_bad_type(self): + json = {"test": datetime.now()} + # Looks a bit weird since we have to escape regex reserved symbols. + exception_message = ( + r"Type \<(type|class) \'datetime.datetime\'\> used " + r"for parameter json\[test\] is not a number or a string" + ) + with pytest.raises(AirflowException, match=exception_message): + DatabricksJobsCreateOperator(task_id=TASK_ID, json=json) + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_exec_create(self, db_mock_class): + """ + Test the execute function in case where the job does not exist. + """ + json = { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + op = DatabricksJobsCreateOperator(task_id=TASK_ID, json=json) + db_mock = db_mock_class.return_value + db_mock.create.return_value = JOB_ID + + ti = mock.MagicMock() + ti.xcom_pull.return_value = None + op.execute({"ti": ti}) + + expected = utils.normalise_json_content( + { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksJobsCreateOperator", + ) + + db_mock.create.assert_called_once_with(expected) + assert JOB_ID == op.job_id + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_exec_reset(self, db_mock_class): + """ + Test the execute function in case where the job already exists. + """ + json = { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + op = DatabricksJobsCreateOperator(task_id=TASK_ID, json=json) + db_mock = db_mock_class.return_value + + ti = mock.MagicMock() + ti.xcom_pull.return_value = JOB_ID + op.execute({"ti": ti}) + + expected = utils.normalise_json_content( + { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksJobsCreateOperator", + ) + + db_mock.reset.assert_called_once_with(JOB_ID, expected) + assert JOB_ID == op.job_id + class TestDatabricksSubmitRunOperator: def test_init_with_notebook_task_named_parameters(self): diff --git a/tests/system/providers/databricks/example_databricks.py b/tests/system/providers/databricks/example_databricks.py index a6cae8a518a98..b568fa80471b5 100644 --- a/tests/system/providers/databricks/example_databricks.py +++ b/tests/system/providers/databricks/example_databricks.py @@ -36,7 +36,11 @@ from datetime import datetime from airflow import DAG -from airflow.providers.databricks.operators.databricks import DatabricksSubmitRunOperator +from airflow.providers.databricks.operators.databricks import ( + DatabricksJobsCreateOperator, + DatabricksRunNowOperator, + DatabricksSubmitRunOperator, +) ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "example_databricks_operator" @@ -48,6 +52,66 @@ tags=["example"], catchup=False, ) as dag: + + # [START howto_operator_databricks_jobs_create_json] + # Example of using the JSON parameter to initialize the operator. + job = { + "tasks": [ + { + "task_key": "test", + "job_cluster_key": "job_cluster", + "notebook_task": { + "notebook_path": "/Shared/test", + }, + }, + ], + "job_clusters": [ + { + "job_cluster_key": "job_cluster", + "new_cluster": { + "spark_version": "7.3.x-scala2.12", + "node_type_id": "i3.xlarge", + "num_workers": 2, + }, + }, + ] + } + + jobs_create_json = DatabricksJobsCreateOperator(task_id="jobs_create_json", json=job) + # [END howto_operator_databricks_jobs_create_json] + + # [START howto_operator_databricks_jobs_create_named] + # Example of using the named parameters to initialize the operator. + tasks = [ + { + "task_key": "test", + "job_cluster_key": "job_cluster", + "notebook_task": { + "notebook_path": "/Shared/test", + }, + }, + ] + job_clusters = [ + { + "job_cluster_key": "job_cluster", + "new_cluster": { + "spark_version": "7.3.x-scala2.12", + "node_type_id": "i3.xlarge", + "num_workers": 2, + }, + }, + ] + + jobs_create_named = DatabricksJobsCreateOperator(task_id="jobs_create_named", tasks=tasks, job_clusters=job_clusters) + # [END howto_operator_databricks_jobs_create_named] + + # [START howto_operator_databricks_run_now] + # Example of using the DatabricksRunNowOperator after creating a job with DatabricksJobsCreateOperator. + run_now = DatabricksRunNowOperator(task_id="run_now", job_id="{{ ti.xcom_pull(task_ids='jobs_create_named') }}") + + jobs_create_named >> run_now + # [END howto_operator_databricks_run_now] + # [START howto_operator_databricks_json] # Example of using the JSON parameter to initialize the operator. new_cluster = { From 3798939b519628f038d16cc9582f0c8fcd95d973 Mon Sep 17 00:00:00 2001 From: Sri Tikkireddy Date: Tue, 27 Jun 2023 14:07:32 -0400 Subject: [PATCH 02/19] run black formatter with breeze --- .../providers/databricks/hooks/databricks.py | 2 +- .../databricks/operators/test_databricks.py | 48 +++++++------------ .../databricks/example_databricks.py | 10 ++-- 3 files changed, 25 insertions(+), 35 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 5a7abdc5df8a2..4e9a0b3dc0fca 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -212,7 +212,7 @@ def reset(self, job_id: str, json: dict): :param json: The data used in the new_settings of the request to the ``reset`` endpoint. """ - self._do_api_call(RESET_ENDPOINT, {"job_id": job_id, "new_settings": json}) + self._do_api_call(RESET_ENDPOINT, {"job_id": job_id, "new_settings": json}) def run_now(self, json: dict) -> int: """ diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index ffb5fb12240d2..4dfdae0bde962 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -86,12 +86,10 @@ "parameters": [ "--data", "dbfs:/path/to/data.json", - ] + ], }, "libraries": [ - { - "jar": "dbfs:/mnt/databricks/Sessionize.jar" - }, + {"jar": "dbfs:/mnt/databricks/Sessionize.jar"}, ], "timeout_seconds": 86400, "max_retries": 3, @@ -104,15 +102,10 @@ "job_cluster_key": "auto_scaling_cluster", "spark_jar_task": { "main_class_name": "com.databricks.OrdersIngest", - "parameters": [ - "--data", - "dbfs:/path/to/order-data.json" - ], + "parameters": ["--data", "dbfs:/path/to/order-data.json"], }, "libraries": [ - { - "jar": "dbfs:/mnt/databricks/OrderIngest.jar" - }, + {"jar": "dbfs:/mnt/databricks/OrderIngest.jar"}, ], "timeout_seconds": 86400, "max_retries": 3, @@ -123,12 +116,8 @@ "task_key": "Match", "description": "Matches orders with user sessions", "depends_on": [ - { - "task_key": "Orders_Ingest" - }, - { - "task_key": "Sessionize" - }, + {"task_key": "Orders_Ingest"}, + {"task_key": "Sessionize"}, ], "new_cluster": { "spark_version": "7.3.x-scala2.12", @@ -143,7 +132,7 @@ "autoscale": { "min_workers": 2, "max_workers": 16, - } + }, }, "notebook_task": { "notebook_path": "/Users/user.name@databricks.com/Match", @@ -151,7 +140,7 @@ "base_parameters": { "name": "John Doe", "age": "35", - } + }, }, "timeout_seconds": 86400, "max_retries": 3, @@ -175,19 +164,19 @@ "autoscale": { "min_workers": 2, "max_workers": 16, - } - } + }, + }, }, ] EMAIL_NOTIFICATIONS = { "on_start": [ - "user.name@databricks.com", + "user.name@databricks.com", ], "on_success": [ - "user.name@databricks.com", + "user.name@databricks.com", ], "on_failure": [ - "user.name@databricks.com", + "user.name@databricks.com", ], "no_alert_for_skipped_runs": False, } @@ -215,13 +204,13 @@ SCHEDULE = { "quartz_cron_expression": "20 30 * * * ?", "timezone_id": "Europe/London", - "pause_status": "PAUSED" + "pause_status": "PAUSED", } MAX_CONCURRENT_RUNS = 10 GIT_SOURCE = { "git_url": "https://github.com/databricks/databricks-cli", "git_branch": "main", - "git_provider": "gitHub" + "git_provider": "gitHub", } ACCESS_CONTROL_LIST = [ { @@ -252,6 +241,7 @@ def make_run_with_state_mock( } ) + class TestDatabricksJobsCreateOperator: def test_init_with_named_parameters(self): """ @@ -397,11 +387,7 @@ def test_init_with_templating(self): dag = DAG("test", start_date=datetime.now()) op = DatabricksJobsCreateOperator(dag=dag, task_id=TASK_ID, json=json) op.render_template_fields(context={"ds": DATE}) - expected = utils.normalise_json_content( - { - "name": f"test-{DATE}" - } - ) + expected = utils.normalise_json_content({"name": f"test-{DATE}"}) assert expected == op.json def test_init_with_bad_type(self): diff --git a/tests/system/providers/databricks/example_databricks.py b/tests/system/providers/databricks/example_databricks.py index b568fa80471b5..2f931aec93bff 100644 --- a/tests/system/providers/databricks/example_databricks.py +++ b/tests/system/providers/databricks/example_databricks.py @@ -74,7 +74,7 @@ "num_workers": 2, }, }, - ] + ], } jobs_create_json = DatabricksJobsCreateOperator(task_id="jobs_create_json", json=job) @@ -102,12 +102,16 @@ }, ] - jobs_create_named = DatabricksJobsCreateOperator(task_id="jobs_create_named", tasks=tasks, job_clusters=job_clusters) + jobs_create_named = DatabricksJobsCreateOperator( + task_id="jobs_create_named", tasks=tasks, job_clusters=job_clusters + ) # [END howto_operator_databricks_jobs_create_named] # [START howto_operator_databricks_run_now] # Example of using the DatabricksRunNowOperator after creating a job with DatabricksJobsCreateOperator. - run_now = DatabricksRunNowOperator(task_id="run_now", job_id="{{ ti.xcom_pull(task_ids='jobs_create_named') }}") + run_now = DatabricksRunNowOperator( + task_id="run_now", job_id="{{ ti.xcom_pull(task_ids='jobs_create_named') }}" + ) jobs_create_named >> run_now # [END howto_operator_databricks_run_now] From 9f382fef70fa2e05ce2047f830a5c3c10ed2a096 Mon Sep 17 00:00:00 2001 From: Sri Tikkireddy Date: Tue, 27 Jun 2023 17:28:14 -0400 Subject: [PATCH 03/19] added support for databricks sdk to use the latest set of objects for type hints --- airflow/providers/databricks/provider.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index bfaed1bc70c5c..eae192906f989 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -61,6 +61,7 @@ dependencies: # The 2.9.1 (to be released soon) already contains the fix - databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0 - aiohttp>=3.6.3, <4 + - databricks-sdk>=0.1.11, <1.0.0 integrations: - integration-name: Databricks From e76687ea56a5d0e91ad0e23ed0d4412f99959cd2 Mon Sep 17 00:00:00 2001 From: Sri Tikkireddy Date: Tue, 27 Jun 2023 17:53:49 -0400 Subject: [PATCH 04/19] remove without precommit --- airflow/providers/databricks/provider.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index eae192906f989..bfaed1bc70c5c 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -61,7 +61,6 @@ dependencies: # The 2.9.1 (to be released soon) already contains the fix - databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0 - aiohttp>=3.6.3, <4 - - databricks-sdk>=0.1.11, <1.0.0 integrations: - integration-name: Databricks From b6302a80e80ecdb04b24943d039a3aa61ddc7c09 Mon Sep 17 00:00:00 2001 From: Sri Tikkireddy Date: Tue, 27 Jun 2023 17:56:21 -0400 Subject: [PATCH 05/19] added databricks-sdk with precommit --- airflow/providers/databricks/provider.yaml | 1 + generated/provider_dependencies.json | 1 + 2 files changed, 2 insertions(+) diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index bfaed1bc70c5c..eae192906f989 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -61,6 +61,7 @@ dependencies: # The 2.9.1 (to be released soon) already contains the fix - databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0 - aiohttp>=3.6.3, <4 + - databricks-sdk>=0.1.11, <1.0.0 integrations: - integration-name: Databricks diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 772fa4f78a199..b9e92b5bc8dee 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -301,6 +301,7 @@ "aiohttp>=3.6.3, <4", "apache-airflow-providers-common-sql>=1.5.0", "apache-airflow>=2.5.0", + "databricks-sdk>=0.1.11, <1.0.0", "databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0", "requests>=2.27,<3" ], From 44d2f246d8f5d502d5618a277f81243a15568529 Mon Sep 17 00:00:00 2001 From: Sri Tikkireddy Date: Tue, 27 Jun 2023 18:36:24 -0400 Subject: [PATCH 06/19] use the databricks sdk objects --- .../databricks/operators/databricks.py | 55 +++++++++++-------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index fc7ad114ecc81..bf00b0fe21cbd 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -21,7 +21,10 @@ import time import warnings from functools import cached_property -from typing import TYPE_CHECKING, Any, Sequence +from logging import Logger +from typing import TYPE_CHECKING, Any, Sequence, Union + +from databricks.sdk.service import jobs as j from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning @@ -40,6 +43,15 @@ XCOM_RUN_ID_KEY = "run_id" XCOM_JOB_ID_KEY = "job_id" XCOM_RUN_PAGE_URL_KEY = "run_page_url" +DatabricksTaskType = Union[ + j.DbtTask, + j.NotebookTask, + j.PipelineTask, + j.PythonWheelTask, + j.SparkPythonTask, + j.SparkJarTask, + j.SparkSubmitTask, +] def _handle_databricks_operator_execution(operator, hook, log, context) -> None: @@ -226,15 +238,15 @@ def __init__( json: Any | None = None, name: str | None = None, tags: dict[str, str] | None = None, - tasks: list[object] | None = None, - job_clusters: list[object] | None = None, - email_notifications: object | None = None, - webhook_notifications: object | None = None, + tasks: list[DatabricksTaskType] | None = None, + job_clusters: list[j.JobCluster] | None = None, + email_notifications: j.JobEmailNotifications | None = None, + webhook_notifications: j.JobWebhookNotifications | None = None, timeout_seconds: int | None = None, - schedule: dict[str, str] | None = None, + schedule: j.CronSchedule | None = None, max_concurrent_runs: int | None = None, - git_source: dict[str, str] | None = None, - access_control_list: list[dict[str, str]] | None = None, + git_source: j.GitSource | None = None, + access_control_list: j.AccessControlRequest | None = None, databricks_conn_id: str = "databricks_default", polling_period_seconds: int = 30, databricks_retry_limit: int = 3, @@ -255,17 +267,17 @@ def __init__( if tags is not None: self.json["tags"] = tags if tasks is not None: - self.json["tasks"] = tasks + self.json["tasks"] = [task.as_dict() for task in tasks] if job_clusters is not None: - self.json["job_clusters"] = job_clusters + self.json["job_clusters"] = [job_cluster.as_dict() for job_cluster in job_clusters] if email_notifications is not None: - self.json["email_notifications"] = email_notifications + self.json["email_notifications"] = email_notifications.as_dict() if webhook_notifications is not None: - self.json["webhook_notifications"] = webhook_notifications + self.json["webhook_notifications"] = webhook_notifications.as_dict() if timeout_seconds is not None: self.json["timeout_seconds"] = timeout_seconds if schedule is not None: - self.json["schedule"] = schedule + self.json["schedule"] = schedule.as_dict() if max_concurrent_runs is not None: self.json["max_concurrent_runs"] = max_concurrent_runs if git_source is not None: @@ -289,16 +301,13 @@ def _get_hook(self, caller: str) -> DatabricksHook: ) def execute(self, context: Context) -> int: - self.job_id = self.xcom_pull( - context, - task_ids=self.task_id, - include_prior_dates=True, - ) - if self.job_id: - self._hook.reset(self.job_id, self.json) - else: - self.job_id = self._hook.create(self.json) - return self.job_id + if "name" not in self.json: + raise AirflowException("Missing required parameter: name") + job_id = self._hook.find_job_id_by_name(self.json["name"]) + if job_id is None: + return self._hook.create(self.json) + self._hook.reset(job_id, self.json) + return job_id class DatabricksSubmitRunOperator(BaseOperator): From 18374224e1423f4e95a308e5c0cf2a598058792c Mon Sep 17 00:00:00 2001 From: Sri Tikkireddy Date: Tue, 27 Jun 2023 18:52:52 -0400 Subject: [PATCH 07/19] fixed type hints and adjusted tests --- .../databricks/operators/databricks.py | 17 ++++---------- .../databricks/operators/test_databricks.py | 23 ++++++++++--------- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index bf00b0fe21cbd..ab98883b61da5 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -22,7 +22,7 @@ import warnings from functools import cached_property from logging import Logger -from typing import TYPE_CHECKING, Any, Sequence, Union +from typing import TYPE_CHECKING, Any, Sequence from databricks.sdk.service import jobs as j @@ -43,15 +43,6 @@ XCOM_RUN_ID_KEY = "run_id" XCOM_JOB_ID_KEY = "job_id" XCOM_RUN_PAGE_URL_KEY = "run_page_url" -DatabricksTaskType = Union[ - j.DbtTask, - j.NotebookTask, - j.PipelineTask, - j.PythonWheelTask, - j.SparkPythonTask, - j.SparkJarTask, - j.SparkSubmitTask, -] def _handle_databricks_operator_execution(operator, hook, log, context) -> None: @@ -238,7 +229,7 @@ def __init__( json: Any | None = None, name: str | None = None, tags: dict[str, str] | None = None, - tasks: list[DatabricksTaskType] | None = None, + tasks: list[j.JobTaskSettings] | None = None, job_clusters: list[j.JobCluster] | None = None, email_notifications: j.JobEmailNotifications | None = None, webhook_notifications: j.JobWebhookNotifications | None = None, @@ -246,7 +237,7 @@ def __init__( schedule: j.CronSchedule | None = None, max_concurrent_runs: int | None = None, git_source: j.GitSource | None = None, - access_control_list: j.AccessControlRequest | None = None, + access_control_list: list[j.AccessControlRequest] | None = None, databricks_conn_id: str = "databricks_default", polling_period_seconds: int = 30, databricks_retry_limit: int = 3, @@ -283,7 +274,7 @@ def __init__( if git_source is not None: self.json["git_source"] = git_source if access_control_list is not None: - self.json["access_control_list"] = access_control_list + self.json["access_control_list"] = [acl.as_dict() for acl in access_control_list] self.json = normalise_json_content(self.json) diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 4dfdae0bde962..9aaf00c2a1f99 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -22,6 +22,7 @@ from unittest.mock import MagicMock import pytest +from databricks.sdk.service import jobs as j from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG @@ -251,15 +252,15 @@ def test_init_with_named_parameters(self): task_id=TASK_ID, name=JOB_NAME, tags=TAGS, - tasks=TASKS, - job_clusters=JOB_CLUSTERS, - email_notifications=EMAIL_NOTIFICATIONS, - webhook_notifications=WEBHOOK_NOTIFICATIONS, + tasks=[j.JobTaskSettings.from_dict(task) for task in TASKS], + job_clusters=[j.JobCluster.from_dict(cluster) for cluster in JOB_CLUSTERS], + email_notifications=j.JobEmailNotifications.from_dict(EMAIL_NOTIFICATIONS), + webhook_notifications=j.JobWebhookNotifications.from_dict(WEBHOOK_NOTIFICATIONS), timeout_seconds=TIMEOUT_SECONDS, - schedule=SCHEDULE, + schedule=j.CronSchedule.from_dict(SCHEDULE), max_concurrent_runs=MAX_CONCURRENT_RUNS, - git_source=GIT_SOURCE, - access_control_list=ACCESS_CONTROL_LIST, + git_source=j.GitSource.from_dict(GIT_SOURCE), + access_control_list=[j.AccessControlRequest.from_dict(acl) for acl in ACCESS_CONTROL_LIST], ) expected = utils.normalise_json_content( { @@ -354,12 +355,12 @@ def test_init_with_merging(self): tags=override_tags, tasks=override_tasks, job_clusters=override_job_clusters, - email_notifications=override_email_notifications, - webhook_notifications=override_webhook_notifications, + email_notifications=j.JobEmailNotifications.from_dict(override_email_notifications), + webhook_notifications=j.JobWebhookNotifications.from_dict(override_webhook_notifications), timeout_seconds=override_timeout_seconds, - schedule=override_schedule, + schedule=j.CronSchedule.from_dict(override_schedule), max_concurrent_runs=override_max_concurrent_runs, - git_source=override_git_source, + git_source=j.GitSource.from_dict(override_git_source), access_control_list=override_access_control_list, ) From 2c59ba251a9cef46acb79bcbe991dc460692c708 Mon Sep 17 00:00:00 2001 From: Sri Tikkireddy Date: Tue, 27 Jun 2023 18:58:58 -0400 Subject: [PATCH 08/19] fixed as dict --- airflow/providers/databricks/operators/databricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index ab98883b61da5..622e8cc424040 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -272,7 +272,7 @@ def __init__( if max_concurrent_runs is not None: self.json["max_concurrent_runs"] = max_concurrent_runs if git_source is not None: - self.json["git_source"] = git_source + self.json["git_source"] = git_source.as_dict() if access_control_list is not None: self.json["access_control_list"] = [acl.as_dict() for acl in access_control_list] From f35833ee4e43b7c3ea44a0ad7dad1a3296bdca1d Mon Sep 17 00:00:00 2001 From: Sri Tikkireddy Date: Tue, 27 Jun 2023 19:11:42 -0400 Subject: [PATCH 09/19] fixed tests with proper testing logic --- .../providers/databricks/operators/databricks.py | 2 +- .../databricks/operators/test_databricks.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 622e8cc424040..63dd5901aa963 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -297,7 +297,7 @@ def execute(self, context: Context) -> int: job_id = self._hook.find_job_id_by_name(self.json["name"]) if job_id is None: return self._hook.create(self.json) - self._hook.reset(job_id, self.json) + self._hook.reset(str(job_id), self.json) return job_id diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 9aaf00c2a1f99..676c00aecbd4f 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -423,9 +423,9 @@ def test_exec_create(self, db_mock_class): db_mock = db_mock_class.return_value db_mock.create.return_value = JOB_ID - ti = mock.MagicMock() - ti.xcom_pull.return_value = None - op.execute({"ti": ti}) + db_mock.find_job_id_by_name.return_value = None + + return_result = op.execute({}) expected = utils.normalise_json_content( { @@ -451,7 +451,7 @@ def test_exec_create(self, db_mock_class): ) db_mock.create.assert_called_once_with(expected) - assert JOB_ID == op.job_id + assert JOB_ID == return_result @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_exec_reset(self, db_mock_class): @@ -473,10 +473,9 @@ def test_exec_reset(self, db_mock_class): } op = DatabricksJobsCreateOperator(task_id=TASK_ID, json=json) db_mock = db_mock_class.return_value + db_mock.find_job_id_by_name.return_value = JOB_ID - ti = mock.MagicMock() - ti.xcom_pull.return_value = JOB_ID - op.execute({"ti": ti}) + return_result = op.execute({}) expected = utils.normalise_json_content( { @@ -502,7 +501,7 @@ def test_exec_reset(self, db_mock_class): ) db_mock.reset.assert_called_once_with(JOB_ID, expected) - assert JOB_ID == op.job_id + assert JOB_ID == return_result class TestDatabricksSubmitRunOperator: From 6a43550652256d5b8d1d94eb4c14efdffa30302a Mon Sep 17 00:00:00 2001 From: Sri Tikkireddy Date: Wed, 28 Jun 2023 13:32:12 -0400 Subject: [PATCH 10/19] added jobs_create to provider.yaml file --- airflow/providers/databricks/provider.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index eae192906f989..626b81831ad35 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -67,6 +67,7 @@ integrations: - integration-name: Databricks external-doc-url: https://databricks.com/ how-to-guide: + - /docs/apache-airflow-providers-databricks/operators/jobs_create.rst - /docs/apache-airflow-providers-databricks/operators/submit_run.rst - /docs/apache-airflow-providers-databricks/operators/run_now.rst logo: /integration-logos/databricks/Databricks.png From 66cc643c7bc0ba27517dd22189769fed673cb60a Mon Sep 17 00:00:00 2001 From: "sri.tikkireddy" Date: Tue, 25 Jul 2023 10:19:00 -0400 Subject: [PATCH 11/19] resoved comments on pr --- .../providers/databricks/hooks/databricks.py | 4 +-- .../databricks/operators/databricks.py | 29 +++++++++---------- .../operators/jobs_create.rst | 14 ++++----- .../databricks/hooks/test_databricks.py | 4 +-- .../databricks/operators/test_databricks.py | 28 +++++++++--------- .../databricks/example_databricks.py | 8 ++--- 6 files changed, 42 insertions(+), 45 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 4e9a0b3dc0fca..3a6e80609c87d 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -196,7 +196,7 @@ def __init__( ) -> None: super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay, retry_args, caller) - def create(self, json: dict) -> int: + def create_job(self, json: dict) -> int: """ Utility function to call the ``api/2.1/jobs/create`` endpoint. @@ -206,7 +206,7 @@ def create(self, json: dict) -> int: response = self._do_api_call(CREATE_ENDPOINT, json) return response["job_id"] - def reset(self, job_id: str, json: dict): + def reset_job(self, job_id: str, json: dict): """ Utility function to call the ``api/2.1/jobs/reset`` endpoint. diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 63dd5901aa963..716bd14e847b6 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -24,7 +24,7 @@ from logging import Logger from typing import TYPE_CHECKING, Any, Sequence -from databricks.sdk.service import jobs as j +from databricks.sdk.service import jobs from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning @@ -165,7 +165,7 @@ def get_link( return XCom.get_value(key=XCOM_RUN_PAGE_URL_KEY, ti_key=ti_key) -class DatabricksJobsCreateOperator(BaseOperator): +class DatabricksCreateJobsOperator(BaseOperator): """ Creates (or resets) a Databricks job using the `api/2.1/jobs/create @@ -229,15 +229,15 @@ def __init__( json: Any | None = None, name: str | None = None, tags: dict[str, str] | None = None, - tasks: list[j.JobTaskSettings] | None = None, - job_clusters: list[j.JobCluster] | None = None, - email_notifications: j.JobEmailNotifications | None = None, - webhook_notifications: j.JobWebhookNotifications | None = None, + tasks: list[jobs.JobTaskSettings] | None = None, + job_clusters: list[jobs.JobCluster] | None = None, + email_notifications: jobs.JobEmailNotifications | None = None, + webhook_notifications: jobs.JobWebhookNotifications | None = None, timeout_seconds: int | None = None, - schedule: j.CronSchedule | None = None, + schedule: jobs.CronSchedule | None = None, max_concurrent_runs: int | None = None, - git_source: j.GitSource | None = None, - access_control_list: list[j.AccessControlRequest] | None = None, + git_source: jobs.GitSource | None = None, + access_control_list: list[jobs.AccessControlRequest] | None = None, databricks_conn_id: str = "databricks_default", polling_period_seconds: int = 30, databricks_retry_limit: int = 3, @@ -245,7 +245,7 @@ def __init__( databricks_retry_args: dict[Any, Any] | None = None, **kwargs, ) -> None: - """Creates a new ``DatabricksJobsCreateOperator``.""" + """Creates a new ``DatabricksCreateJobsOperator``.""" super().__init__(**kwargs) self.json = json or {} self.databricks_conn_id = databricks_conn_id @@ -280,15 +280,12 @@ def __init__( @cached_property def _hook(self): - return self._get_hook(caller="DatabricksJobsCreateOperator") - - def _get_hook(self, caller: str) -> DatabricksHook: return DatabricksHook( self.databricks_conn_id, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, retry_args=self.databricks_retry_args, - caller=caller, + caller="DatabricksCreateJobsOperator", ) def execute(self, context: Context) -> int: @@ -296,8 +293,8 @@ def execute(self, context: Context) -> int: raise AirflowException("Missing required parameter: name") job_id = self._hook.find_job_id_by_name(self.json["name"]) if job_id is None: - return self._hook.create(self.json) - self._hook.reset(str(job_id), self.json) + return self._hook.create_job(self.json) + self._hook.reset_job(str(job_id), self.json) return job_id diff --git a/docs/apache-airflow-providers-databricks/operators/jobs_create.rst b/docs/apache-airflow-providers-databricks/operators/jobs_create.rst index a9a69976768c3..779095e92cd6b 100644 --- a/docs/apache-airflow-providers-databricks/operators/jobs_create.rst +++ b/docs/apache-airflow-providers-databricks/operators/jobs_create.rst @@ -17,10 +17,10 @@ -DatabricksJobsCreateOperator +DatabricksCreateJobsOperator ============================ -Use the :class:`~airflow.providers.databricks.operators.DatabricksJobsCreateOperator` to create +Use the :class:`~airflow.providers.databricks.operators.DatabricksCreateJobsOperator` to create (or reset) a Databricks job. This operator relies on past XComs to remember the ``job_id`` that was created so that repeated calls with this operator will update the existing job rather than creating new ones. When paired with the DatabricksRunNowOperator all runs will fall under the same @@ -31,18 +31,18 @@ Using the Operator ------------------ There are three ways to instantiate this operator. In the first way, you can take the JSON payload that you typically use -to call the ``api/2.1/jobs/create`` endpoint and pass it directly to our ``DatabricksJobsCreateOperator`` through the +to call the ``api/2.1/jobs/create`` endpoint and pass it directly to our ``DatabricksCreateJobsOperator`` through the ``json`` parameter. With this approach you get full control over the underlying payload to Jobs REST API, including execution of Databricks jobs with multiple tasks, but it's harder to detect errors because of the lack of the type checking. -The second way to accomplish the same thing is to use the named parameters of the ``DatabricksJobsCreateOperator`` directly. Note that there is exactly +The second way to accomplish the same thing is to use the named parameters of the ``DatabricksCreateJobsOperator`` directly. Note that there is exactly one named parameter for each top level parameter in the ``api/2.1/jobs/create`` endpoint. The third way is to use both the json parameter **AND** the named parameters. They will be merged together. If there are conflicts during the merge, the named parameters will take precedence and override the top level ``json`` keys. -Currently the named parameters that ``DatabricksJobsCreateOperator`` supports are: +Currently the named parameters that ``DatabricksCreateJobsOperator`` supports are: - ``name`` - ``tags`` - ``tasks`` @@ -62,7 +62,7 @@ Examples Specifying parameters as JSON ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -An example usage of the DatabricksJobsCreateOperator is as follows: +An example usage of the DatabricksCreateJobsOperator is as follows: .. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py :language: python @@ -82,7 +82,7 @@ You can also use named parameters to initialize the operator and run the job. Pairing with DatabricksRunNowOperator ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -You can use the ``job_id`` that is returned by the DatabricksJobsCreateOperator in the +You can use the ``job_id`` that is returned by the DatabricksCreateJobsOperator in the return_value XCom as an argument to the DatabricksRunNowOperator to run the job. .. exampleinclude:: /../../tests/system/providers/databricks/example_databricks.py diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index dde00b7c7322f..c836691b1e178 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -408,7 +408,7 @@ def test_create(self, mock_requests): status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock json = {"name": "test"} - job_id = self.hook.create(json) + job_id = self.hook.create_job(json) assert job_id == JOB_ID @@ -427,7 +427,7 @@ def test_reset(self, mock_requests): status_code_mock = mock.PropertyMock(return_value=200) type(mock_requests.post.return_value).status_code = status_code_mock json = {"name": "test"} - self.hook.reset(JOB_ID, json) + self.hook.reset_job(JOB_ID, json) mock_requests.post.assert_called_once_with( reset_endpoint(HOST), diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 676c00aecbd4f..5a38debd11105 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -28,7 +28,7 @@ from airflow.models import DAG from airflow.providers.databricks.hooks.databricks import RunState from airflow.providers.databricks.operators.databricks import ( - DatabricksJobsCreateOperator, + DatabricksCreateJobsOperator, DatabricksRunNowDeferrableOperator, DatabricksRunNowOperator, DatabricksSubmitRunDeferrableOperator, @@ -243,12 +243,12 @@ def make_run_with_state_mock( ) -class TestDatabricksJobsCreateOperator: +class TestDatabricksCreateJobsOperator: def test_init_with_named_parameters(self): """ Test the initializer with the named parameters. """ - op = DatabricksJobsCreateOperator( + op = DatabricksCreateJobsOperator( task_id=TASK_ID, name=JOB_NAME, tags=TAGS, @@ -297,7 +297,7 @@ def test_init_with_json(self): "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, } - op = DatabricksJobsCreateOperator(task_id=TASK_ID, json=json) + op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) expected = utils.normalise_json_content( { @@ -348,7 +348,7 @@ def test_init_with_merging(self): "access_control_list": ACCESS_CONTROL_LIST, } - op = DatabricksJobsCreateOperator( + op = DatabricksCreateJobsOperator( task_id=TASK_ID, json=json, name=override_name, @@ -386,7 +386,7 @@ def test_init_with_templating(self): json = {"name": "test-{{ ds }}"} dag = DAG("test", start_date=datetime.now()) - op = DatabricksJobsCreateOperator(dag=dag, task_id=TASK_ID, json=json) + op = DatabricksCreateJobsOperator(dag=dag, task_id=TASK_ID, json=json) op.render_template_fields(context={"ds": DATE}) expected = utils.normalise_json_content({"name": f"test-{DATE}"}) assert expected == op.json @@ -399,7 +399,7 @@ def test_init_with_bad_type(self): r"for parameter json\[test\] is not a number or a string" ) with pytest.raises(AirflowException, match=exception_message): - DatabricksJobsCreateOperator(task_id=TASK_ID, json=json) + DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_exec_create(self, db_mock_class): @@ -419,9 +419,9 @@ def test_exec_create(self, db_mock_class): "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, } - op = DatabricksJobsCreateOperator(task_id=TASK_ID, json=json) + op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) db_mock = db_mock_class.return_value - db_mock.create.return_value = JOB_ID + db_mock.create_job.return_value = JOB_ID db_mock.find_job_id_by_name.return_value = None @@ -447,10 +447,10 @@ def test_exec_create(self, db_mock_class): retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay, retry_args=None, - caller="DatabricksJobsCreateOperator", + caller="DatabricksCreateJobsOperator", ) - db_mock.create.assert_called_once_with(expected) + db_mock.create_job.assert_called_once_with(expected) assert JOB_ID == return_result @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") @@ -471,7 +471,7 @@ def test_exec_reset(self, db_mock_class): "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, } - op = DatabricksJobsCreateOperator(task_id=TASK_ID, json=json) + op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) db_mock = db_mock_class.return_value db_mock.find_job_id_by_name.return_value = JOB_ID @@ -497,10 +497,10 @@ def test_exec_reset(self, db_mock_class): retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay, retry_args=None, - caller="DatabricksJobsCreateOperator", + caller="DatabricksCreateJobsOperator", ) - db_mock.reset.assert_called_once_with(JOB_ID, expected) + db_mock.reset_job.assert_called_once_with(JOB_ID, expected) assert JOB_ID == return_result diff --git a/tests/system/providers/databricks/example_databricks.py b/tests/system/providers/databricks/example_databricks.py index 2f931aec93bff..c7b4293efcf52 100644 --- a/tests/system/providers/databricks/example_databricks.py +++ b/tests/system/providers/databricks/example_databricks.py @@ -37,7 +37,7 @@ from airflow import DAG from airflow.providers.databricks.operators.databricks import ( - DatabricksJobsCreateOperator, + DatabricksCreateJobsOperator, DatabricksRunNowOperator, DatabricksSubmitRunOperator, ) @@ -77,7 +77,7 @@ ], } - jobs_create_json = DatabricksJobsCreateOperator(task_id="jobs_create_json", json=job) + jobs_create_json = DatabricksCreateJobsOperator(task_id="jobs_create_json", json=job) # [END howto_operator_databricks_jobs_create_json] # [START howto_operator_databricks_jobs_create_named] @@ -102,13 +102,13 @@ }, ] - jobs_create_named = DatabricksJobsCreateOperator( + jobs_create_named = DatabricksCreateJobsOperator( task_id="jobs_create_named", tasks=tasks, job_clusters=job_clusters ) # [END howto_operator_databricks_jobs_create_named] # [START howto_operator_databricks_run_now] - # Example of using the DatabricksRunNowOperator after creating a job with DatabricksJobsCreateOperator. + # Example of using the DatabricksRunNowOperator after creating a job with DatabricksCreateJobsOperator. run_now = DatabricksRunNowOperator( task_id="run_now", job_id="{{ ti.xcom_pull(task_ids='jobs_create_named') }}" ) From 7a5b650ebe007bd2366712a92e92df7dd5eac65d Mon Sep 17 00:00:00 2001 From: "sri.tikkireddy" Date: Tue, 25 Jul 2023 10:25:57 -0400 Subject: [PATCH 12/19] fixed imports in test_databricks.py --- .../databricks/operators/test_databricks.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 5a38debd11105..3916658a437c6 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -22,7 +22,7 @@ from unittest.mock import MagicMock import pytest -from databricks.sdk.service import jobs as j +from databricks.sdk.service import jobs from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG @@ -252,15 +252,15 @@ def test_init_with_named_parameters(self): task_id=TASK_ID, name=JOB_NAME, tags=TAGS, - tasks=[j.JobTaskSettings.from_dict(task) for task in TASKS], - job_clusters=[j.JobCluster.from_dict(cluster) for cluster in JOB_CLUSTERS], - email_notifications=j.JobEmailNotifications.from_dict(EMAIL_NOTIFICATIONS), - webhook_notifications=j.JobWebhookNotifications.from_dict(WEBHOOK_NOTIFICATIONS), + tasks=[jobs.JobTaskSettings.from_dict(task) for task in TASKS], + job_clusters=[jobs.JobCluster.from_dict(cluster) for cluster in JOB_CLUSTERS], + email_notifications=jobs.JobEmailNotifications.from_dict(EMAIL_NOTIFICATIONS), + webhook_notifications=jobs.JobWebhookNotifications.from_dict(WEBHOOK_NOTIFICATIONS), timeout_seconds=TIMEOUT_SECONDS, - schedule=j.CronSchedule.from_dict(SCHEDULE), + schedule=jobs.CronSchedule.from_dict(SCHEDULE), max_concurrent_runs=MAX_CONCURRENT_RUNS, - git_source=j.GitSource.from_dict(GIT_SOURCE), - access_control_list=[j.AccessControlRequest.from_dict(acl) for acl in ACCESS_CONTROL_LIST], + git_source=jobs.GitSource.from_dict(GIT_SOURCE), + access_control_list=[jobs.AccessControlRequest.from_dict(acl) for acl in ACCESS_CONTROL_LIST], ) expected = utils.normalise_json_content( { @@ -355,12 +355,12 @@ def test_init_with_merging(self): tags=override_tags, tasks=override_tasks, job_clusters=override_job_clusters, - email_notifications=j.JobEmailNotifications.from_dict(override_email_notifications), - webhook_notifications=j.JobWebhookNotifications.from_dict(override_webhook_notifications), + email_notifications=jobs.JobEmailNotifications.from_dict(override_email_notifications), + webhook_notifications=jobs.JobWebhookNotifications.from_dict(override_webhook_notifications), timeout_seconds=override_timeout_seconds, - schedule=j.CronSchedule.from_dict(override_schedule), + schedule=jobs.CronSchedule.from_dict(override_schedule), max_concurrent_runs=override_max_concurrent_runs, - git_source=j.GitSource.from_dict(override_git_source), + git_source=jobs.GitSource.from_dict(override_git_source), access_control_list=override_access_control_list, ) From 1e268352dd9d8551b56bcfe90c46498cb702c04c Mon Sep 17 00:00:00 2001 From: "sri.tikkireddy" Date: Tue, 25 Jul 2023 12:09:46 -0400 Subject: [PATCH 13/19] added correct type hint for reset_job --- airflow/providers/databricks/hooks/databricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 3a6e80609c87d..e1da837f43e5f 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -206,7 +206,7 @@ def create_job(self, json: dict) -> int: response = self._do_api_call(CREATE_ENDPOINT, json) return response["job_id"] - def reset_job(self, job_id: str, json: dict): + def reset_job(self, job_id: str, json: dict) -> None: """ Utility function to call the ``api/2.1/jobs/reset`` endpoint. From 683129982e65b806d5558979e0343364b8d6f4c4 Mon Sep 17 00:00:00 2001 From: stikkireddy <54602805+stikkireddy@users.noreply.github.com> Date: Tue, 15 Aug 2023 08:15:34 -0400 Subject: [PATCH 14/19] change type hint for json arg in DatabricksCreateJobsOperator --- airflow/providers/databricks/operators/databricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 716bd14e847b6..3f1194d760e3c 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -226,7 +226,7 @@ class DatabricksCreateJobsOperator(BaseOperator): def __init__( self, *, - json: Any | None = None, + json: dict | None = None, name: str | None = None, tags: dict[str, str] | None = None, tasks: list[jobs.JobTaskSettings] | None = None, From 6d2670417a5b37daa1b6800b8a01101b1f53b8a7 Mon Sep 17 00:00:00 2001 From: "sri.tikkireddy" Date: Wed, 11 Oct 2023 10:32:27 -0400 Subject: [PATCH 15/19] fixed CI errors --- .../databricks/operators/databricks.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 3f1194d760e3c..cc665e3669fd4 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -24,8 +24,6 @@ from logging import Logger from typing import TYPE_CHECKING, Any, Sequence -from databricks.sdk.service import jobs - from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator, BaseOperatorLink, XCom @@ -36,6 +34,8 @@ if TYPE_CHECKING: from logging import Logger + from databricks.sdk.service import jobs + from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.context import Context @@ -166,16 +166,11 @@ def get_link( class DatabricksCreateJobsOperator(BaseOperator): - """ - Creates (or resets) a Databricks job using the - `api/2.1/jobs/create - `_ - (or `api/2.1/jobs/reset - `_) - API endpoint. + """Creates (or resets) a Databricks job using the API endpoint. .. seealso:: - https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate + https://docs.databricks.com/api/workspace/jobs/create + https://docs.databricks.com/api/workspace/jobs/reset :param json: A JSON object containing API parameters which will be passed directly to the ``api/2.1/jobs/create`` endpoint. The other named parameters @@ -215,6 +210,7 @@ class DatabricksCreateJobsOperator(BaseOperator): :param databricks_retry_delay: Number of seconds to wait between retries (it might be a floating point number). :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. + """ # Used in airflow.models.BaseOperator @@ -226,7 +222,7 @@ class DatabricksCreateJobsOperator(BaseOperator): def __init__( self, *, - json: dict | None = None, + json: Any | None = None, name: str | None = None, tags: dict[str, str] | None = None, tasks: list[jobs.JobTaskSettings] | None = None, From 064e17a74e010a6c94a227b718f530479b3ef61c Mon Sep 17 00:00:00 2001 From: "sri.tikkireddy" Date: Wed, 11 Oct 2023 12:17:00 -0400 Subject: [PATCH 16/19] fixed broken tests and imports. also pinned databricks sdk to a specific version ==0.10.0 --- airflow/providers/databricks/operators/databricks.py | 6 +++--- airflow/providers/databricks/provider.yaml | 2 +- generated/provider_dependencies.json | 2 +- tests/providers/databricks/operators/test_databricks.py | 8 ++++---- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index cc665e3669fd4..03c4de62fbd57 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -225,15 +225,15 @@ def __init__( json: Any | None = None, name: str | None = None, tags: dict[str, str] | None = None, - tasks: list[jobs.JobTaskSettings] | None = None, + tasks: list[jobs.Task] | None = None, job_clusters: list[jobs.JobCluster] | None = None, email_notifications: jobs.JobEmailNotifications | None = None, - webhook_notifications: jobs.JobWebhookNotifications | None = None, + webhook_notifications: jobs.WebhookNotifications | None = None, timeout_seconds: int | None = None, schedule: jobs.CronSchedule | None = None, max_concurrent_runs: int | None = None, git_source: jobs.GitSource | None = None, - access_control_list: list[jobs.AccessControlRequest] | None = None, + access_control_list: list[jobs.JobAccessControlRequest] | None = None, databricks_conn_id: str = "databricks_default", polling_period_seconds: int = 30, databricks_retry_limit: int = 3, diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index 626b81831ad35..7c2839a008139 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -61,7 +61,7 @@ dependencies: # The 2.9.1 (to be released soon) already contains the fix - databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0 - aiohttp>=3.6.3, <4 - - databricks-sdk>=0.1.11, <1.0.0 + - databricks-sdk==0.10.0 integrations: - integration-name: Databricks diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index b9e92b5bc8dee..0e2003462e30a 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -301,7 +301,7 @@ "aiohttp>=3.6.3, <4", "apache-airflow-providers-common-sql>=1.5.0", "apache-airflow>=2.5.0", - "databricks-sdk>=0.1.11, <1.0.0", + "databricks-sdk==0.10.0", "databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0", "requests>=2.27,<3" ], diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 3916658a437c6..a699effdeb9a3 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -252,15 +252,15 @@ def test_init_with_named_parameters(self): task_id=TASK_ID, name=JOB_NAME, tags=TAGS, - tasks=[jobs.JobTaskSettings.from_dict(task) for task in TASKS], + tasks=[jobs.Task.from_dict(task) for task in TASKS], job_clusters=[jobs.JobCluster.from_dict(cluster) for cluster in JOB_CLUSTERS], email_notifications=jobs.JobEmailNotifications.from_dict(EMAIL_NOTIFICATIONS), - webhook_notifications=jobs.JobWebhookNotifications.from_dict(WEBHOOK_NOTIFICATIONS), + webhook_notifications=jobs.WebhookNotifications.from_dict(WEBHOOK_NOTIFICATIONS), timeout_seconds=TIMEOUT_SECONDS, schedule=jobs.CronSchedule.from_dict(SCHEDULE), max_concurrent_runs=MAX_CONCURRENT_RUNS, git_source=jobs.GitSource.from_dict(GIT_SOURCE), - access_control_list=[jobs.AccessControlRequest.from_dict(acl) for acl in ACCESS_CONTROL_LIST], + access_control_list=[jobs.JobAccessControlRequest.from_dict(acl) for acl in ACCESS_CONTROL_LIST], ) expected = utils.normalise_json_content( { @@ -356,7 +356,7 @@ def test_init_with_merging(self): tasks=override_tasks, job_clusters=override_job_clusters, email_notifications=jobs.JobEmailNotifications.from_dict(override_email_notifications), - webhook_notifications=jobs.JobWebhookNotifications.from_dict(override_webhook_notifications), + webhook_notifications=jobs.WebhookNotifications.from_dict(override_webhook_notifications), timeout_seconds=override_timeout_seconds, schedule=jobs.CronSchedule.from_dict(override_schedule), max_concurrent_runs=override_max_concurrent_runs, From 054d2e72ddc5b49d3f7a5cf9ae4773b2a3031fd8 Mon Sep 17 00:00:00 2001 From: "sri.tikkireddy" Date: Thu, 12 Oct 2023 06:42:09 -0400 Subject: [PATCH 17/19] fixed broken tests and imports. also pinned databricks sdk to a specific version ==0.10.0 --- .../databricks/operators/databricks.py | 30 +++++++++---------- airflow/providers/databricks/provider.yaml | 7 +++++ .../databricks/operators/test_databricks.py | 23 +++++++------- 3 files changed, 32 insertions(+), 28 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 03c4de62fbd57..559d6005a3353 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -34,8 +34,6 @@ if TYPE_CHECKING: from logging import Logger - from databricks.sdk.service import jobs - from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.context import Context @@ -225,15 +223,15 @@ def __init__( json: Any | None = None, name: str | None = None, tags: dict[str, str] | None = None, - tasks: list[jobs.Task] | None = None, - job_clusters: list[jobs.JobCluster] | None = None, - email_notifications: jobs.JobEmailNotifications | None = None, - webhook_notifications: jobs.WebhookNotifications | None = None, + tasks: list[dict] | None = None, + job_clusters: list[dict] | None = None, + email_notifications: dict | None = None, + webhook_notifications: dict | None = None, timeout_seconds: int | None = None, - schedule: jobs.CronSchedule | None = None, + schedule: dict | None = None, max_concurrent_runs: int | None = None, - git_source: jobs.GitSource | None = None, - access_control_list: list[jobs.JobAccessControlRequest] | None = None, + git_source: dict | None = None, + access_control_list: list[dict] | None = None, databricks_conn_id: str = "databricks_default", polling_period_seconds: int = 30, databricks_retry_limit: int = 3, @@ -254,23 +252,23 @@ def __init__( if tags is not None: self.json["tags"] = tags if tasks is not None: - self.json["tasks"] = [task.as_dict() for task in tasks] + self.json["tasks"] = tasks if job_clusters is not None: - self.json["job_clusters"] = [job_cluster.as_dict() for job_cluster in job_clusters] + self.json["job_clusters"] = job_clusters if email_notifications is not None: - self.json["email_notifications"] = email_notifications.as_dict() + self.json["email_notifications"] = email_notifications if webhook_notifications is not None: - self.json["webhook_notifications"] = webhook_notifications.as_dict() + self.json["webhook_notifications"] = webhook_notifications if timeout_seconds is not None: self.json["timeout_seconds"] = timeout_seconds if schedule is not None: - self.json["schedule"] = schedule.as_dict() + self.json["schedule"] = schedule if max_concurrent_runs is not None: self.json["max_concurrent_runs"] = max_concurrent_runs if git_source is not None: - self.json["git_source"] = git_source.as_dict() + self.json["git_source"] = git_source if access_control_list is not None: - self.json["access_control_list"] = [acl.as_dict() for acl in access_control_list] + self.json["access_control_list"] = access_control_list self.json = normalise_json_content(self.json) diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index 7c2839a008139..dfb5a2fecc291 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -125,3 +125,10 @@ connection-types: extra-links: - airflow.providers.databricks.operators.databricks.DatabricksJobRunLink + +additional-extras: + # pip install apache-airflow-providers-databricks[sdk] + - name: sdk + description: Install Databricks SDK + dependencies: + - databricks-sdk==0.10.0 diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index a699effdeb9a3..5fb1a31cd32bc 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -22,7 +22,6 @@ from unittest.mock import MagicMock import pytest -from databricks.sdk.service import jobs from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG @@ -252,15 +251,15 @@ def test_init_with_named_parameters(self): task_id=TASK_ID, name=JOB_NAME, tags=TAGS, - tasks=[jobs.Task.from_dict(task) for task in TASKS], - job_clusters=[jobs.JobCluster.from_dict(cluster) for cluster in JOB_CLUSTERS], - email_notifications=jobs.JobEmailNotifications.from_dict(EMAIL_NOTIFICATIONS), - webhook_notifications=jobs.WebhookNotifications.from_dict(WEBHOOK_NOTIFICATIONS), + tasks=TASKS, + job_clusters=JOB_CLUSTERS, + email_notifications=EMAIL_NOTIFICATIONS, + webhook_notifications=WEBHOOK_NOTIFICATIONS, timeout_seconds=TIMEOUT_SECONDS, - schedule=jobs.CronSchedule.from_dict(SCHEDULE), + schedule=SCHEDULE, max_concurrent_runs=MAX_CONCURRENT_RUNS, - git_source=jobs.GitSource.from_dict(GIT_SOURCE), - access_control_list=[jobs.JobAccessControlRequest.from_dict(acl) for acl in ACCESS_CONTROL_LIST], + git_source=GIT_SOURCE, + access_control_list=ACCESS_CONTROL_LIST, ) expected = utils.normalise_json_content( { @@ -355,12 +354,12 @@ def test_init_with_merging(self): tags=override_tags, tasks=override_tasks, job_clusters=override_job_clusters, - email_notifications=jobs.JobEmailNotifications.from_dict(override_email_notifications), - webhook_notifications=jobs.WebhookNotifications.from_dict(override_webhook_notifications), + email_notifications=override_email_notifications, + webhook_notifications=override_webhook_notifications, timeout_seconds=override_timeout_seconds, - schedule=jobs.CronSchedule.from_dict(override_schedule), + schedule=override_schedule, max_concurrent_runs=override_max_concurrent_runs, - git_source=jobs.GitSource.from_dict(override_git_source), + git_source=override_git_source, access_control_list=override_access_control_list, ) From 749dd84307ab36f894a24e7b3b8a3c3e98145bdb Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Fri, 27 Oct 2023 14:16:28 +0100 Subject: [PATCH 18/19] Fix CI static checks --- airflow/providers/databricks/operators/databricks.py | 2 -- tests/system/providers/databricks/example_databricks.py | 1 - 2 files changed, 3 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 559d6005a3353..5a54b7e9eea2f 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -32,8 +32,6 @@ from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event if TYPE_CHECKING: - from logging import Logger - from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.context import Context diff --git a/tests/system/providers/databricks/example_databricks.py b/tests/system/providers/databricks/example_databricks.py index c7b4293efcf52..3a7ed3e53b2e0 100644 --- a/tests/system/providers/databricks/example_databricks.py +++ b/tests/system/providers/databricks/example_databricks.py @@ -52,7 +52,6 @@ tags=["example"], catchup=False, ) as dag: - # [START howto_operator_databricks_jobs_create_json] # Example of using the JSON parameter to initialize the operator. job = { From 043f692b42c573402a66fd230b7a2d65ac396a4f Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Fri, 27 Oct 2023 14:23:02 +0100 Subject: [PATCH 19/19] Remove databricks-sdk dependency This was agreed with @stikkireddy, since there the SDK interfaces are changing ATM. When it becomes stable, we can re-introduce this dependency --- airflow/providers/databricks/provider.yaml | 1 - generated/provider_dependencies.json | 1 - 2 files changed, 2 deletions(-) diff --git a/airflow/providers/databricks/provider.yaml b/airflow/providers/databricks/provider.yaml index dfb5a2fecc291..52463c351705c 100644 --- a/airflow/providers/databricks/provider.yaml +++ b/airflow/providers/databricks/provider.yaml @@ -61,7 +61,6 @@ dependencies: # The 2.9.1 (to be released soon) already contains the fix - databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0 - aiohttp>=3.6.3, <4 - - databricks-sdk==0.10.0 integrations: - integration-name: Databricks diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 0e2003462e30a..772fa4f78a199 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -301,7 +301,6 @@ "aiohttp>=3.6.3, <4", "apache-airflow-providers-common-sql>=1.5.0", "apache-airflow>=2.5.0", - "databricks-sdk==0.10.0", "databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0", "requests>=2.27,<3" ],