From ae4c60feaedc9cd0fe510d6df0db22bdc021cecd Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Tue, 3 May 2022 17:47:23 -0600 Subject: [PATCH 01/35] first databrick implementation --- dbt/adapters/spark/impl.py | 125 ++++++++++++++++++ .../spark/macros/materializations/table.sql | 23 +++- 2 files changed, 145 insertions(+), 3 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index eb001fbc9..c5c0e73f6 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -1,4 +1,7 @@ import re +import requests +import time +import base64 from concurrent.futures import Future from dataclasses import dataclass from typing import Optional, List, Dict, Any, Union, Iterable @@ -400,6 +403,128 @@ def run_sql_for_tests(self, sql, fetch, conn): finally: conn.transaction_open = False + def submit_python_job(self, schema, identifier, file_contents): + # basically copying drew's other script over :) + auth_header = {'Authorization': f'Bearer {self.connections.profile.credentials.token}'} + b64_encoded_content = base64.b64encode(file_contents.encode()).decode() + + # create new dir + # response = requests.post( + # f'https://{self.connections.profile.credentials.host}/api/2.0/workspace/mkdirs', + # headers=auth_header, + # json={ + # 'path': f'/Users/{schema}/', + # } + # ) + + # add notebook + response = requests.post( + f'https://{self.connections.profile.credentials.host}/api/2.0/workspace/import', + headers=auth_header, + json={ + 'path': f'/Users/{schema}/{identifier}', + 'content': b64_encoded_content, + 'language': 'PYTHON', + 'overwrite': True, + 'format': 'SOURCE' + } + ) + # need to validate submit succeed here + resp = response.json() + + # submit job + response = requests.post( + f'https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/submit', + headers=auth_header, + json={ + "run_name": "debug task", + "existing_cluster_id": self.connections.profile.credentials.cluster, + 'notebook_task': { + 'notebook_path': f'/Users/{schema}/{identifier}', + } + } + ) + + + run_id = response.json()['run_id'] + + # poll until job finish + # this feels bad + state = None + while state != 'TERMINATED': + resp = requests.get( + f'https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/get?run_id={run_id}', + headers=auth_header, + ) + state = resp.json()['state']['life_cycle_state'] + print(f"Polling.... in state: {state}") + time.sleep(1) + + run_output = requests.get( + f'https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/get-output?run_id={run_id}', + headers=auth_header, + ) + +""" +sample run_output + +{ + "metadata": { + "job_id": 116603964177912, + "run_id": 981, + "number_in_job": 981, + "state": { + "life_cycle_state": "TERMINATED", + "result_state": "FAILED", + "state_message": "", + "user_cancelled_or_timedout": false + }, + "start_time": 1651187347908, + "setup_duration": 0, + "execution_duration": 14000, + "cleanup_duration": 0, + "end_time": 1651187362169, + "creator_user_name": "chenyu.li@dbtlabs.com", + "run_name": "debug task", + "run_page_url": "https://dbc-9274a712-595c.cloud.databricks.com/?o=733816330658499#job/116603964177912/run/981", + "run_type": "SUBMIT_RUN", + "tasks": [ + { + "run_id": 981, + "task_key": "debug_task", + "notebook_task": { + "notebook_path": "/Users/chenyu.li@dbtlabs.com/random" + }, + "existing_cluster_id": "0411-132815-9avnz2eh", + "state": { + "life_cycle_state": "TERMINATED", + "result_state": "FAILED", + "state_message": "", + "user_cancelled_or_timedout": false + }, + "run_page_url": "https://dbc-9274a712-595c.cloud.databricks.com/?o=733816330658499#job/116603964177912/run/981", + "start_time": 1651187347908, + "setup_duration": 0, + "execution_duration": 14000, + "cleanup_duration": 0, + "end_time": 1651187362169, + "cluster_instance": { + "cluster_id": "0411-132815-9avnz2eh", + "spark_context_id": "768047027524473660" + }, + "attempt_number": 0 + } + ], + "format": "MULTI_TASK" + }, + "error": "SyntaxError: EOL while scanning string literal", + "error_trace": "File \"\", line 1\n print(\"hello world\n ^\nSyntaxError: EOL while scanning string literal", + "notebook_output": {} +} +""" + + + # spark does something interesting with joins when both tables have the same # static values for the join condition and complains that the join condition is diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql index 3ae2df973..843b685b8 100644 --- a/dbt/include/spark/macros/materializations/table.sql +++ b/dbt/include/spark/macros/materializations/table.sql @@ -18,9 +18,16 @@ {%- endif %} -- build model - {% call statement('main') -%} - {{ create_table_as(False, target_relation, sql) }} - {%- endcall %} + {% if config.get('language', 'sql') == 'python' -%}} + -- sql here is really just the compiled python code + {%- set python_code = py_complete_script(model=model, schema=schema, python_code=sql) -%} + {{ log("python code " ~ python_code ) }} + {{adapter.submit_python_job('chenyu.li@dbtlabs.com', identifier, python_code)}} + {%- else -%} + {% call statement('main') -%} + {{ create_table_as(False, target_relation, sql) }} + {%- endcall %} + {%- endif %} {% do persist_docs(target_relation, model) %} @@ -29,3 +36,13 @@ {{ return({'relations': [target_relation]})}} {% endmaterialization %} + + +{% macro py_complete_script(model, schema, python_code) %} +{#-- can we wrap in 'def model:' here? or will formatting screw us? --#} +{#-- Above was Drew's comment --#} +{{ python_code }} + +df.write.mode("overwrite").format("delta").saveAsTable("{{schema}}.{{model['alias']}}") + +{% endmacro %} From ecece22692a70b22d954fdf807fa2c89dc573ac5 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Wed, 4 May 2022 19:13:54 -0600 Subject: [PATCH 02/35] add cell to notebook --- dbt/adapters/spark/impl.py | 1 + .../spark/macros/materializations/table.sql | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index c5c0e73f6..8a2336a60 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -459,6 +459,7 @@ def submit_python_job(self, schema, identifier, file_contents): state = resp.json()['state']['life_cycle_state'] print(f"Polling.... in state: {state}") time.sleep(1) + # TODO resp.json()['state_message'] contain useful information and we may want to surface to user if job fails run_output = requests.get( f'https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/get-output?run_id={run_id}', diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql index 843b685b8..f91bb1a4d 100644 --- a/dbt/include/spark/macros/materializations/table.sql +++ b/dbt/include/spark/macros/materializations/table.sql @@ -37,12 +37,31 @@ {% endmaterialization %} +{% macro py_script_prefix( model) %} +# this part is dbt logic for get ref work +{{ build_ref_function(model ) }} +{{ build_source_function(model ) }} + +def config(*args, **kwargs): + pass + +class dbt: + config = config + ref = ref + source = source + +# COMMAND ---------- +# This part of the code is python model code + +{% endmacro %} {% macro py_complete_script(model, schema, python_code) %} {#-- can we wrap in 'def model:' here? or will formatting screw us? --#} {#-- Above was Drew's comment --#} {{ python_code }} +# COMMAND ---------- +# this is materialization code df.write.mode("overwrite").format("delta").saveAsTable("{{schema}}.{{model['alias']}}") {% endmacro %} From 77dcf6b56c5bbd3e21bf185e500004ae83c9966c Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Thu, 5 May 2022 16:32:05 -0600 Subject: [PATCH 03/35] proper return run result --- dbt/adapters/spark/impl.py | 84 ++++--------------- .../spark/macros/materializations/table.sql | 32 ++----- 2 files changed, 25 insertions(+), 91 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 8a2336a60..edb323ad9 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -407,22 +407,26 @@ def submit_python_job(self, schema, identifier, file_contents): # basically copying drew's other script over :) auth_header = {'Authorization': f'Bearer {self.connections.profile.credentials.token}'} b64_encoded_content = base64.b64encode(file_contents.encode()).decode() - # create new dir - # response = requests.post( - # f'https://{self.connections.profile.credentials.host}/api/2.0/workspace/mkdirs', - # headers=auth_header, - # json={ - # 'path': f'/Users/{schema}/', - # } - # ) + if not self.connections.profile.credentials.user: + raise ValueError('Need to supply user in profile to submit python job') + # it is safe to call mkdirs even if dir already exists and have content inside + work_dir = f'/Users/{self.connections.profile.credentials.user}/{schema}' + response = requests.post( + f'https://{self.connections.profile.credentials.host}/api/2.0/workspace/mkdirs', + headers=auth_header, + json={ + 'path': work_dir, + } + ) + # TODO check the response # add notebook response = requests.post( f'https://{self.connections.profile.credentials.host}/api/2.0/workspace/import', headers=auth_header, json={ - 'path': f'/Users/{schema}/{identifier}', + 'path': f'{work_dir}/{identifier}', 'content': b64_encoded_content, 'language': 'PYTHON', 'overwrite': True, @@ -440,7 +444,7 @@ def submit_python_job(self, schema, identifier, file_contents): "run_name": "debug task", "existing_cluster_id": self.connections.profile.credentials.cluster, 'notebook_task': { - 'notebook_path': f'/Users/{schema}/{identifier}', + 'notebook_path': f'{work_dir}/{identifier}', } } ) @@ -465,64 +469,8 @@ def submit_python_job(self, schema, identifier, file_contents): f'https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/get-output?run_id={run_id}', headers=auth_header, ) - -""" -sample run_output - -{ - "metadata": { - "job_id": 116603964177912, - "run_id": 981, - "number_in_job": 981, - "state": { - "life_cycle_state": "TERMINATED", - "result_state": "FAILED", - "state_message": "", - "user_cancelled_or_timedout": false - }, - "start_time": 1651187347908, - "setup_duration": 0, - "execution_duration": 14000, - "cleanup_duration": 0, - "end_time": 1651187362169, - "creator_user_name": "chenyu.li@dbtlabs.com", - "run_name": "debug task", - "run_page_url": "https://dbc-9274a712-595c.cloud.databricks.com/?o=733816330658499#job/116603964177912/run/981", - "run_type": "SUBMIT_RUN", - "tasks": [ - { - "run_id": 981, - "task_key": "debug_task", - "notebook_task": { - "notebook_path": "/Users/chenyu.li@dbtlabs.com/random" - }, - "existing_cluster_id": "0411-132815-9avnz2eh", - "state": { - "life_cycle_state": "TERMINATED", - "result_state": "FAILED", - "state_message": "", - "user_cancelled_or_timedout": false - }, - "run_page_url": "https://dbc-9274a712-595c.cloud.databricks.com/?o=733816330658499#job/116603964177912/run/981", - "start_time": 1651187347908, - "setup_duration": 0, - "execution_duration": 14000, - "cleanup_duration": 0, - "end_time": 1651187362169, - "cluster_instance": { - "cluster_id": "0411-132815-9avnz2eh", - "spark_context_id": "768047027524473660" - }, - "attempt_number": 0 - } - ], - "format": "MULTI_TASK" - }, - "error": "SyntaxError: EOL while scanning string literal", - "error_trace": "File \"\", line 1\n print(\"hello world\n ^\nSyntaxError: EOL while scanning string literal", - "notebook_output": {} -} -""" + # TODO have more info here and determine what do we want to if python model fail + return run_output.json()['metadata']['state']['result_state'] diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql index f91bb1a4d..2c57efc10 100644 --- a/dbt/include/spark/macros/materializations/table.sql +++ b/dbt/include/spark/macros/materializations/table.sql @@ -22,7 +22,11 @@ -- sql here is really just the compiled python code {%- set python_code = py_complete_script(model=model, schema=schema, python_code=sql) -%} {{ log("python code " ~ python_code ) }} - {{adapter.submit_python_job('chenyu.li@dbtlabs.com', identifier, python_code)}} + {% set result = adapter.submit_python_job(schema, identifier, python_code) %} + {% call noop_statement('main', result,) %} + -- python model return run result -- + {% endcall %} + {%- else -%} {% call statement('main') -%} {{ create_table_as(False, target_relation, sql) }} @@ -37,31 +41,13 @@ {% endmaterialization %} -{% macro py_script_prefix( model) %} -# this part is dbt logic for get ref work -{{ build_ref_function(model ) }} -{{ build_source_function(model ) }} - -def config(*args, **kwargs): - pass - -class dbt: - config = config - ref = ref - source = source - -# COMMAND ---------- -# This part of the code is python model code - -{% endmacro %} {% macro py_complete_script(model, schema, python_code) %} -{#-- can we wrap in 'def model:' here? or will formatting screw us? --#} -{#-- Above was Drew's comment --#} {{ python_code }} - # COMMAND ---------- -# this is materialization code -df.write.mode("overwrite").format("delta").saveAsTable("{{schema}}.{{model['alias']}}") +# this is materialization code dbt generated, please do not modify +# we are doing this to make some example code working databricks and snowflake +df = model(spark, dbt) +df.write.mode("overwrite").format("delta").saveAsTable("{{schema}}.{{model['alias']}}") {% endmacro %} From a4211a965cd190320ef434f956f9256d20e7b323 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Thu, 5 May 2022 16:56:58 -0600 Subject: [PATCH 04/35] properly make function available --- dbt/adapters/spark/impl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index edb323ad9..5de3fa887 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -13,6 +13,7 @@ from dbt.adapters.base import AdapterConfig from dbt.adapters.base.impl import catch_as_completed +from dbt.adapters.base.meta import available from dbt.adapters.sql import SQLAdapter from dbt.adapters.spark import SparkConnectionManager from dbt.adapters.spark import SparkRelation @@ -402,7 +403,7 @@ def run_sql_for_tests(self, sql, fetch, conn): raise finally: conn.transaction_open = False - + @available def submit_python_job(self, schema, identifier, file_contents): # basically copying drew's other script over :) auth_header = {'Authorization': f'Bearer {self.connections.profile.credentials.token}'} From 2e2cae1cbef7e5a3a7a47c5f7b65994849e203c9 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Fri, 6 May 2022 16:51:40 -0600 Subject: [PATCH 05/35] ref return df --- dbt/include/spark/macros/adapters.sql | 4 ++++ dbt/include/spark/macros/materializations/table.sql | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index e96501c45..cb9dcc79d 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -284,3 +284,7 @@ {% do run_query(sql) %} {% endmacro %} + +{% macro load_df_def() %} + load_df_function = spark.table +{% endmacro %} diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql index 2c57efc10..744dfeca2 100644 --- a/dbt/include/spark/macros/materializations/table.sql +++ b/dbt/include/spark/macros/materializations/table.sql @@ -48,6 +48,6 @@ # this is materialization code dbt generated, please do not modify # we are doing this to make some example code working databricks and snowflake -df = model(spark, dbt) +df = model(dbt) df.write.mode("overwrite").format("delta").saveAsTable("{{schema}}.{{model['alias']}}") {% endmacro %} From d69fe4ccb9654f8dc00674f927235424ae334e07 Mon Sep 17 00:00:00 2001 From: Jeremy Cohen Date: Fri, 6 May 2022 15:41:33 +0200 Subject: [PATCH 06/35] Bump version to 1.3.0a1 --- dbt/adapters/spark/__version__.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dbt/adapters/spark/__version__.py b/dbt/adapters/spark/__version__.py index a6b977228..a9fe3c3ee 100644 --- a/dbt/adapters/spark/__version__.py +++ b/dbt/adapters/spark/__version__.py @@ -1 +1 @@ -version = "1.2.0a1" +version = "1.3.0a1" diff --git a/setup.py b/setup.py index 12ecbacde..e9ba3cc1e 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ def _get_dbt_core_version(): package_name = "dbt-spark" -package_version = "1.2.0a1" +package_version = "1.3.0a1" dbt_core_version = _get_dbt_core_version() description = """The Apache Spark adapter plugin for dbt""" From c79991d651bee0ce07f7626fecd0aa8dcb1bfb6a Mon Sep 17 00:00:00 2001 From: Jeremy Cohen Date: Fri, 6 May 2022 15:42:19 +0200 Subject: [PATCH 07/35] Small quality of life fixups --- dbt/adapters/spark/impl.py | 2 +- dbt/include/spark/macros/materializations/table.sql | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 5de3fa887..2f2fd49bd 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -462,7 +462,7 @@ def submit_python_job(self, schema, identifier, file_contents): headers=auth_header, ) state = resp.json()['state']['life_cycle_state'] - print(f"Polling.... in state: {state}") + logger.debug(f"Polling.... in state: {state}") time.sleep(1) # TODO resp.json()['state_message'] contain useful information and we may want to surface to user if job fails diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql index 744dfeca2..a6b1d2a65 100644 --- a/dbt/include/spark/macros/materializations/table.sql +++ b/dbt/include/spark/macros/materializations/table.sql @@ -23,7 +23,7 @@ {%- set python_code = py_complete_script(model=model, schema=schema, python_code=sql) -%} {{ log("python code " ~ python_code ) }} {% set result = adapter.submit_python_job(schema, identifier, python_code) %} - {% call noop_statement('main', result,) %} + {% call noop_statement('main', 'OK', 'OK', 1) %} -- python model return run result -- {% endcall %} @@ -44,6 +44,9 @@ {% macro py_complete_script(model, schema, python_code) %} {{ python_code }} + +df = model(spark, dbt) + # COMMAND ---------- # this is materialization code dbt generated, please do not modify From ccdc1706e73b29ec9feecfb01f053371b91230f0 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Fri, 6 May 2022 17:44:30 -0600 Subject: [PATCH 08/35] update more result --- dbt/include/spark/macros/materializations/table.sql | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql index a6b1d2a65..961b2a9f0 100644 --- a/dbt/include/spark/macros/materializations/table.sql +++ b/dbt/include/spark/macros/materializations/table.sql @@ -23,7 +23,7 @@ {%- set python_code = py_complete_script(model=model, schema=schema, python_code=sql) -%} {{ log("python code " ~ python_code ) }} {% set result = adapter.submit_python_job(schema, identifier, python_code) %} - {% call noop_statement('main', 'OK', 'OK', 1) %} + {% call noop_statement('main', result, 'OK', 1) %} -- python model return run result -- {% endcall %} @@ -45,12 +45,10 @@ {% macro py_complete_script(model, schema, python_code) %} {{ python_code }} -df = model(spark, dbt) +df = model(dbt) # COMMAND ---------- # this is materialization code dbt generated, please do not modify -# we are doing this to make some example code working databricks and snowflake -df = model(dbt) df.write.mode("overwrite").format("delta").saveAsTable("{{schema}}.{{model['alias']}}") {% endmacro %} From be2b0a2c8c7e05aa4cf18adfbedb0008b4fa217d Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Mon, 9 May 2022 18:14:07 -0700 Subject: [PATCH 09/35] fix format --- dbt/adapters/spark/impl.py | 284 ++++++++++++++++++------------------- 1 file changed, 141 insertions(+), 143 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 2f2fd49bd..eb24202d5 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -25,19 +25,19 @@ logger = AdapterLogger("Spark") -GET_COLUMNS_IN_RELATION_MACRO_NAME = 'get_columns_in_relation' -LIST_SCHEMAS_MACRO_NAME = 'list_schemas' -LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching' -DROP_RELATION_MACRO_NAME = 'drop_relation' -FETCH_TBL_PROPERTIES_MACRO_NAME = 'fetch_tbl_properties' +GET_COLUMNS_IN_RELATION_MACRO_NAME = "get_columns_in_relation" +LIST_SCHEMAS_MACRO_NAME = "list_schemas" +LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching" +DROP_RELATION_MACRO_NAME = "drop_relation" +FETCH_TBL_PROPERTIES_MACRO_NAME = "fetch_tbl_properties" -KEY_TABLE_OWNER = 'Owner' -KEY_TABLE_STATISTICS = 'Statistics' +KEY_TABLE_OWNER = "Owner" +KEY_TABLE_STATISTICS = "Statistics" @dataclass class SparkConfig(AdapterConfig): - file_format: str = 'parquet' + file_format: str = "parquet" location_root: Optional[str] = None partition_by: Optional[Union[List[str], str]] = None clustered_by: Optional[Union[List[str], str]] = None @@ -48,38 +48,36 @@ class SparkConfig(AdapterConfig): class SparkAdapter(SQLAdapter): COLUMN_NAMES = ( - 'table_database', - 'table_schema', - 'table_name', - 'table_type', - 'table_comment', - 'table_owner', - 'column_name', - 'column_index', - 'column_type', - 'column_comment', - - 'stats:bytes:label', - 'stats:bytes:value', - 'stats:bytes:description', - 'stats:bytes:include', - - 'stats:rows:label', - 'stats:rows:value', - 'stats:rows:description', - 'stats:rows:include', + "table_database", + "table_schema", + "table_name", + "table_type", + "table_comment", + "table_owner", + "column_name", + "column_index", + "column_type", + "column_comment", + "stats:bytes:label", + "stats:bytes:value", + "stats:bytes:description", + "stats:bytes:include", + "stats:rows:label", + "stats:rows:value", + "stats:rows:description", + "stats:rows:include", ) INFORMATION_COLUMNS_REGEX = re.compile( - r"^ \|-- (.*): (.*) \(nullable = (.*)\b", re.MULTILINE) + r"^ \|-- (.*): (.*) \(nullable = (.*)\b", re.MULTILINE + ) INFORMATION_OWNER_REGEX = re.compile(r"^Owner: (.*)$", re.MULTILINE) - INFORMATION_STATISTICS_REGEX = re.compile( - r"^Statistics: (.*)$", re.MULTILINE) + INFORMATION_STATISTICS_REGEX = re.compile(r"^Statistics: (.*)$", re.MULTILINE) HUDI_METADATA_COLUMNS = [ - '_hoodie_commit_time', - '_hoodie_commit_seqno', - '_hoodie_record_key', - '_hoodie_partition_path', - '_hoodie_file_name' + "_hoodie_commit_time", + "_hoodie_commit_seqno", + "_hoodie_record_key", + "_hoodie_partition_path", + "_hoodie_file_name", ] Relation = SparkRelation @@ -89,7 +87,7 @@ class SparkAdapter(SQLAdapter): @classmethod def date_function(cls) -> str: - return 'current_timestamp()' + return "current_timestamp()" @classmethod def convert_text_type(cls, agate_table, col_idx): @@ -113,31 +111,28 @@ def convert_datetime_type(cls, agate_table, col_idx): return "timestamp" def quote(self, identifier): - return '`{}`'.format(identifier) + return "`{}`".format(identifier) def add_schema_to_cache(self, schema) -> str: """Cache a new schema in dbt. It will show up in `list relations`.""" if schema is None: name = self.nice_connection_name() dbt.exceptions.raise_compiler_error( - 'Attempted to cache a null schema for {}'.format(name) + "Attempted to cache a null schema for {}".format(name) ) if dbt.flags.USE_CACHE: self.cache.add_schema(None, schema) # so jinja doesn't render things - return '' + return "" def list_relations_without_caching( self, schema_relation: SparkRelation ) -> List[SparkRelation]: - kwargs = {'schema_relation': schema_relation} + kwargs = {"schema_relation": schema_relation} try: - results = self.execute_macro( - LIST_RELATIONS_MACRO_NAME, - kwargs=kwargs - ) + results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs) except dbt.exceptions.RuntimeException as e: - errmsg = getattr(e, 'msg', '') + errmsg = getattr(e, "msg", "") if f"Database '{schema_relation}' not found" in errmsg: return [] else: @@ -150,13 +145,14 @@ def list_relations_without_caching( if len(row) != 4: raise dbt.exceptions.RuntimeException( f'Invalid value from "show table extended ...", ' - f'got {len(row)} values, expected 4' + f"got {len(row)} values, expected 4" ) _schema, name, _, information = row - rel_type = RelationType.View \ - if 'Type: VIEW' in information else RelationType.Table - is_delta = 'Provider: delta' in information - is_hudi = 'Provider: hudi' in information + rel_type = ( + RelationType.View if "Type: VIEW" in information else RelationType.Table + ) + is_delta = "Provider: delta" in information + is_hudi = "Provider: hudi" in information relation = self.Relation.create( schema=_schema, identifier=name, @@ -178,9 +174,7 @@ def get_relation( return super().get_relation(database, schema, identifier) def parse_describe_extended( - self, - relation: Relation, - raw_rows: List[agate.Row] + self, relation: Relation, raw_rows: List[agate.Row] ) -> List[SparkColumn]: # Convert the Row to a dict dict_rows = [dict(zip(row._keys, row._values)) for row in raw_rows] @@ -189,44 +183,45 @@ def parse_describe_extended( pos = self.find_table_information_separator(dict_rows) # Remove rows that start with a hash, they are comments - rows = [ - row for row in raw_rows[0:pos] - if not row['col_name'].startswith('#') - ] - metadata = { - col['col_name']: col['data_type'] for col in raw_rows[pos + 1:] - } + rows = [row for row in raw_rows[0:pos] if not row["col_name"].startswith("#")] + metadata = {col["col_name"]: col["data_type"] for col in raw_rows[pos + 1 :]} raw_table_stats = metadata.get(KEY_TABLE_STATISTICS) table_stats = SparkColumn.convert_table_stats(raw_table_stats) - return [SparkColumn( - table_database=None, - table_schema=relation.schema, - table_name=relation.name, - table_type=relation.type, - table_owner=str(metadata.get(KEY_TABLE_OWNER)), - table_stats=table_stats, - column=column['col_name'], - column_index=idx, - dtype=column['data_type'], - ) for idx, column in enumerate(rows)] + return [ + SparkColumn( + table_database=None, + table_schema=relation.schema, + table_name=relation.name, + table_type=relation.type, + table_owner=str(metadata.get(KEY_TABLE_OWNER)), + table_stats=table_stats, + column=column["col_name"], + column_index=idx, + dtype=column["data_type"], + ) + for idx, column in enumerate(rows) + ] @staticmethod def find_table_information_separator(rows: List[dict]) -> int: pos = 0 for row in rows: - if not row['col_name'] or row['col_name'].startswith('#'): + if not row["col_name"] or row["col_name"].startswith("#"): break pos += 1 return pos def get_columns_in_relation(self, relation: Relation) -> List[SparkColumn]: - cached_relations = self.cache.get_relations( - relation.database, relation.schema) - cached_relation = next((cached_relation - for cached_relation in cached_relations - if str(cached_relation) == str(relation)), - None) + cached_relations = self.cache.get_relations(relation.database, relation.schema) + cached_relation = next( + ( + cached_relation + for cached_relation in cached_relations + if str(cached_relation) == str(relation) + ), + None, + ) columns = [] if cached_relation and cached_relation.information: columns = self.parse_columns_from_information(cached_relation) @@ -243,29 +238,27 @@ def get_columns_in_relation(self, relation: Relation) -> List[SparkColumn]: # CDW would just return and empty list, normalizing the behavior here errmsg = getattr(e, "msg", "") if ( - "Table or view not found" in errmsg or - "NoSuchTableException" in errmsg + "Table or view not found" in errmsg + or "NoSuchTableException" in errmsg ): pass else: raise e # strip hudi metadata columns. - columns = [x for x in columns - if x.name not in self.HUDI_METADATA_COLUMNS] + columns = [x for x in columns if x.name not in self.HUDI_METADATA_COLUMNS] return columns def parse_columns_from_information( - self, relation: SparkRelation + self, relation: SparkRelation ) -> List[SparkColumn]: - owner_match = re.findall( - self.INFORMATION_OWNER_REGEX, relation.information) + owner_match = re.findall(self.INFORMATION_OWNER_REGEX, relation.information) owner = owner_match[0] if owner_match else None - matches = re.finditer( - self.INFORMATION_COLUMNS_REGEX, relation.information) + matches = re.finditer(self.INFORMATION_COLUMNS_REGEX, relation.information) columns = [] stats_match = re.findall( - self.INFORMATION_STATISTICS_REGEX, relation.information) + self.INFORMATION_STATISTICS_REGEX, relation.information + ) raw_table_stats = stats_match[0] if stats_match else None table_stats = SparkColumn.convert_table_stats(raw_table_stats) for match_num, match in enumerate(matches): @@ -279,7 +272,7 @@ def parse_columns_from_information( table_owner=owner, column=column_name, dtype=column_type, - table_stats=table_stats + table_stats=table_stats, ) columns.append(column) return columns @@ -292,15 +285,14 @@ def _get_columns_for_catalog( for column in columns: # convert SparkColumns into catalog dicts as_dict = column.to_column_dict() - as_dict['column_name'] = as_dict.pop('column', None) - as_dict['column_type'] = as_dict.pop('dtype') - as_dict['table_database'] = None + as_dict["column_name"] = as_dict.pop("column", None) + as_dict["column_type"] = as_dict.pop("dtype") + as_dict["table_database"] = None yield as_dict def get_properties(self, relation: Relation) -> Dict[str, str]: properties = self.execute_macro( - FETCH_TBL_PROPERTIES_MACRO_NAME, - kwargs={'relation': relation} + FETCH_TBL_PROPERTIES_MACRO_NAME, kwargs={"relation": relation} ) return dict(properties) @@ -308,28 +300,37 @@ def get_catalog(self, manifest): schema_map = self._get_catalog_schemas(manifest) if len(schema_map) > 1: dbt.exceptions.raise_compiler_error( - f'Expected only one database in get_catalog, found ' - f'{list(schema_map)}' + f"Expected only one database in get_catalog, found " + f"{list(schema_map)}" ) with executor(self.config) as tpe: futures: List[Future[agate.Table]] = [] for info, schemas in schema_map.items(): for schema in schemas: - futures.append(tpe.submit_connected( - self, schema, - self._get_one_catalog, info, [schema], manifest - )) + futures.append( + tpe.submit_connected( + self, + schema, + self._get_one_catalog, + info, + [schema], + manifest, + ) + ) catalogs, exceptions = catch_as_completed(futures) return catalogs, exceptions def _get_one_catalog( - self, information_schema, schemas, manifest, + self, + information_schema, + schemas, + manifest, ) -> agate.Table: if len(schemas) != 1: dbt.exceptions.raise_compiler_error( - f'Expected only one schema in spark _get_one_catalog, found ' - f'{schemas}' + f"Expected only one schema in spark _get_one_catalog, found " + f"{schemas}" ) database = information_schema.database @@ -339,14 +340,11 @@ def _get_one_catalog( for relation in self.list_relations(database, schema): logger.debug("Getting table schema for relation {}", relation) columns.extend(self._get_columns_for_catalog(relation)) - return agate.Table.from_object( - columns, column_types=DEFAULT_TYPE_TESTER - ) + return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) def check_schema_exists(self, database, schema): results = self.execute_macro( - LIST_SCHEMAS_MACRO_NAME, - kwargs={'database': database} + LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database} ) exists = True if schema in [row[0] for row in results] else False @@ -357,7 +355,7 @@ def get_rows_different_sql( relation_a: BaseRelation, relation_b: BaseRelation, column_names: Optional[List[str]] = None, - except_operator: str = 'EXCEPT', + except_operator: str = "EXCEPT", ) -> str: """Generate SQL for a query that returns a single row with a two columns: the number of rows that are different between the two @@ -370,7 +368,7 @@ def get_rows_different_sql( names = sorted((self.quote(c.name) for c in columns)) else: names = sorted((self.quote(n) for n in column_names)) - columns_csv = ', '.join(names) + columns_csv = ", ".join(names) sql = COLUMNS_EQUAL_SQL.format( columns=columns_csv, @@ -388,7 +386,7 @@ def run_sql_for_tests(self, sql, fetch, conn): try: cursor.execute(sql) if fetch == "one": - if hasattr(cursor, 'fetchone'): + if hasattr(cursor, "fetchone"): return cursor.fetchone() else: # AttributeError: 'PyhiveConnectionWrapper' object has no attribute 'fetchone' @@ -403,77 +401,77 @@ def run_sql_for_tests(self, sql, fetch, conn): raise finally: conn.transaction_open = False + @available def submit_python_job(self, schema, identifier, file_contents): # basically copying drew's other script over :) - auth_header = {'Authorization': f'Bearer {self.connections.profile.credentials.token}'} + auth_header = { + "Authorization": f"Bearer {self.connections.profile.credentials.token}" + } b64_encoded_content = base64.b64encode(file_contents.encode()).decode() # create new dir if not self.connections.profile.credentials.user: - raise ValueError('Need to supply user in profile to submit python job') + raise ValueError("Need to supply user in profile to submit python job") # it is safe to call mkdirs even if dir already exists and have content inside - work_dir = f'/Users/{self.connections.profile.credentials.user}/{schema}' + work_dir = f"/Users/{self.connections.profile.credentials.user}/{schema}" response = requests.post( - f'https://{self.connections.profile.credentials.host}/api/2.0/workspace/mkdirs', + f"https://{self.connections.profile.credentials.host}/api/2.0/workspace/mkdirs", headers=auth_header, json={ - 'path': work_dir, - } + "path": work_dir, + }, ) # TODO check the response # add notebook response = requests.post( - f'https://{self.connections.profile.credentials.host}/api/2.0/workspace/import', + f"https://{self.connections.profile.credentials.host}/api/2.0/workspace/import", headers=auth_header, json={ - 'path': f'{work_dir}/{identifier}', - 'content': b64_encoded_content, - 'language': 'PYTHON', - 'overwrite': True, - 'format': 'SOURCE' - } + "path": f"{work_dir}/{identifier}", + "content": b64_encoded_content, + "language": "PYTHON", + "overwrite": True, + "format": "SOURCE", + }, ) # need to validate submit succeed here resp = response.json() - + # submit job response = requests.post( - f'https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/submit', + f"https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/submit", headers=auth_header, json={ "run_name": "debug task", "existing_cluster_id": self.connections.profile.credentials.cluster, - 'notebook_task': { - 'notebook_path': f'{work_dir}/{identifier}', - } - } + "notebook_task": { + "notebook_path": f"{work_dir}/{identifier}", + }, + }, ) - - run_id = response.json()['run_id'] + run_id = response.json()["run_id"] # poll until job finish # this feels bad state = None - while state != 'TERMINATED': + while state != "TERMINATED": resp = requests.get( - f'https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/get?run_id={run_id}', + f"https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/get?run_id={run_id}", headers=auth_header, ) - state = resp.json()['state']['life_cycle_state'] + state = resp.json()["state"]["life_cycle_state"] logger.debug(f"Polling.... in state: {state}") time.sleep(1) # TODO resp.json()['state_message'] contain useful information and we may want to surface to user if job fails - + run_output = requests.get( - f'https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/get-output?run_id={run_id}', + f"https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/get-output?run_id={run_id}", headers=auth_header, ) # TODO have more info here and determine what do we want to if python model fail - return run_output.json()['metadata']['state']['result_state'] - - + return run_output.json()["metadata"]["state"]["result_state"] # spark does something interesting with joins when both tables have the same @@ -481,7 +479,7 @@ def submit_python_job(self, schema, identifier, file_contents): # "trivial". Which is true, though it seems like an unreasonable cause for # failure! It also doesn't like the `from foo, bar` syntax as opposed to # `from foo cross join bar`. -COLUMNS_EQUAL_SQL = ''' +COLUMNS_EQUAL_SQL = """ with diff_count as ( SELECT 1 as id, @@ -508,4 +506,4 @@ def submit_python_job(self, schema, identifier, file_contents): diff_count.num_missing as num_mismatched from row_count_diff cross join diff_count -'''.strip() +""".strip() From 4195ccde5aa1a47049e128edc8b99fb1e2a9d8fa Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Tue, 31 May 2022 16:25:09 -0700 Subject: [PATCH 10/35] better error handling for api call and target relation templating --- dbt/adapters/spark/impl.py | 57 +++++++++++++------ .../spark/macros/materializations/table.sql | 6 +- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index eb24202d5..3b7bc43f9 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -403,12 +403,20 @@ def run_sql_for_tests(self, sql, fetch, conn): conn.transaction_open = False @available - def submit_python_job(self, schema, identifier, file_contents): - # basically copying drew's other script over :) + def submit_python_job(self, schema, identifier, file_contents, timeout=None): + + # TODO limit this function to run only when doing the materialization of python nodes + + # assuming that for python job running over 1 day user would mannually overwrite this + if not timeout: + timeout = 60*60*24 + if timeout <= 0 : + raise ValueError('Timeout must larger than 0') + auth_header = { "Authorization": f"Bearer {self.connections.profile.credentials.token}" } - b64_encoded_content = base64.b64encode(file_contents.encode()).decode() + # create new dir if not self.connections.profile.credentials.user: raise ValueError("Need to supply user in profile to submit python job") @@ -421,9 +429,11 @@ def submit_python_job(self, schema, identifier, file_contents): "path": work_dir, }, ) - # TODO check the response + if response.status_code != 200: + raise dbt.exceptions.RuntimeException(f'Error creating work_dir for python notebooks\n {response.content}') # add notebook + b64_encoded_content = base64.b64encode(file_contents.encode()).decode() response = requests.post( f"https://{self.connections.profile.credentials.host}/api/2.0/workspace/import", headers=auth_header, @@ -435,11 +445,12 @@ def submit_python_job(self, schema, identifier, file_contents): "format": "SOURCE", }, ) - # need to validate submit succeed here - resp = response.json() + if response.status_code != 200: + raise dbt.exceptions.RuntimeException(f'Error creating python notebook.\n {response.content}') + # submit job - response = requests.post( + submit_response = requests.post( f"https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/submit", headers=auth_header, json={ @@ -450,28 +461,42 @@ def submit_python_job(self, schema, identifier, file_contents): }, }, ) - - run_id = response.json()["run_id"] + if submit_response.status_code != 200: + raise dbt.exceptions.RuntimeException(f'Error creating python run.\n {response.content}') + # poll until job finish - # this feels bad state = None - while state != "TERMINATED": + start = time.time() + run_id = submit_response.json()["run_id"] + terminal_states = ['TERMINATED', 'SKIPPED', 'INTERNAL_ERROR'] + while state not in terminal_states and time.time() - start < timeout: + time.sleep(1) resp = requests.get( f"https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/get?run_id={run_id}", headers=auth_header, ) - state = resp.json()["state"]["life_cycle_state"] + json_resp = resp.json() + state = json_resp["state"]["life_cycle_state"] logger.debug(f"Polling.... in state: {state}") - time.sleep(1) - # TODO resp.json()['state_message'] contain useful information and we may want to surface to user if job fails + if state != "TERMINATED": + raise dbt.exceptions.RuntimeException(f"python model run ended in state {state} with state_message\n {json_resp['state']['state_message']}") + # get end state to return to user run_output = requests.get( f"https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/get-output?run_id={run_id}", headers=auth_header, ) - # TODO have more info here and determine what do we want to if python model fail - return run_output.json()["metadata"]["state"]["result_state"] + json_run_output = run_output.json() + result_state = json_run_output["metadata"]["state"]["result_state"] + if result_state != 'SUCCESS': + raise dbt.exceptions.RuntimeException( + f"\ +Python model failed with traceback as:\n \ +(Note that the line number here does not match the line number in your code due to dbt templating)\n \ +{json_run_output['error_trace']}" + ) + return result_state # spark does something interesting with joins when both tables have the same diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql index 961b2a9f0..6d7a8dc75 100644 --- a/dbt/include/spark/macros/materializations/table.sql +++ b/dbt/include/spark/macros/materializations/table.sql @@ -20,7 +20,7 @@ -- build model {% if config.get('language', 'sql') == 'python' -%}} -- sql here is really just the compiled python code - {%- set python_code = py_complete_script(model=model, schema=schema, python_code=sql) -%} + {%- set python_code = py_complete_script(python_code=sql, target_relation=target_relation) -%} {{ log("python code " ~ python_code ) }} {% set result = adapter.submit_python_job(schema, identifier, python_code) %} {% call noop_statement('main', result, 'OK', 1) %} @@ -42,7 +42,7 @@ {% endmaterialization %} -{% macro py_complete_script(model, schema, python_code) %} +{% macro py_complete_script(python_code, target_relation) %} {{ python_code }} df = model(dbt) @@ -50,5 +50,5 @@ df = model(dbt) # COMMAND ---------- # this is materialization code dbt generated, please do not modify -df.write.mode("overwrite").format("delta").saveAsTable("{{schema}}.{{model['alias']}}") +df.write.mode("overwrite").format("delta").saveAsTable("{{ target_relation }}") {% endmacro %} From 98f60e7cde1bddd06697a338bd62748045a1b31b Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Tue, 31 May 2022 16:30:54 -0700 Subject: [PATCH 11/35] fix format --- dbt/adapters/spark/impl.py | 75 +++++++++++++++----------------------- 1 file changed, 30 insertions(+), 45 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 3b7bc43f9..5c5092b8e 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -67,9 +67,7 @@ class SparkAdapter(SQLAdapter): "stats:rows:description", "stats:rows:include", ) - INFORMATION_COLUMNS_REGEX = re.compile( - r"^ \|-- (.*): (.*) \(nullable = (.*)\b", re.MULTILINE - ) + INFORMATION_COLUMNS_REGEX = re.compile(r"^ \|-- (.*): (.*) \(nullable = (.*)\b", re.MULTILINE) INFORMATION_OWNER_REGEX = re.compile(r"^Owner: (.*)$", re.MULTILINE) INFORMATION_STATISTICS_REGEX = re.compile(r"^Statistics: (.*)$", re.MULTILINE) HUDI_METADATA_COLUMNS = [ @@ -148,9 +146,7 @@ def list_relations_without_caching( f"got {len(row)} values, expected 4" ) _schema, name, _, information = row - rel_type = ( - RelationType.View if "Type: VIEW" in information else RelationType.Table - ) + rel_type = RelationType.View if "Type: VIEW" in information else RelationType.Table is_delta = "Provider: delta" in information is_hudi = "Provider: hudi" in information relation = self.Relation.create( @@ -165,9 +161,7 @@ def list_relations_without_caching( return relations - def get_relation( - self, database: str, schema: str, identifier: str - ) -> Optional[BaseRelation]: + def get_relation(self, database: str, schema: str, identifier: str) -> Optional[BaseRelation]: if not self.Relation.include_policy.database: database = None @@ -237,10 +231,7 @@ def get_columns_in_relation(self, relation: Relation) -> List[SparkColumn]: # spark would throw error when table doesn't exist, where other # CDW would just return and empty list, normalizing the behavior here errmsg = getattr(e, "msg", "") - if ( - "Table or view not found" in errmsg - or "NoSuchTableException" in errmsg - ): + if "Table or view not found" in errmsg or "NoSuchTableException" in errmsg: pass else: raise e @@ -249,16 +240,12 @@ def get_columns_in_relation(self, relation: Relation) -> List[SparkColumn]: columns = [x for x in columns if x.name not in self.HUDI_METADATA_COLUMNS] return columns - def parse_columns_from_information( - self, relation: SparkRelation - ) -> List[SparkColumn]: + def parse_columns_from_information(self, relation: SparkRelation) -> List[SparkColumn]: owner_match = re.findall(self.INFORMATION_OWNER_REGEX, relation.information) owner = owner_match[0] if owner_match else None matches = re.finditer(self.INFORMATION_COLUMNS_REGEX, relation.information) columns = [] - stats_match = re.findall( - self.INFORMATION_STATISTICS_REGEX, relation.information - ) + stats_match = re.findall(self.INFORMATION_STATISTICS_REGEX, relation.information) raw_table_stats = stats_match[0] if stats_match else None table_stats = SparkColumn.convert_table_stats(raw_table_stats) for match_num, match in enumerate(matches): @@ -277,9 +264,7 @@ def parse_columns_from_information( columns.append(column) return columns - def _get_columns_for_catalog( - self, relation: SparkRelation - ) -> Iterable[Dict[str, Any]]: + def _get_columns_for_catalog(self, relation: SparkRelation) -> Iterable[Dict[str, Any]]: columns = self.parse_columns_from_information(relation) for column in columns: @@ -300,8 +285,7 @@ def get_catalog(self, manifest): schema_map = self._get_catalog_schemas(manifest) if len(schema_map) > 1: dbt.exceptions.raise_compiler_error( - f"Expected only one database in get_catalog, found " - f"{list(schema_map)}" + f"Expected only one database in get_catalog, found " f"{list(schema_map)}" ) with executor(self.config) as tpe: @@ -329,8 +313,7 @@ def _get_one_catalog( ) -> agate.Table: if len(schemas) != 1: dbt.exceptions.raise_compiler_error( - f"Expected only one schema in spark _get_one_catalog, found " - f"{schemas}" + f"Expected only one schema in spark _get_one_catalog, found " f"{schemas}" ) database = information_schema.database @@ -343,9 +326,7 @@ def _get_one_catalog( return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) def check_schema_exists(self, database, schema): - results = self.execute_macro( - LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database} - ) + results = self.execute_macro(LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}) exists = True if schema in [row[0] for row in results] else False return exists @@ -409,14 +390,12 @@ def submit_python_job(self, schema, identifier, file_contents, timeout=None): # assuming that for python job running over 1 day user would mannually overwrite this if not timeout: - timeout = 60*60*24 - if timeout <= 0 : - raise ValueError('Timeout must larger than 0') - - auth_header = { - "Authorization": f"Bearer {self.connections.profile.credentials.token}" - } - + timeout = 60 * 60 * 24 + if timeout <= 0: + raise ValueError("Timeout must larger than 0") + + auth_header = {"Authorization": f"Bearer {self.connections.profile.credentials.token}"} + # create new dir if not self.connections.profile.credentials.user: raise ValueError("Need to supply user in profile to submit python job") @@ -430,7 +409,9 @@ def submit_python_job(self, schema, identifier, file_contents, timeout=None): }, ) if response.status_code != 200: - raise dbt.exceptions.RuntimeException(f'Error creating work_dir for python notebooks\n {response.content}') + raise dbt.exceptions.RuntimeException( + f"Error creating work_dir for python notebooks\n {response.content}" + ) # add notebook b64_encoded_content = base64.b64encode(file_contents.encode()).decode() @@ -446,8 +427,9 @@ def submit_python_job(self, schema, identifier, file_contents, timeout=None): }, ) if response.status_code != 200: - raise dbt.exceptions.RuntimeException(f'Error creating python notebook.\n {response.content}') - + raise dbt.exceptions.RuntimeException( + f"Error creating python notebook.\n {response.content}" + ) # submit job submit_response = requests.post( @@ -462,14 +444,15 @@ def submit_python_job(self, schema, identifier, file_contents, timeout=None): }, ) if submit_response.status_code != 200: - raise dbt.exceptions.RuntimeException(f'Error creating python run.\n {response.content}') - + raise dbt.exceptions.RuntimeException( + f"Error creating python run.\n {response.content}" + ) # poll until job finish state = None start = time.time() run_id = submit_response.json()["run_id"] - terminal_states = ['TERMINATED', 'SKIPPED', 'INTERNAL_ERROR'] + terminal_states = ["TERMINATED", "SKIPPED", "INTERNAL_ERROR"] while state not in terminal_states and time.time() - start < timeout: time.sleep(1) resp = requests.get( @@ -480,7 +463,9 @@ def submit_python_job(self, schema, identifier, file_contents, timeout=None): state = json_resp["state"]["life_cycle_state"] logger.debug(f"Polling.... in state: {state}") if state != "TERMINATED": - raise dbt.exceptions.RuntimeException(f"python model run ended in state {state} with state_message\n {json_resp['state']['state_message']}") + raise dbt.exceptions.RuntimeException( + f"python model run ended in state {state} with state_message\n{json_resp['state']['state_message']}" + ) # get end state to return to user run_output = requests.get( @@ -489,7 +474,7 @@ def submit_python_job(self, schema, identifier, file_contents, timeout=None): ) json_run_output = run_output.json() result_state = json_run_output["metadata"]["state"]["result_state"] - if result_state != 'SUCCESS': + if result_state != "SUCCESS": raise dbt.exceptions.RuntimeException( f"\ Python model failed with traceback as:\n \ From 0a6e673d12a092cb5bee24bb6bbfc9bcffcae901 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Wed, 1 Jun 2022 15:48:14 -0700 Subject: [PATCH 12/35] fix format --- dbt/adapters/spark/impl.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 5c5092b8e..ee60b219f 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -178,7 +178,7 @@ def parse_describe_extended( # Remove rows that start with a hash, they are comments rows = [row for row in raw_rows[0:pos] if not row["col_name"].startswith("#")] - metadata = {col["col_name"]: col["data_type"] for col in raw_rows[pos + 1 :]} + metadata = {col["col_name"]: col["data_type"] for col in raw_rows[pos + 1:]} raw_table_stats = metadata.get(KEY_TABLE_STATISTICS) table_stats = SparkColumn.convert_table_stats(raw_table_stats) @@ -456,7 +456,8 @@ def submit_python_job(self, schema, identifier, file_contents, timeout=None): while state not in terminal_states and time.time() - start < timeout: time.sleep(1) resp = requests.get( - f"https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/get?run_id={run_id}", + f"https://{self.connections.profile.credentials.host}" + f"/api/2.1/jobs/runs/get?run_id={run_id}", headers=auth_header, ) json_resp = resp.json() @@ -464,22 +465,24 @@ def submit_python_job(self, schema, identifier, file_contents, timeout=None): logger.debug(f"Polling.... in state: {state}") if state != "TERMINATED": raise dbt.exceptions.RuntimeException( - f"python model run ended in state {state} with state_message\n{json_resp['state']['state_message']}" + "python model run ended in state" + f"{state} with state_message\n{json_resp['state']['state_message']}" ) # get end state to return to user run_output = requests.get( - f"https://{self.connections.profile.credentials.host}/api/2.1/jobs/runs/get-output?run_id={run_id}", + f"https://{self.connections.profile.credentials.host}" + f"/api/2.1/jobs/runs/get-output?run_id={run_id}", headers=auth_header, ) json_run_output = run_output.json() result_state = json_run_output["metadata"]["state"]["result_state"] if result_state != "SUCCESS": raise dbt.exceptions.RuntimeException( - f"\ -Python model failed with traceback as:\n \ -(Note that the line number here does not match the line number in your code due to dbt templating)\n \ -{json_run_output['error_trace']}" + "Python model failed with traceback as:\n" + "(Note that the line number here does not " + "match the line number in your code due to dbt templating)\n" + f"{json_run_output['error_trace']}" ) return result_state From c29867ec9ea64b14b38dd162d8f96342e8f1dab1 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Thu, 2 Jun 2022 16:17:20 -0700 Subject: [PATCH 13/35] add functional test --- tests/conftest.py | 1 + tests/functional/adapter/test_basic.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 7ba95d47b..f2f0abcf0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -91,6 +91,7 @@ def databricks_http_cluster_target(): "connect_retries": 5, "connect_timeout": 60, "retry_all": bool(os.getenv('DBT_DATABRICKS_RETRY_ALL', False)), + "user": os.getenv('DBT_DATABRICKS_USER') } diff --git a/tests/functional/adapter/test_basic.py b/tests/functional/adapter/test_basic.py index 70f3267a4..04498052e 100644 --- a/tests/functional/adapter/test_basic.py +++ b/tests/functional/adapter/test_basic.py @@ -12,6 +12,7 @@ from dbt.tests.adapter.basic.test_snapshot_check_cols import BaseSnapshotCheckCols from dbt.tests.adapter.basic.test_snapshot_timestamp import BaseSnapshotTimestamp from dbt.tests.adapter.basic.test_adapter_methods import BaseAdapterMethod +from dbt.tests.adapter.python_model.test_python_model import BasePythonModelTests @pytest.mark.skip_profile('spark_session') @@ -80,4 +81,8 @@ def project_config_update(self): } class TestBaseAdapterMethod(BaseAdapterMethod): + pass + + +class TestBasePythonModelSnowflake(BasePythonModelTests): pass \ No newline at end of file From f87a30b5accd4322cc1ea0c7a90471f0a462a2b2 Mon Sep 17 00:00:00 2001 From: Ian Knox Date: Thu, 16 Jun 2022 13:19:02 -0500 Subject: [PATCH 14/35] first pass --- dbt/adapters/spark/impl.py | 4 +- .../incremental/incremental.sql | 57 ++++++++++++++++--- 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index ee60b219f..14653c291 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -384,7 +384,9 @@ def run_sql_for_tests(self, sql, fetch, conn): conn.transaction_open = False @available - def submit_python_job(self, schema, identifier, file_contents, timeout=None): + def submit_python_job(self, schema: str, identifier: str, file_contents: str, timeout=None): + # TODO improve the typing here. N.B. Jinja returns a `jinja2.runtime.Undefined` instead + # of `None` which evaluates to True! # TODO limit this function to run only when doing the materialization of python nodes diff --git a/dbt/include/spark/macros/materializations/incremental/incremental.sql b/dbt/include/spark/macros/materializations/incremental/incremental.sql index d0b6e89ba..6291afc5f 100644 --- a/dbt/include/spark/macros/materializations/incremental/incremental.sql +++ b/dbt/include/spark/macros/materializations/incremental/incremental.sql @@ -24,23 +24,62 @@ {% endcall %} {% endif %} + {% set language = config.get('language') %} + {{ run_hooks(pre_hooks) }} {% if existing_relation is none %} - {% set build_sql = create_table_as(False, target_relation, sql) %} + {{ log("#-- Relation must be created --#") }} + {% if language == 'sql'%} + {%- call statement('main') -%} + {{ create_table_as(False, target_relation, sql) }} + {%- endcall -%} + {% elif language == 'python' %} + {%- set python_code = py_complete_script(python_code=sql, target_relation=target_relation) -%} + {{ log("python code: " ~ python_code ) }} + {% set result = adapter.submit_python_job(schema, model['alias'], python_code) %} + {% call noop_statement('main', result, 'OK', 1) %} + -- python model return run result -- + {% endcall %} + {% endif %} {% elif existing_relation.is_view or full_refresh_mode %} + {{ log("#-- Relation must be dropped & recreated --#") }} {% do adapter.drop_relation(existing_relation) %} - {% set build_sql = create_table_as(False, target_relation, sql) %} + {% if language == 'sql'%} + {%- call statement('main') -%} + {{ create_table_as(False, target_relation, sql) }} + {%- endcall -%} + {% elif language == 'python' %} + {%- set python_code = py_complete_script(python_code=sql, target_relation=target_relation) -%} + {{ log("python code " ~ python_code ) }} + {% set result = adapter.submit_python_job(schema, model['alias'], python_code) %} + {% call noop_statement('main', result, 'OK', 1) %} + -- python model return run result -- + {% endcall %} + {% endif %} {% else %} - {% do run_query(create_table_as(True, tmp_relation, sql)) %} - {% do process_schema_changes(on_schema_change, tmp_relation, existing_relation) %} - {% set build_sql = dbt_spark_get_incremental_sql(strategy, tmp_relation, target_relation, unique_key) %} + {{ log("#-- Relation must be merged --#") }} + {% if language == 'sql'%} + {% do run_query(create_table_as(True, tmp_relation, sql)) %} + {% do process_schema_changes(on_schema_change, tmp_relation, existing_relation) %} + {%- call statement('main') -%} + {{ dbt_spark_get_incremental_sql(strategy, tmp_relation, target_relation, unique_key) }} + {%- endcall -%} + {% elif language == 'python' %} + {%- set python_code = py_complete_script(python_code=sql, target_relation=tmp_relation) -%} + {% set result = adapter.submit_python_job(schema, model['alias'], python_code) %} + {{ log("python code " ~ python_code ) }} + {% call noop_statement('main', result, 'OK', 1) %} + -- python model return run result -- + {% endcall %} + {{ log("XXXXXX-" ~ result) }} + {% do process_schema_changes(on_schema_change, tmp_relation, existing_relation) %} + {%- call statement('main') -%} + {{ dbt_spark_get_incremental_sql(strategy, tmp_relation, target_relation, unique_key) }} + {%- endcall -%} + {% endif %} {% endif %} - {%- call statement('main') -%} - {{ build_sql }} - {%- endcall -%} - {% do persist_docs(target_relation, model) %} {{ run_hooks(post_hooks) }} From d6ac3b96de54473ddf689317f3624e13477d0cb7 Mon Sep 17 00:00:00 2001 From: Ian Knox Date: Thu, 16 Jun 2022 15:52:32 -0500 Subject: [PATCH 15/35] cleanup , pt 1 --- .../incremental/incremental.sql | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/dbt/include/spark/macros/materializations/incremental/incremental.sql b/dbt/include/spark/macros/materializations/incremental/incremental.sql index 6291afc5f..3da7d5bc4 100644 --- a/dbt/include/spark/macros/materializations/incremental/incremental.sql +++ b/dbt/include/spark/macros/materializations/incremental/incremental.sql @@ -1,32 +1,33 @@ {% materialization incremental, adapter='spark' -%} - + {#-- Validate early so we don't run SQL if the file_format + strategy combo is invalid --#} {%- set raw_file_format = config.get('file_format', default='parquet') -%} {%- set raw_strategy = config.get('incremental_strategy', default='append') -%} - {%- set file_format = dbt_spark_validate_get_file_format(raw_file_format) -%} {%- set strategy = dbt_spark_validate_get_incremental_strategy(raw_strategy, file_format) -%} - + + {#-- Set vars --#} {%- set unique_key = config.get('unique_key', none) -%} - {%- set partition_by = config.get('partition_by', none) -%} + {%- set partition_by = config.get('partition_by', none) -%} + {%- set language = config.get('language') -%} + {%- set on_schema_change = incremental_validate_on_schema_change(config.get('on_schema_change'), default='ignore') -%} + {%- set target_relation = this -%} + {%- set existing_relation = load_relation(this) -%} + {%- set tmp_relation = make_temp_relation(this) -%} - {%- set full_refresh_mode = (should_full_refresh()) -%} - - {% set on_schema_change = incremental_validate_on_schema_change(config.get('on_schema_change'), default='ignore') %} + {%- if strategy == 'insert_overwrite' and partition_by -%} + {%- call statement() -%} + set spark.sql.sources.partitionOverwriteMode = DYNAMIC + {%- endcall -%} + {%- endif -%} - {% set target_relation = this %} - {% set existing_relation = load_relation(this) %} - {% set tmp_relation = make_temp_relation(this) %} + {#-- Run pre-hooks --#} + {{ run_hooks(pre_hooks) }} + + {#-- Incremental run logic --#} - {% if strategy == 'insert_overwrite' and partition_by %} - {% call statement() %} - set spark.sql.sources.partitionOverwriteMode = DYNAMIC - {% endcall %} - {% endif %} - {% set language = config.get('language') %} - {{ run_hooks(pre_hooks) }} {% if existing_relation is none %} {{ log("#-- Relation must be created --#") }} @@ -42,7 +43,7 @@ -- python model return run result -- {% endcall %} {% endif %} - {% elif existing_relation.is_view or full_refresh_mode %} + {% elif existing_relation.is_view or should_full_refresh() %} {{ log("#-- Relation must be dropped & recreated --#") }} {% do adapter.drop_relation(existing_relation) %} {% if language == 'sql'%} From ca04f357f8daa3ae67c409319a41b05b2a6ddd1b Mon Sep 17 00:00:00 2001 From: Ian Knox Date: Thu, 23 Jun 2022 15:56:10 -0500 Subject: [PATCH 16/35] cleaned up incremental logic --- dbt/include/spark/macros/adapters.sql | 54 +++++++----- .../incremental/incremental.sql | 85 +++++++------------ .../spark/macros/materializations/table.sql | 36 +++----- 3 files changed, 74 insertions(+), 101 deletions(-) diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index cb9dcc79d..28d89e85d 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -117,34 +117,42 @@ {%- endmacro %} -{% macro create_temporary_view(relation, sql) -%} - {{ return(adapter.dispatch('create_temporary_view', 'dbt')(relation, sql)) }} +{% macro create_temporary_view(relation, model_code, language) -%} + {{ return(adapter.dispatch('create_temporary_view', 'dbt')(relation, model_code, language)) }} {%- endmacro -%} -{#-- We can't use temporary tables with `create ... as ()` syntax #} -{% macro spark__create_temporary_view(relation, sql) -%} - create temporary view {{ relation.include(schema=false) }} as - {{ sql }} -{% endmacro %} +{#-- We can't use temporary tables with `create ... as ()` syntax --#} +{% macro spark__create_temporary_view(relation, model_code, language='sql') -%} + {%- if language == 'sql' -%} + create temporary view {{ relation.include(schema=false) }} as + {{ model_code }} + {%- elif language == 'python' -%} + {{ py_complete_script(python_code=model_code, target_relation=relation, is_tmp_view=True) }} + {%- endif -%} +{%- endmacro -%} -{% macro spark__create_table_as(temporary, relation, sql) -%} - {% if temporary -%} - {{ create_temporary_view(relation, sql) }} +{%- macro spark__create_table_as(temporary, relation, model_code, language='sql') -%} + {%- if temporary -%} + {{ create_temporary_view(relation, model_code, language) }} {%- else -%} - {% if config.get('file_format', validator=validation.any[basestring]) == 'delta' %} - create or replace table {{ relation }} - {% else %} - create table {{ relation }} - {% endif %} - {{ file_format_clause() }} - {{ options_clause() }} - {{ partition_cols(label="partitioned by") }} - {{ clustered_cols(label="clustered by") }} - {{ location_clause() }} - {{ comment_clause() }} - as - {{ sql }} + {%- if language == 'sql' -%} + {% if config.get('file_format', validator=validation.any[basestring]) == 'delta' %} + create or replace table {{ relation }} + {% else %} + create table {{ relation }} + {% endif %} + {{ file_format_clause() }} + {{ options_clause() }} + {{ partition_cols(label="partitioned by") }} + {{ clustered_cols(label="clustered by") }} + {{ location_clause() }} + {{ comment_clause() }} + as + {{ model_code }} + {%- elif language == 'python' -%} + {{ py_complete_script(python_code=model_code, target_relation=relation) }} + {%- endif -%} {%- endif %} {%- endmacro -%} diff --git a/dbt/include/spark/macros/materializations/incremental/incremental.sql b/dbt/include/spark/macros/materializations/incremental/incremental.sql index 3da7d5bc4..dd4a3f39b 100644 --- a/dbt/include/spark/macros/materializations/incremental/incremental.sql +++ b/dbt/include/spark/macros/materializations/incremental/incremental.sql @@ -14,7 +14,9 @@ {%- set target_relation = this -%} {%- set existing_relation = load_relation(this) -%} {%- set tmp_relation = make_temp_relation(this) -%} + {%- set model_code = sql -%} + {#-- Set Overwrite Mode --#} {%- if strategy == 'insert_overwrite' and partition_by -%} {%- call statement() -%} set spark.sql.sources.partitionOverwriteMode = DYNAMIC @@ -25,62 +27,33 @@ {{ run_hooks(pre_hooks) }} {#-- Incremental run logic --#} - - - - - {% if existing_relation is none %} - {{ log("#-- Relation must be created --#") }} - {% if language == 'sql'%} - {%- call statement('main') -%} - {{ create_table_as(False, target_relation, sql) }} - {%- endcall -%} - {% elif language == 'python' %} - {%- set python_code = py_complete_script(python_code=sql, target_relation=target_relation) -%} - {{ log("python code: " ~ python_code ) }} - {% set result = adapter.submit_python_job(schema, model['alias'], python_code) %} - {% call noop_statement('main', result, 'OK', 1) %} - -- python model return run result -- - {% endcall %} - {% endif %} - {% elif existing_relation.is_view or should_full_refresh() %} - {{ log("#-- Relation must be dropped & recreated --#") }} - {% do adapter.drop_relation(existing_relation) %} - {% if language == 'sql'%} - {%- call statement('main') -%} - {{ create_table_as(False, target_relation, sql) }} - {%- endcall -%} - {% elif language == 'python' %} - {%- set python_code = py_complete_script(python_code=sql, target_relation=target_relation) -%} - {{ log("python code " ~ python_code ) }} - {% set result = adapter.submit_python_job(schema, model['alias'], python_code) %} - {% call noop_statement('main', result, 'OK', 1) %} - -- python model return run result -- - {% endcall %} - {% endif %} - {% else %} - {{ log("#-- Relation must be merged --#") }} - {% if language == 'sql'%} - {% do run_query(create_table_as(True, tmp_relation, sql)) %} - {% do process_schema_changes(on_schema_change, tmp_relation, existing_relation) %} - {%- call statement('main') -%} - {{ dbt_spark_get_incremental_sql(strategy, tmp_relation, target_relation, unique_key) }} - {%- endcall -%} - {% elif language == 'python' %} - {%- set python_code = py_complete_script(python_code=sql, target_relation=tmp_relation) -%} - {% set result = adapter.submit_python_job(schema, model['alias'], python_code) %} - {{ log("python code " ~ python_code ) }} - {% call noop_statement('main', result, 'OK', 1) %} - -- python model return run result -- - {% endcall %} - {{ log("XXXXXX-" ~ result) }} - {% do process_schema_changes(on_schema_change, tmp_relation, existing_relation) %} - {%- call statement('main') -%} - {{ dbt_spark_get_incremental_sql(strategy, tmp_relation, target_relation, unique_key) }} - {%- endcall -%} - {% endif %} - {% endif %} - + {%- if existing_relation is none -%} + {#-- Relation must be created --#} + {{log("make rel")}} + {%- call statement('main', language=language) -%} + {{ create_table_as(False, target_relation, model_code, language) }} + {%- endcall -%} + {%- elif existing_relation.is_view or should_full_refresh() -%} + {#-- Relation must be dropped & recreated --#} + {{log("remake rel")}} + {%- do adapter.drop_relation(existing_relation) -%} + {%- call statement('main', language=language) -%} + {{ create_table_as(False, target_relation, model_code, language) }} + {%- endcall -%} + {%- else -%} + {#-- Relation must be merged --#} + {{log("merge rel")}} + {%- call statement('create_tmp_relation', language=language) -%} + {{ create_table_as(True, tmp_relation, model_code, language) }} + {%- endcall -%} + {%- do process_schema_changes(on_schema_change, tmp_relation, existing_relation) -%} + {%- call statement('main') -%} + {{ dbt_spark_get_incremental_sql(strategy, tmp_relation, target_relation, unique_key) }} + {%- endcall -%} + {%- endif -%} + + {{ log("Inc logic complete") }} + {% do persist_docs(target_relation, model) %} {{ run_hooks(post_hooks) }} diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql index 6d7a8dc75..2c82a0e94 100644 --- a/dbt/include/spark/macros/materializations/table.sql +++ b/dbt/include/spark/macros/materializations/table.sql @@ -1,7 +1,8 @@ {% materialization table, adapter = 'spark' %} + {%- set language = config.get('language') -%} {%- set identifier = model['alias'] -%} - + {%- set model_code = sql -%} {%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%} {%- set target_relation = api.Relation.create(identifier=identifier, schema=schema, @@ -17,22 +18,10 @@ {{ adapter.drop_relation(old_relation) }} {%- endif %} - -- build model - {% if config.get('language', 'sql') == 'python' -%}} - -- sql here is really just the compiled python code - {%- set python_code = py_complete_script(python_code=sql, target_relation=target_relation) -%} - {{ log("python code " ~ python_code ) }} - {% set result = adapter.submit_python_job(schema, identifier, python_code) %} - {% call noop_statement('main', result, 'OK', 1) %} - -- python model return run result -- - {% endcall %} - - {%- else -%} - {% call statement('main') -%} - {{ create_table_as(False, target_relation, sql) }} - {%- endcall %} - {%- endif %} - + {%- call statement('main', language=language) -%} + {{ create_table_as(False, target_relation, model_code, language) }} + {%- endcall -%} + {% do persist_docs(target_relation, model) %} {{ run_hooks(post_hooks) }} @@ -42,13 +31,16 @@ {% endmaterialization %} -{% macro py_complete_script(python_code, target_relation) %} +{% macro py_complete_script(python_code, target_relation, is_tmp_view=False) %} {{ python_code }} -df = model(dbt) -# COMMAND ---------- -# this is materialization code dbt generated, please do not modify +# --- Autogenerated dbt code below this line. Do not modify. --- # +df = model(dbt) +{%- if is_tmp_view %} +df.createTempView("{{ target_relation }}") +{%- else %} df.write.mode("overwrite").format("delta").saveAsTable("{{ target_relation }}") -{% endmacro %} +{%- endif -%} +{%- endmacro -%} From d639594c1a9727dd951ea13c49b42e6946b77ef3 Mon Sep 17 00:00:00 2001 From: Ian Knox Date: Fri, 24 Jun 2022 13:59:43 -0500 Subject: [PATCH 17/35] cleanup, add is_incremental --- dbt/include/spark/macros/adapters.sql | 27 ++++++++++--------- .../incremental/incremental.sql | 12 +++++++++ .../spark/macros/materializations/table.sql | 7 +---- 3 files changed, 28 insertions(+), 18 deletions(-) diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index 28d89e85d..ac6b6b1ef 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -122,21 +122,17 @@ {%- endmacro -%} {#-- We can't use temporary tables with `create ... as ()` syntax --#} -{% macro spark__create_temporary_view(relation, model_code, language='sql') -%} - {%- if language == 'sql' -%} +{% macro spark__create_temporary_view(relation, model_code) -%} create temporary view {{ relation.include(schema=false) }} as {{ model_code }} - {%- elif language == 'python' -%} - {{ py_complete_script(python_code=model_code, target_relation=relation, is_tmp_view=True) }} - {%- endif -%} {%- endmacro -%} {%- macro spark__create_table_as(temporary, relation, model_code, language='sql') -%} - {%- if temporary -%} - {{ create_temporary_view(relation, model_code, language) }} - {%- else -%} - {%- if language == 'sql' -%} + {%- if language == 'sql' -%} + {%- if temporary -%} + {{ create_temporary_view(relation, model_code, language) }} + {%- else -%} {% if config.get('file_format', validator=validation.any[basestring]) == 'delta' %} create or replace table {{ relation }} {% else %} @@ -150,10 +146,17 @@ {{ comment_clause() }} as {{ model_code }} - {%- elif language == 'python' -%} - {{ py_complete_script(python_code=model_code, target_relation=relation) }} {%- endif -%} - {%- endif %} + {%- elif language == 'python' -%} + {#-- + N.B. Python models _can_ write to temp views HOWEVER they use a different session + and have already expired by the time they need to be used (I.E. in merges for incremental models) + + TODO: Deep dive into spark sessions to see if we can reuse a single session for an entire + dbt invocation. + --#} + {{ py_complete_script(python_code=model_code, target_relation=relation) }} + {%- endif -%} {%- endmacro -%} diff --git a/dbt/include/spark/macros/materializations/incremental/incremental.sql b/dbt/include/spark/macros/materializations/incremental/incremental.sql index dd4a3f39b..9fc0527f7 100644 --- a/dbt/include/spark/macros/materializations/incremental/incremental.sql +++ b/dbt/include/spark/macros/materializations/incremental/incremental.sql @@ -50,6 +50,18 @@ {%- call statement('main') -%} {{ dbt_spark_get_incremental_sql(strategy, tmp_relation, target_relation, unique_key) }} {%- endcall -%} + {%- if language == 'python' -%} + {#-- + This is yucky. + See note in dbt-spark/dbt/include/spark/macros/adapters.sql + re: python models and temporary views. + + Also, why doesn't either drop_relation or adapter.drop_relation work here?! + --#} + {% call statement('drop_relation') -%} + drop table if exists {{ tmp_relation }} + {%- endcall %} + {%- endif -%} {%- endif -%} {{ log("Inc logic complete") }} diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql index 2c82a0e94..edebd2c87 100644 --- a/dbt/include/spark/macros/materializations/table.sql +++ b/dbt/include/spark/macros/materializations/table.sql @@ -31,16 +31,11 @@ {% endmaterialization %} -{% macro py_complete_script(python_code, target_relation, is_tmp_view=False) %} +{% macro py_complete_script(python_code, target_relation) %} {{ python_code }} # --- Autogenerated dbt code below this line. Do not modify. --- # df = model(dbt) - -{%- if is_tmp_view %} -df.createTempView("{{ target_relation }}") -{%- else %} df.write.mode("overwrite").format("delta").saveAsTable("{{ target_relation }}") -{%- endif -%} {%- endmacro -%} From f5c178e4251e78a687b245986f907bceae0500eb Mon Sep 17 00:00:00 2001 From: Ian Knox Date: Fri, 24 Jun 2022 14:24:15 -0500 Subject: [PATCH 18/35] remove debug logging --- .../spark/macros/materializations/incremental/incremental.sql | 3 --- 1 file changed, 3 deletions(-) diff --git a/dbt/include/spark/macros/materializations/incremental/incremental.sql b/dbt/include/spark/macros/materializations/incremental/incremental.sql index 9fc0527f7..6b109e132 100644 --- a/dbt/include/spark/macros/materializations/incremental/incremental.sql +++ b/dbt/include/spark/macros/materializations/incremental/incremental.sql @@ -29,20 +29,17 @@ {#-- Incremental run logic --#} {%- if existing_relation is none -%} {#-- Relation must be created --#} - {{log("make rel")}} {%- call statement('main', language=language) -%} {{ create_table_as(False, target_relation, model_code, language) }} {%- endcall -%} {%- elif existing_relation.is_view or should_full_refresh() -%} {#-- Relation must be dropped & recreated --#} - {{log("remake rel")}} {%- do adapter.drop_relation(existing_relation) -%} {%- call statement('main', language=language) -%} {{ create_table_as(False, target_relation, model_code, language) }} {%- endcall -%} {%- else -%} {#-- Relation must be merged --#} - {{log("merge rel")}} {%- call statement('create_tmp_relation', language=language) -%} {{ create_table_as(True, tmp_relation, model_code, language) }} {%- endcall -%} From 7a44feb1b3997e646f4bd701b9590e4d59a40b43 Mon Sep 17 00:00:00 2001 From: Ian Knox Date: Mon, 27 Jun 2022 12:50:48 -0500 Subject: [PATCH 19/35] flake8 --- dbt/adapters/spark/impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 14653c291..3563f7c3d 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -385,7 +385,7 @@ def run_sql_for_tests(self, sql, fetch, conn): @available def submit_python_job(self, schema: str, identifier: str, file_contents: str, timeout=None): - # TODO improve the typing here. N.B. Jinja returns a `jinja2.runtime.Undefined` instead + # TODO improve the typing here. N.B. Jinja returns a `jinja2.runtime.Undefined` instead # of `None` which evaluates to True! # TODO limit this function to run only when doing the materialization of python nodes From d7b06d4ae1302cccd9aa8831c141c6d07b112532 Mon Sep 17 00:00:00 2001 From: Ian Knox Date: Mon, 27 Jun 2022 12:57:01 -0500 Subject: [PATCH 20/35] removed python lang from temp views for now --- dbt/include/spark/macros/adapters.sql | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index ac6b6b1ef..a6720b246 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -117,8 +117,8 @@ {%- endmacro %} -{% macro create_temporary_view(relation, model_code, language) -%} - {{ return(adapter.dispatch('create_temporary_view', 'dbt')(relation, model_code, language)) }} +{% macro create_temporary_view(relation, model_code) -%} + {{ return(adapter.dispatch('create_temporary_view', 'dbt')(relation, model_code)) }} {%- endmacro -%} {#-- We can't use temporary tables with `create ... as ()` syntax --#} @@ -131,7 +131,7 @@ {%- macro spark__create_table_as(temporary, relation, model_code, language='sql') -%} {%- if language == 'sql' -%} {%- if temporary -%} - {{ create_temporary_view(relation, model_code, language) }} + {{ create_temporary_view(relation, model_code) }} {%- else -%} {% if config.get('file_format', validator=validation.any[basestring]) == 'delta' %} create or replace table {{ relation }} From 4d4ae513dc21a016bdc7d778e93b52d21e68ac6e Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Fri, 17 Jun 2022 16:16:14 -0700 Subject: [PATCH 21/35] add change schema test --- tests/functional/adapter/test_basic.py | 5 ---- tests/functional/adapter/test_python.py | 38 +++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 5 deletions(-) create mode 100644 tests/functional/adapter/test_python.py diff --git a/tests/functional/adapter/test_basic.py b/tests/functional/adapter/test_basic.py index 04498052e..70f3267a4 100644 --- a/tests/functional/adapter/test_basic.py +++ b/tests/functional/adapter/test_basic.py @@ -12,7 +12,6 @@ from dbt.tests.adapter.basic.test_snapshot_check_cols import BaseSnapshotCheckCols from dbt.tests.adapter.basic.test_snapshot_timestamp import BaseSnapshotTimestamp from dbt.tests.adapter.basic.test_adapter_methods import BaseAdapterMethod -from dbt.tests.adapter.python_model.test_python_model import BasePythonModelTests @pytest.mark.skip_profile('spark_session') @@ -81,8 +80,4 @@ def project_config_update(self): } class TestBaseAdapterMethod(BaseAdapterMethod): - pass - - -class TestBasePythonModelSnowflake(BasePythonModelTests): pass \ No newline at end of file diff --git a/tests/functional/adapter/test_python.py b/tests/functional/adapter/test_python.py new file mode 100644 index 000000000..103ed1829 --- /dev/null +++ b/tests/functional/adapter/test_python.py @@ -0,0 +1,38 @@ +import pytest +from dbt.tests.util import run_dbt, write_file +from dbt.tests.adapter.python_model.test_python_model import BasePythonModelTests + +class TestPythonModelSpark(BasePythonModelTests): + pass + +models__simple_python_model = """ +import pandas + +def model(dbt): + dbt.config( + materialized='table', + ) + data = [[1,2]] * 10 + return spark.createDataFrame(data, schema=['test', 'test2']) +""" +models__simple_python_model_v2 = """ +import pandas + +def model(dbt): + dbt.config( + materialized='table', + ) + data = [[1,2]] * 10 + return spark.createDataFrame(data, schema=['test1', 'test3']) +""" + +class TestChangingSchemaSnowflake: + @pytest.fixture(scope="class") + def models(self): + return { + "simple_python_model.py": models__simple_python_model + } + def test_changing_schema(self,project): + run_dbt(["run"]) + write_file(models__simple_python_model_v2, project.project_root + '/models', "simple_python_model.py") + run_dbt(["run"]) \ No newline at end of file From 8b95b2e18c4cbbd8b91a2d31ee29e9d5bf51a790 Mon Sep 17 00:00:00 2001 From: Ian Knox Date: Mon, 27 Jun 2022 13:07:05 -0500 Subject: [PATCH 22/35] removed log line --- .../spark/macros/materializations/incremental/incremental.sql | 2 -- 1 file changed, 2 deletions(-) diff --git a/dbt/include/spark/macros/materializations/incremental/incremental.sql b/dbt/include/spark/macros/materializations/incremental/incremental.sql index 6b109e132..8fe1b3d31 100644 --- a/dbt/include/spark/macros/materializations/incremental/incremental.sql +++ b/dbt/include/spark/macros/materializations/incremental/incremental.sql @@ -61,8 +61,6 @@ {%- endif -%} {%- endif -%} - {{ log("Inc logic complete") }} - {% do persist_docs(target_relation, model) %} {{ run_hooks(post_hooks) }} From a758930b85b69baada39ab5e05e132f34fe6f76c Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Mon, 27 Jun 2022 22:35:28 -0700 Subject: [PATCH 23/35] more restiction and adjust syntax --- dbt/adapters/spark/impl.py | 2 +- dbt/include/spark/macros/materializations/table.sql | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 3563f7c3d..7e148a472 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -383,7 +383,7 @@ def run_sql_for_tests(self, sql, fetch, conn): finally: conn.transaction_open = False - @available + @available.parse_none def submit_python_job(self, schema: str, identifier: str, file_contents: str, timeout=None): # TODO improve the typing here. N.B. Jinja returns a `jinja2.runtime.Undefined` instead # of `None` which evaluates to True! diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql index edebd2c87..ee3508f16 100644 --- a/dbt/include/spark/macros/materializations/table.sql +++ b/dbt/include/spark/macros/materializations/table.sql @@ -2,7 +2,6 @@ {%- set language = config.get('language') -%} {%- set identifier = model['alias'] -%} - {%- set model_code = sql -%} {%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%} {%- set target_relation = api.Relation.create(identifier=identifier, schema=schema, @@ -19,7 +18,7 @@ {%- endif %} {%- call statement('main', language=language) -%} - {{ create_table_as(False, target_relation, model_code, language) }} + {{ create_table_as(False, target_relation, compiled_code, language) }} {%- endcall -%} {% do persist_docs(target_relation, model) %} @@ -36,6 +35,7 @@ # --- Autogenerated dbt code below this line. Do not modify. --- # +dbt = dbtObj(spark.table) df = model(dbt) df.write.mode("overwrite").format("delta").saveAsTable("{{ target_relation }}") {%- endmacro -%} From 88b7ad4b4063f07981b6bac2feed6c741bdee37a Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Tue, 28 Jun 2022 16:11:23 -0700 Subject: [PATCH 24/35] adjust name for incremental model --- .../macros/materializations/incremental/incremental.sql | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dbt/include/spark/macros/materializations/incremental/incremental.sql b/dbt/include/spark/macros/materializations/incremental/incremental.sql index 8fe1b3d31..aadc8c743 100644 --- a/dbt/include/spark/macros/materializations/incremental/incremental.sql +++ b/dbt/include/spark/macros/materializations/incremental/incremental.sql @@ -14,7 +14,6 @@ {%- set target_relation = this -%} {%- set existing_relation = load_relation(this) -%} {%- set tmp_relation = make_temp_relation(this) -%} - {%- set model_code = sql -%} {#-- Set Overwrite Mode --#} {%- if strategy == 'insert_overwrite' and partition_by -%} @@ -30,18 +29,18 @@ {%- if existing_relation is none -%} {#-- Relation must be created --#} {%- call statement('main', language=language) -%} - {{ create_table_as(False, target_relation, model_code, language) }} + {{ create_table_as(False, target_relation, compiled_code, language) }} {%- endcall -%} {%- elif existing_relation.is_view or should_full_refresh() -%} {#-- Relation must be dropped & recreated --#} {%- do adapter.drop_relation(existing_relation) -%} {%- call statement('main', language=language) -%} - {{ create_table_as(False, target_relation, model_code, language) }} + {{ create_table_as(False, target_relation, compiled_code, language) }} {%- endcall -%} {%- else -%} {#-- Relation must be merged --#} {%- call statement('create_tmp_relation', language=language) -%} - {{ create_table_as(True, tmp_relation, model_code, language) }} + {{ create_table_as(True, tmp_relation, compiled_code, language) }} {%- endcall -%} {%- do process_schema_changes(on_schema_change, tmp_relation, existing_relation) -%} {%- call statement('main') -%} From 85a49ae926f466b1263abe20350cfc2d51876dd6 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Tue, 28 Jun 2022 20:43:44 -0700 Subject: [PATCH 25/35] stage changes --- dbt/adapters/spark/impl.py | 6 ++++-- dbt/include/spark/macros/adapters.sql | 16 ++++++++-------- .../materializations/incremental/incremental.sql | 1 - .../spark/macros/materializations/table.sql | 4 ++-- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 7e148a472..bd238a073 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -384,13 +384,15 @@ def run_sql_for_tests(self, sql, fetch, conn): conn.transaction_open = False @available.parse_none - def submit_python_job(self, schema: str, identifier: str, file_contents: str, timeout=None): + def submit_python_job(self, parsed_model:dict, compiled_code: str, timeout=None): # TODO improve the typing here. N.B. Jinja returns a `jinja2.runtime.Undefined` instead # of `None` which evaluates to True! # TODO limit this function to run only when doing the materialization of python nodes # assuming that for python job running over 1 day user would mannually overwrite this + schema = getattr(parsed_model, "schema", self.config.credentials.schema) + identifier = parsed_model['alias'] if not timeout: timeout = 60 * 60 * 24 if timeout <= 0: @@ -416,7 +418,7 @@ def submit_python_job(self, schema: str, identifier: str, file_contents: str, ti ) # add notebook - b64_encoded_content = base64.b64encode(file_contents.encode()).decode() + b64_encoded_content = base64.b64encode(compiled_code.encode()).decode() response = requests.post( f"https://{self.connections.profile.credentials.host}/api/2.0/workspace/import", headers=auth_header, diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index a6720b246..83e1a5932 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -117,21 +117,21 @@ {%- endmacro %} -{% macro create_temporary_view(relation, model_code) -%} - {{ return(adapter.dispatch('create_temporary_view', 'dbt')(relation, model_code)) }} +{% macro create_temporary_view(relation, compiled_code) -%} + {{ return(adapter.dispatch('create_temporary_view', 'dbt')(relation, compiled_code)) }} {%- endmacro -%} {#-- We can't use temporary tables with `create ... as ()` syntax --#} -{% macro spark__create_temporary_view(relation, model_code) -%} +{% macro spark__create_temporary_view(relation, compiled_code) -%} create temporary view {{ relation.include(schema=false) }} as - {{ model_code }} + {{ compiled_code }} {%- endmacro -%} -{%- macro spark__create_table_as(temporary, relation, model_code, language='sql') -%} +{%- macro spark__create_table_as(temporary, relation, compiled_code, language='sql') -%} {%- if language == 'sql' -%} {%- if temporary -%} - {{ create_temporary_view(relation, model_code) }} + {{ create_temporary_view(relation, compiled_code) }} {%- else -%} {% if config.get('file_format', validator=validation.any[basestring]) == 'delta' %} create or replace table {{ relation }} @@ -145,7 +145,7 @@ {{ location_clause() }} {{ comment_clause() }} as - {{ model_code }} + {{ compiled_code }} {%- endif -%} {%- elif language == 'python' -%} {#-- @@ -155,7 +155,7 @@ TODO: Deep dive into spark sessions to see if we can reuse a single session for an entire dbt invocation. --#} - {{ py_complete_script(python_code=model_code, target_relation=relation) }} + {{ py_complete_script(compiled_code=compiled_code, target_relation=relation) }} {%- endif -%} {%- endmacro -%} diff --git a/dbt/include/spark/macros/materializations/incremental/incremental.sql b/dbt/include/spark/macros/materializations/incremental/incremental.sql index aadc8c743..878a338f2 100644 --- a/dbt/include/spark/macros/materializations/incremental/incremental.sql +++ b/dbt/include/spark/macros/materializations/incremental/incremental.sql @@ -1,5 +1,4 @@ {% materialization incremental, adapter='spark' -%} - {#-- Validate early so we don't run SQL if the file_format + strategy combo is invalid --#} {%- set raw_file_format = config.get('file_format', default='parquet') -%} {%- set raw_strategy = config.get('incremental_strategy', default='append') -%} diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql index ee3508f16..192435427 100644 --- a/dbt/include/spark/macros/materializations/table.sql +++ b/dbt/include/spark/macros/materializations/table.sql @@ -1,5 +1,5 @@ {% materialization table, adapter = 'spark' %} - + {{debug()}} {%- set language = config.get('language') -%} {%- set identifier = model['alias'] -%} {%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%} @@ -30,7 +30,7 @@ {% endmaterialization %} -{% macro py_complete_script(python_code, target_relation) %} +{% macro py_complete_script(compiled_code, target_relation) %} {{ python_code }} From 3ee0a42052b4fab36f34de48af930fb3b5f03d07 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Tue, 28 Jun 2022 20:56:35 -0700 Subject: [PATCH 26/35] fixed it --- dbt/include/spark/macros/materializations/table.sql | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql index 192435427..8ac005bf7 100644 --- a/dbt/include/spark/macros/materializations/table.sql +++ b/dbt/include/spark/macros/materializations/table.sql @@ -1,5 +1,4 @@ {% materialization table, adapter = 'spark' %} - {{debug()}} {%- set language = config.get('language') -%} {%- set identifier = model['alias'] -%} {%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%} @@ -31,8 +30,7 @@ {% macro py_complete_script(compiled_code, target_relation) %} -{{ python_code }} - +{{ compiled_code }} # --- Autogenerated dbt code below this line. Do not modify. --- # dbt = dbtObj(spark.table) From d596866b7baf3f13493c52d7f09fe827d5e4e5bf Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Thu, 14 Jul 2022 17:00:44 -0700 Subject: [PATCH 27/35] remove unneed macro --- dbt/include/spark/macros/adapters.sql | 3 --- 1 file changed, 3 deletions(-) diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index 83e1a5932..2811e7b45 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -296,6 +296,3 @@ {% endmacro %} -{% macro load_df_def() %} - load_df_function = spark.table -{% endmacro %} From 27c1441ad34456bc77ef37bbf6299cf8fcf0e7d4 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Mon, 18 Jul 2022 17:43:10 -0700 Subject: [PATCH 28/35] minic result for python job --- dbt/adapters/spark/impl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index bd238a073..dae2aaa4f 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -488,7 +488,8 @@ def submit_python_job(self, parsed_model:dict, compiled_code: str, timeout=None) "match the line number in your code due to dbt templating)\n" f"{json_run_output['error_trace']}" ) - return result_state + return self.connections.get_response(None) + # spark does something interesting with joins when both tables have the same From 9eea39616b400b3d71d78307a6fcaa1b4ea30022 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Mon, 25 Jul 2022 21:21:14 -0700 Subject: [PATCH 29/35] fix python model test (#406) --- .bumpversion.cfg | 7 +- .circleci/config.yml | 25 +-- .flake8 | 12 + .github/ISSUE_TEMPLATE/dependabot.yml | 2 +- .github/ISSUE_TEMPLATE/release.md | 2 +- .github/pull_request_template.md | 2 +- .github/workflows/jira-creation.yml | 2 +- .github/workflows/jira-label.yml | 3 +- .github/workflows/jira-transition.yml | 2 +- .github/workflows/main.yml | 68 +++--- .github/workflows/release.yml | 33 ++- .github/workflows/stale.yml | 2 - .github/workflows/version-bump.yml | 20 +- .gitignore | 53 ++++- .pre-commit-config.yaml | 66 ++++++ CHANGELOG.md | 43 +++- CONTRIBUTING.md | 101 +++++++++ MANIFEST.in | 2 +- Makefile | 56 +++++ README.md | 2 +- dbt/adapters/spark/__init__.py | 5 +- dbt/adapters/spark/column.py | 33 +-- dbt/adapters/spark/connections.py | 210 ++++++++---------- dbt/adapters/spark/impl.py | 53 +++-- dbt/adapters/spark/relation.py | 10 +- dbt/adapters/spark/session.py | 22 +- dbt/include/spark/__init__.py | 1 + dbt/include/spark/macros/adapters.sql | 37 +-- dbt/include/spark/macros/apply_grants.sql | 39 ++++ .../incremental/incremental.sql | 17 +- .../incremental/strategies.sql | 10 +- .../materializations/incremental/validate.sql | 4 +- .../macros/materializations/snapshot.sql | 8 +- .../spark/macros/materializations/table.sql | 12 +- dbt/include/spark/macros/utils/any_value.sql | 5 + .../spark/macros/utils/assert_not_null.sql | 9 + dbt/include/spark/macros/utils/bool_or.sql | 11 + dbt/include/spark/macros/utils/concat.sql | 3 + dbt/include/spark/macros/utils/dateadd.sql | 62 ++++++ dbt/include/spark/macros/utils/datediff.sql | 107 +++++++++ dbt/include/spark/macros/utils/listagg.sql | 17 ++ dbt/include/spark/macros/utils/split_part.sql | 23 ++ dev_requirements.txt => dev-requirements.txt | 24 +- docker-compose.yml | 4 +- docker/spark-defaults.conf | 4 +- requirements.txt | 3 +- scripts/build-dist.sh | 2 +- setup.py | 72 +++--- test.env.example | 15 ++ tests/conftest.py | 4 +- tests/functional/adapter/test_basic.py | 5 +- tests/functional/adapter/test_grants.py | 60 +++++ tests/functional/adapter/test_python.py | 38 ---- tests/functional/adapter/test_python_model.py | 55 +++++ .../adapter/utils/fixture_listagg.py | 61 +++++ .../adapter/utils/test_data_types.py | 67 ++++++ tests/functional/adapter/utils/test_utils.py | 122 ++++++++++ tox.ini | 20 +- 58 files changed, 1335 insertions(+), 422 deletions(-) create mode 100644 .flake8 create mode 100644 .pre-commit-config.yaml create mode 100644 CONTRIBUTING.md create mode 100644 Makefile create mode 100644 dbt/include/spark/macros/apply_grants.sql create mode 100644 dbt/include/spark/macros/utils/any_value.sql create mode 100644 dbt/include/spark/macros/utils/assert_not_null.sql create mode 100644 dbt/include/spark/macros/utils/bool_or.sql create mode 100644 dbt/include/spark/macros/utils/concat.sql create mode 100644 dbt/include/spark/macros/utils/dateadd.sql create mode 100644 dbt/include/spark/macros/utils/datediff.sql create mode 100644 dbt/include/spark/macros/utils/listagg.sql create mode 100644 dbt/include/spark/macros/utils/split_part.sql rename dev_requirements.txt => dev-requirements.txt (51%) create mode 100644 test.env.example create mode 100644 tests/functional/adapter/test_grants.py delete mode 100644 tests/functional/adapter/test_python.py create mode 100644 tests/functional/adapter/test_python_model.py create mode 100644 tests/functional/adapter/utils/fixture_listagg.py create mode 100644 tests/functional/adapter/utils/test_data_types.py create mode 100644 tests/functional/adapter/utils/test_utils.py diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 744284849..605b6f378 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,10 +1,10 @@ [bumpversion] -current_version = 1.2.0a1 +current_version = 1.3.0a1 parse = (?P\d+) \.(?P\d+) \.(?P\d+) ((?Pa|b|rc)(?P\d+))? -serialize = +serialize = {major}.{minor}.{patch}{prerelease}{num} {major}.{minor}.{patch} commit = False @@ -13,7 +13,7 @@ tag = False [bumpversion:part:prerelease] first_value = a optional_value = final -values = +values = a b rc @@ -25,4 +25,3 @@ first_value = 1 [bumpversion:file:setup.py] [bumpversion:file:dbt/adapters/spark/__version__.py] - diff --git a/.circleci/config.yml b/.circleci/config.yml index 34e449acf..8f0afa6ce 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -33,29 +33,12 @@ jobs: DBT_INVOCATION_ENV: circle docker: - image: fishtownanalytics/test-container:10 - - image: godatadriven/spark:2 + - image: godatadriven/spark:3.1.1 environment: WAIT_FOR: localhost:5432 command: > --class org.apache.spark.sql.hive.thriftserver.HiveThriftServer2 --name Thrift JDBC/ODBC Server - --conf spark.hadoop.javax.jdo.option.ConnectionURL=jdbc:postgresql://localhost/metastore - --conf spark.hadoop.javax.jdo.option.ConnectionUserName=dbt - --conf spark.hadoop.javax.jdo.option.ConnectionPassword=dbt - --conf spark.hadoop.javax.jdo.option.ConnectionDriverName=org.postgresql.Driver - --conf spark.serializer=org.apache.spark.serializer.KryoSerializer - --conf spark.jars.packages=org.apache.hudi:hudi-spark-bundle_2.11:0.9.0 - --conf spark.sql.extensions=org.apache.spark.sql.hudi.HoodieSparkSessionExtension - --conf spark.driver.userClassPathFirst=true - --conf spark.hadoop.datanucleus.autoCreateTables=true - --conf spark.hadoop.datanucleus.schema.autoCreateTables=true - --conf spark.hadoop.datanucleus.fixedDatastore=false - --conf spark.sql.hive.convertMetastoreParquet=false - --hiveconf hoodie.datasource.hive_sync.use_jdbc=false - --hiveconf hoodie.datasource.hive_sync.mode=hms - --hiveconf datanucleus.schema.autoCreateAll=true - --hiveconf hive.metastore.schema.verification=false - - image: postgres:9.6.17-alpine environment: POSTGRES_USER: dbt @@ -80,6 +63,9 @@ jobs: environment: DBT_INVOCATION_ENV: circle DBT_DATABRICKS_RETRY_ALL: True + DBT_TEST_USER_1: "buildbot+dbt_test_user_1@dbtlabs.com" + DBT_TEST_USER_2: "buildbot+dbt_test_user_2@dbtlabs.com" + DBT_TEST_USER_3: "buildbot+dbt_test_user_3@dbtlabs.com" docker: - image: fishtownanalytics/test-container:10 steps: @@ -95,6 +81,9 @@ jobs: environment: DBT_INVOCATION_ENV: circle ODBC_DRIVER: Simba # TODO: move env var to Docker image + DBT_TEST_USER_1: "buildbot+dbt_test_user_1@dbtlabs.com" + DBT_TEST_USER_2: "buildbot+dbt_test_user_2@dbtlabs.com" + DBT_TEST_USER_3: "buildbot+dbt_test_user_3@dbtlabs.com" docker: # image based on `fishtownanalytics/test-container` w/ Simba ODBC Spark driver installed - image: 828731156495.dkr.ecr.us-east-1.amazonaws.com/dbt-spark-odbc-test-container:latest diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..f39d154c0 --- /dev/null +++ b/.flake8 @@ -0,0 +1,12 @@ +[flake8] +select = + E + W + F +ignore = + W503 # makes Flake8 work like black + W504 + E203 # makes Flake8 work like black + E741 + E501 +exclude = test diff --git a/.github/ISSUE_TEMPLATE/dependabot.yml b/.github/ISSUE_TEMPLATE/dependabot.yml index 8a8c85b9f..2a6f34492 100644 --- a/.github/ISSUE_TEMPLATE/dependabot.yml +++ b/.github/ISSUE_TEMPLATE/dependabot.yml @@ -5,4 +5,4 @@ updates: directory: "/" schedule: interval: "daily" - rebase-strategy: "disabled" \ No newline at end of file + rebase-strategy: "disabled" diff --git a/.github/ISSUE_TEMPLATE/release.md b/.github/ISSUE_TEMPLATE/release.md index ac28792a3..a69349f54 100644 --- a/.github/ISSUE_TEMPLATE/release.md +++ b/.github/ISSUE_TEMPLATE/release.md @@ -7,4 +7,4 @@ assignees: '' --- -### TBD \ No newline at end of file +### TBD diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 60e12779b..5928b1cbf 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -18,4 +18,4 @@ resolves # - [ ] I have signed the [CLA](https://docs.getdbt.com/docs/contributor-license-agreements) - [ ] I have run this code in development and it appears to resolve the stated issue - [ ] This PR includes tests, or tests are not required/relevant for this PR -- [ ] I have updated the `CHANGELOG.md` and added information about my change to the "dbt-spark next" section. \ No newline at end of file +- [ ] I have updated the `CHANGELOG.md` and added information about my change to the "dbt-spark next" section. diff --git a/.github/workflows/jira-creation.yml b/.github/workflows/jira-creation.yml index c84e106a7..b4016befc 100644 --- a/.github/workflows/jira-creation.yml +++ b/.github/workflows/jira-creation.yml @@ -13,7 +13,7 @@ name: Jira Issue Creation on: issues: types: [opened, labeled] - + permissions: issues: write diff --git a/.github/workflows/jira-label.yml b/.github/workflows/jira-label.yml index fd533a170..3da2e3a38 100644 --- a/.github/workflows/jira-label.yml +++ b/.github/workflows/jira-label.yml @@ -13,7 +13,7 @@ name: Jira Label Mirroring on: issues: types: [labeled, unlabeled] - + permissions: issues: read @@ -24,4 +24,3 @@ jobs: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} - diff --git a/.github/workflows/jira-transition.yml b/.github/workflows/jira-transition.yml index 71273c7a9..ed9f9cd4f 100644 --- a/.github/workflows/jira-transition.yml +++ b/.github/workflows/jira-transition.yml @@ -21,4 +21,4 @@ jobs: secrets: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} - JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} \ No newline at end of file + JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index fbdbbbaae..bf607c379 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -18,7 +18,6 @@ on: push: branches: - "main" - - "develop" - "*.latest" - "releases/*" pull_request: @@ -37,18 +36,10 @@ defaults: jobs: code-quality: - name: ${{ matrix.toxenv }} + name: code-quality runs-on: ubuntu-latest - - strategy: - fail-fast: false - matrix: - toxenv: [flake8] - - env: - TOXENV: ${{ matrix.toxenv }} - PYTEST_ADDOPTS: "-v --color=yes" + timeout-minutes: 10 steps: - name: Check out the repository @@ -58,28 +49,36 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 - with: + with: python-version: '3.8' - name: Install python dependencies run: | sudo apt-get install libsasl2-dev - pip install --user --upgrade pip - pip install tox - pip --version - tox --version - - name: Run tox - run: tox + python -m pip install --user --upgrade pip + python -m pip --version + python -m pip install pre-commit + pre-commit --version + python -m pip install mypy==0.942 + python -m pip install types-requests + mypy --version + python -m pip install -r requirements.txt + python -m pip install -r dev-requirements.txt + dbt --version + + - name: Run pre-commit hooks + run: pre-commit run --all-files --show-diff-on-failure unit: name: unit test / python ${{ matrix.python-version }} runs-on: ubuntu-latest + timeout-minutes: 10 strategy: fail-fast: false matrix: - python-version: [3.7, 3.8] # TODO: support unit testing for python 3.9 (https://github.com/dbt-labs/dbt/issues/3689) + python-version: ["3.7", "3.8", "3.9", "3.10"] env: TOXENV: "unit" @@ -88,8 +87,6 @@ jobs: steps: - name: Check out the repository uses: actions/checkout@v2 - with: - persist-credentials: false - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 @@ -99,9 +96,9 @@ jobs: - name: Install python dependencies run: | sudo apt-get install libsasl2-dev - pip install --user --upgrade pip - pip install tox - pip --version + python -m pip install --user --upgrade pip + python -m pip --version + python -m pip install tox tox --version - name: Run tox run: tox @@ -128,8 +125,6 @@ jobs: steps: - name: Check out the repository uses: actions/checkout@v2 - with: - persist-credentials: false - name: Set up Python uses: actions/setup-python@v2 @@ -138,9 +133,10 @@ jobs: - name: Install python dependencies run: | - pip install --user --upgrade pip - pip install --upgrade setuptools wheel twine check-wheel-contents - pip --version + python -m pip install --user --upgrade pip + python -m pip install --upgrade setuptools wheel twine check-wheel-contents + python -m pip --version + - name: Build distributions run: ./scripts/build-dist.sh @@ -153,7 +149,7 @@ jobs: - name: Check wheel contents run: | check-wheel-contents dist/*.whl --ignore W007,W008 - + - name: Check if this is an alpha version id: check-is-alpha run: | @@ -179,7 +175,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: [3.7, 3.8, 3.9] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - name: Set up Python ${{ matrix.python-version }} @@ -189,9 +185,9 @@ jobs: - name: Install python dependencies run: | - pip install --user --upgrade pip - pip install --upgrade wheel - pip --version + python -m pip install --user --upgrade pip + python -m pip install --upgrade wheel + python -m pip --version - uses: actions/download-artifact@v2 with: name: dist @@ -202,13 +198,13 @@ jobs: - name: Install wheel distributions run: | - find ./dist/*.whl -maxdepth 1 -type f | xargs pip install --force-reinstall --find-links=dist/ + find ./dist/*.whl -maxdepth 1 -type f | xargs python -m pip install --force-reinstall --find-links=dist/ - name: Check wheel distributions run: | dbt --version - name: Install source distributions run: | - find ./dist/*.gz -maxdepth 1 -type f | xargs pip install --force-reinstall --find-links=dist/ + find ./dist/*.gz -maxdepth 1 -type f | xargs python -m pip install --force-reinstall --find-links=dist/ - name: Check source distributions run: | dbt --version diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b25ea884e..554e13a8d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -3,28 +3,28 @@ name: Build and Release on: workflow_dispatch: - + # Release version number that must be updated for each release env: version_number: '0.20.0rc2' -jobs: +jobs: Test: runs-on: ubuntu-latest steps: - name: Setup Python uses: actions/setup-python@v2.2.2 - with: + with: python-version: '3.8' - + - uses: actions/checkout@v2 - - name: Test release + - name: Test release run: | python3 -m venv env source env/bin/activate sudo apt-get install libsasl2-dev - pip install -r dev_requirements.txt + pip install -r dev-requirements.txt pip install twine wheel setuptools python setup.py sdist bdist_wheel pip install dist/dbt-spark-*.tar.gz @@ -38,9 +38,9 @@ jobs: steps: - name: Setup Python uses: actions/setup-python@v2.2.2 - with: + with: python-version: '3.8' - + - uses: actions/checkout@v2 - name: Bumping version @@ -48,7 +48,7 @@ jobs: python3 -m venv env source env/bin/activate sudo apt-get install libsasl2-dev - pip install -r dev_requirements.txt + pip install -r dev-requirements.txt bumpversion --config-file .bumpversion-dbt.cfg patch --new-version ${{env.version_number}} bumpversion --config-file .bumpversion.cfg patch --new-version ${{env.version_number}} --allow-dirty git status @@ -60,7 +60,7 @@ jobs: author_email: 'leah.antkiewicz@dbtlabs.com' message: 'Bumping version to ${{env.version_number}}' tag: v${{env.version_number}} - + # Need to set an output variable because env variables can't be taken as input # This is needed for the next step with releasing to GitHub - name: Find release type @@ -69,7 +69,7 @@ jobs: IS_PRERELEASE: ${{ contains(env.version_number, 'rc') || contains(env.version_number, 'b') }} run: | echo ::set-output name=isPrerelease::$IS_PRERELEASE - + - name: Create GitHub release uses: actions/create-release@v1 env: @@ -88,7 +88,7 @@ jobs: # or $ pip install "dbt-spark[PyHive]==${{env.version_number}}" ``` - + PypiRelease: name: Pypi release runs-on: ubuntu-latest @@ -97,13 +97,13 @@ jobs: steps: - name: Setup Python uses: actions/setup-python@v2.2.2 - with: + with: python-version: '3.8' - + - uses: actions/checkout@v2 with: ref: v${{env.version_number}} - + - name: Release to pypi env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} @@ -112,8 +112,7 @@ jobs: python3 -m venv env source env/bin/activate sudo apt-get install libsasl2-dev - pip install -r dev_requirements.txt + pip install -r dev-requirements.txt pip install twine wheel setuptools python setup.py sdist bdist_wheel twine upload --non-interactive dist/dbt_spark-${{env.version_number}}-py3-none-any.whl dist/dbt-spark-${{env.version_number}}.tar.gz - diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 2848ce8f7..a56455d55 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,5 +13,3 @@ jobs: stale-pr-message: "This PR has been marked as Stale because it has been open for 180 days with no activity. If you would like the PR to remain open, please remove the stale label or comment on the PR, or it will be closed in 7 days." # mark issues/PRs stale when they haven't seen activity in 180 days days-before-stale: 180 - # ignore checking issues with the following labels - exempt-issue-labels: "epic, discussion" \ No newline at end of file diff --git a/.github/workflows/version-bump.yml b/.github/workflows/version-bump.yml index 7fb8bb6eb..a8b3236ce 100644 --- a/.github/workflows/version-bump.yml +++ b/.github/workflows/version-bump.yml @@ -1,16 +1,16 @@ # **what?** # This workflow will take a version number and a dry run flag. With that -# it will run versionbump to update the version number everywhere in the +# it will run versionbump to update the version number everywhere in the # code base and then generate an update Docker requirements file. If this # is a dry run, a draft PR will open with the changes. If this isn't a dry # run, the changes will be committed to the branch this is run on. # **why?** -# This is to aid in releasing dbt and making sure we have updated +# This is to aid in releasing dbt and making sure we have updated # the versions and Docker requirements in all places. # **when?** -# This is triggered either manually OR +# This is triggered either manually OR # from the repository_dispatch event "version-bump" which is sent from # the dbt-release repo Action @@ -25,11 +25,11 @@ on: is_dry_run: description: 'Creates a draft PR to allow testing instead of committing to a branch' required: true - default: 'true' + default: 'true' repository_dispatch: types: [version-bump] -jobs: +jobs: bump: runs-on: ubuntu-latest steps: @@ -58,19 +58,19 @@ jobs: sudo apt-get install libsasl2-dev python3 -m venv env source env/bin/activate - pip install --upgrade pip - + pip install --upgrade pip + - name: Create PR branch if: ${{ steps.variables.outputs.IS_DRY_RUN == 'true' }} run: | git checkout -b bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_$GITHUB_RUN_ID git push origin bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_$GITHUB_RUN_ID git branch --set-upstream-to=origin/bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_$GITHUB_RUN_ID bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_$GITHUB_RUN_ID - + - name: Bumping version run: | source env/bin/activate - pip install -r dev_requirements.txt + pip install -r dev-requirements.txt env/bin/bumpversion --allow-dirty --new-version ${{steps.variables.outputs.VERSION_NUMBER}} major git status @@ -100,4 +100,4 @@ jobs: draft: true base: ${{github.ref}} title: 'Bumping version to ${{steps.variables.outputs.VERSION_NUMBER}}' - branch: 'bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_${{GITHUB.RUN_ID}}' + branch: 'bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_${{GITHUB.RUN_ID}}' diff --git a/.gitignore b/.gitignore index cc586f5fe..189589cf4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,18 +1,47 @@ -.hive-metastore/ -.spark-warehouse/ -*.egg-info -env/ -*.pyc +# Byte-compiled / optimized / DLL files __pycache__ +*.py[cod] +*$py.class + +# Distribution / packaging +.Python +build/ +env*/ +dbt_env/ +dist/ +*.egg-info +logs/ + + +# Unit test .tox/ .env +test.env + + +# Django stuff +*.log + +# Mypy +*.pytest_cache/ + +# Vim +*.sw* + +# Pyenv +.python-version + +# pycharm .idea/ -build/ -dist/ -dbt-integration-tests -test/integration/.user.yml + +# MacOS .DS_Store -test.env + +# vscode .vscode -*.log -logs/ \ No newline at end of file + +# other +.hive-metastore/ +.spark-warehouse/ +dbt-integration-tests +test/integration/.user.yml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..e70156dcd --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,66 @@ +# For more on configuring pre-commit hooks (see https://pre-commit.com/) + +# TODO: remove global exclusion of tests when testing overhaul is complete +exclude: '^tests/.*' + +# Force all unspecified python hooks to run python 3.8 +default_language_version: + python: python3.8 + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: check-yaml + args: [--unsafe] + - id: check-json + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-case-conflict +- repo: https://github.com/psf/black + rev: 21.12b0 + hooks: + - id: black + additional_dependencies: ['click==8.0.4'] + args: + - "--line-length=99" + - "--target-version=py38" + - id: black + alias: black-check + stages: [manual] + additional_dependencies: ['click==8.0.4'] + args: + - "--line-length=99" + - "--target-version=py38" + - "--check" + - "--diff" +- repo: https://gitlab.com/pycqa/flake8 + rev: 4.0.1 + hooks: + - id: flake8 + - id: flake8 + alias: flake8-check + stages: [manual] +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.950 + hooks: + - id: mypy + # N.B.: Mypy is... a bit fragile. + # + # By using `language: system` we run this hook in the local + # environment instead of a pre-commit isolated one. This is needed + # to ensure mypy correctly parses the project. + + # It may cause trouble in that it adds environmental variables out + # of our control to the mix. Unfortunately, there's nothing we can + # do about per pre-commit's author. + # See https://github.com/pre-commit/pre-commit/issues/730 for details. + args: [--show-error-codes, --ignore-missing-imports] + files: ^dbt/adapters/.* + language: system + - id: mypy + alias: mypy-check + stages: [manual] + args: [--show-error-codes, --pretty, --ignore-missing-imports] + files: ^dbt/adapters + language: system diff --git a/CHANGELOG.md b/CHANGELOG.md index f9a094942..28f7e138b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,50 @@ -## dbt-spark 1.1.0 (TBD) +## dbt-spark 1.3.0b1 (Release TBD) + +### Fixes +- Pin `pyodbc` to version 4.0.32 to prevent overwriting `libodbc.so` and `libltdl.so` on Linux ([#397](https://github.com/dbt-labs/dbt-spark/issues/397/), [#398](https://github.com/dbt-labs/dbt-spark/pull/398/)) + +### Contributors +- [@barberscott](https://github.com/barberscott) ([#398](https://github.com/dbt-labs/dbt-spark/pull/398/)) + +## dbt-spark 1.2.0rc1 (July 12, 2022) + +### Fixes +- Incremental materialization updated to not drop table first if full refresh for delta lake format, as it already runs _create or replace table_ ([#286](https://github.com/dbt-labs/dbt-spark/issues/286), [#287](https://github.com/dbt-labs/dbt-spark/pull/287/)) +- Apache Spark version upgraded to 3.1.1 ([#348](https://github.com/dbt-labs/dbt-spark/issues/348), [#349](https://github.com/dbt-labs/dbt-spark/pull/349)) + +### Features +- Add grants to materializations ([#366](https://github.com/dbt-labs/dbt-spark/issues/366), [#381](https://github.com/dbt-labs/dbt-spark/pull/381)) + +### Under the hood +- Update `SparkColumn.numeric_type` to return `decimal` instead of `numeric`, since SparkSQL exclusively supports the former ([#380](https://github.com/dbt-labs/dbt-spark/pull/380)) +- Make minimal changes to support dbt Core incremental materialization refactor ([#402](https://github.com/dbt-labs/dbt-spark/issue/402), [#394](httpe://github.com/dbt-labs/dbt-spark/pull/394)) + +### Contributors +- [@grindheim](https://github.com/grindheim) ([#287](https://github.com/dbt-labs/dbt-spark/pull/287/)) +- [@nssalian](https://github.com/nssalian) ([#349](https://github.com/dbt-labs/dbt-spark/pull/349)) + +## dbt-spark 1.2.0b1 (June 24, 2022) + +### Fixes +- `adapter.get_columns_in_relation` (method) and `get_columns_in_relation` (macro) now return identical responses. The previous behavior of `get_columns_in_relation` (macro) is now represented by a new macro, `get_columns_in_relation_raw` ([#354](https://github.com/dbt-labs/dbt-spark/issues/354), [#355](https://github.com/dbt-labs/dbt-spark/pull/355)) + +### Under the hood +- Initialize lift + shift for cross-db macros ([#359](https://github.com/dbt-labs/dbt-spark/pull/359)) +- Add invocation env to user agent string ([#367](https://github.com/dbt-labs/dbt-spark/pull/367)) +- Use dispatch pattern for get_columns_in_relation_raw macro ([#365](https://github.com/dbt-labs/dbt-spark/pull/365)) + +### Contributors +- [@ueshin](https://github.com/ueshin) ([#365](https://github.com/dbt-labs/dbt-spark/pull/365)) +- [@dbeatty10](https://github.com/dbeatty10) ([#359](https://github.com/dbt-labs/dbt-spark/pull/359)) + +## dbt-spark 1.1.0 (April 28, 2022) ### Features - Add session connection method ([#272](https://github.com/dbt-labs/dbt-spark/issues/272), [#279](https://github.com/dbt-labs/dbt-spark/pull/279)) +- rename file to match reference to dbt-core ([#344](https://github.com/dbt-labs/dbt-spark/pull/344)) ### Under the hood +- Add precommit tooling to this repo ([#356](https://github.com/dbt-labs/dbt-spark/pull/356)) - Use dbt.tests.adapter.basic in test suite ([#298](https://github.com/dbt-labs/dbt-spark/issues/298), [#299](https://github.com/dbt-labs/dbt-spark/pull/299)) - Make internal macros use macro dispatch to be overridable in child adapters ([#319](https://github.com/dbt-labs/dbt-spark/issues/319), [#320](https://github.com/dbt-labs/dbt-spark/pull/320)) - Override adapter method 'run_sql_for_tests' ([#323](https://github.com/dbt-labs/dbt-spark/issues/323), [#324](https://github.com/dbt-labs/dbt-spark/pull/324)) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..c0d9bb3d2 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,101 @@ +# Contributing to `dbt-spark` + +1. [About this document](#about-this-document) +3. [Getting the code](#getting-the-code) +5. [Running `dbt-spark` in development](#running-dbt-spark-in-development) +6. [Testing](#testing) +7. [Updating Docs](#updating-docs) +7. [Submitting a Pull Request](#submitting-a-pull-request) + +## About this document +This document is a guide intended for folks interested in contributing to `dbt-spark`. Below, we document the process by which members of the community should create issues and submit pull requests (PRs) in this repository. It is not intended as a guide for using `dbt-spark`, and it assumes a certain level of familiarity with Python concepts such as virtualenvs, `pip`, Python modules, and so on. This guide assumes you are using macOS or Linux and are comfortable with the command line. + +For those wishing to contribute we highly suggest reading the dbt-core's [contribution guide](https://github.com/dbt-labs/dbt-core/blob/HEAD/CONTRIBUTING.md) if you haven't already. Almost all of the information there is applicable to contributing here, too! + +### Signing the CLA + +Please note that all contributors to `dbt-spark` must sign the [Contributor License Agreement](https://docs.getdbt.com/docs/contributor-license-agreements) to have their Pull Request merged into an `dbt-spark` codebase. If you are unable to sign the CLA, then the `dbt-spark` maintainers will unfortunately be unable to merge your Pull Request. You are, however, welcome to open issues and comment on existing ones. + + +## Getting the code + +You will need `git` in order to download and modify the `dbt-spark` source code. You can find directions [here](https://github.com/git-guides/install-git) on how to install `git`. + +### External contributors + +If you are not a member of the `dbt-labs` GitHub organization, you can contribute to `dbt-spark` by forking the `dbt-spark` repository. For a detailed overview on forking, check out the [GitHub docs on forking](https://help.github.com/en/articles/fork-a-repo). In short, you will need to: + +1. fork the `dbt-spark` repository +2. clone your fork locally +3. check out a new branch for your proposed changes +4. push changes to your fork +5. open a pull request against `dbt-labs/dbt-spark` from your forked repository + +### dbt Labs contributors + +If you are a member of the `dbt Labs` GitHub organization, you will have push access to the `dbt-spark` repo. Rather than forking `dbt-spark` to make your changes, just clone the repository, check out a new branch, and push directly to that branch. + + +## Running `dbt-spark` in development + +### Installation + +First make sure that you set up your `virtualenv` as described in [Setting up an environment](https://github.com/dbt-labs/dbt-core/blob/HEAD/CONTRIBUTING.md#setting-up-an-environment). Ensure you have the latest version of pip installed with `pip install --upgrade pip`. Next, install `dbt-spark` latest dependencies: + +```sh +pip install -e . -r dev-requirements.txt +``` + +When `dbt-spark` is installed this way, any changes you make to the `dbt-spark` source code will be reflected immediately in your next `dbt-spark` run. + +To confirm you have correct version of `dbt-core` installed please run `dbt --version` and `which dbt`. + + +## Testing + +### Initial Setup + +`dbt-spark` uses test credentials specified in a `test.env` file in the root of the repository. This `test.env` file is git-ignored, but please be _extra_ careful to never check in credentials or other sensitive information when developing. To create your `test.env` file, copy the provided example file, then supply your relevant credentials. + +``` +cp test.env.example test.env +$EDITOR test.env +``` + +### Test commands +There are a few methods for running tests locally. + +#### `tox` +`tox` takes care of managing Python virtualenvs and installing dependencies in order to run tests. You can also run tests in parallel, for example you can run unit tests for Python 3.7, Python 3.8, Python 3.9, and `flake8` checks in parallel with `tox -p`. Also, you can run unit tests for specific python versions with `tox -e py37`. The configuration of these tests are located in `tox.ini`. + +#### `pytest` +Finally, you can also run a specific test or group of tests using `pytest` directly. With a Python virtualenv active and dev dependencies installed you can do things like: + +```sh +# run specific spark integration tests +python -m pytest -m profile_spark tests/integration/get_columns_in_relation +# run specific functional tests +python -m pytest --profile databricks_sql_endpoint tests/functional/adapter/test_basic.py +# run all unit tests in a file +python -m pytest tests/unit/test_adapter.py +# run a specific unit test +python -m pytest test/unit/test_adapter.py::TestSparkAdapter::test_profile_with_database +``` +## Updating Docs + +Many changes will require and update to the `dbt-spark` docs here are some useful resources. + +- Docs are [here](https://docs.getdbt.com/). +- The docs repo for making changes is located [here]( https://github.com/dbt-labs/docs.getdbt.com). +- The changes made are likely to impact one or both of [Spark Profile](https://docs.getdbt.com/reference/warehouse-profiles/spark-profile), or [Saprk Configs](https://docs.getdbt.com/reference/resource-configs/spark-configs). +- We ask every community member who makes a user-facing change to open an issue or PR regarding doc changes. + +## Submitting a Pull Request + +dbt Labs provides a CI environment to test changes to the `dbt-spark` adapter, and periodic checks against the development version of `dbt-core` through Github Actions. + +A `dbt-spark` maintainer will review your PR. They may suggest code revision for style or clarity, or request that you add unit or integration test(s). These are good things! We believe that, with a little bit of help, anyone can contribute high-quality code. + +Once all requests and answers have been answered the `dbt-spark` maintainer can trigger CI testing. + +Once all tests are passing and your PR has been approved, a `dbt-spark` maintainer will merge your changes into the active development branch. And that's it! Happy developing :tada: diff --git a/MANIFEST.in b/MANIFEST.in index 78412d5b8..cfbc714ed 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -recursive-include dbt/include *.sql *.yml *.md \ No newline at end of file +recursive-include dbt/include *.sql *.yml *.md diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..a520c425f --- /dev/null +++ b/Makefile @@ -0,0 +1,56 @@ +.DEFAULT_GOAL:=help + +.PHONY: dev +dev: ## Installs adapter in develop mode along with development depedencies + @\ + pip install -r dev-requirements.txt && pre-commit install + +.PHONY: mypy +mypy: ## Runs mypy against staged changes for static type checking. + @\ + pre-commit run --hook-stage manual mypy-check | grep -v "INFO" + +.PHONY: flake8 +flake8: ## Runs flake8 against staged changes to enforce style guide. + @\ + pre-commit run --hook-stage manual flake8-check | grep -v "INFO" + +.PHONY: black +black: ## Runs black against staged changes to enforce style guide. + @\ + pre-commit run --hook-stage manual black-check -v | grep -v "INFO" + +.PHONY: lint +lint: ## Runs flake8 and mypy code checks against staged changes. + @\ + pre-commit run flake8-check --hook-stage manual | grep -v "INFO"; \ + pre-commit run mypy-check --hook-stage manual | grep -v "INFO" + +.PHONY: linecheck +linecheck: ## Checks for all Python lines 100 characters or more + @\ + find dbt -type f -name "*.py" -exec grep -I -r -n '.\{100\}' {} \; + +.PHONY: unit +unit: ## Runs unit tests with py38. + @\ + tox -e py38 + +.PHONY: test +test: ## Runs unit tests with py38 and code checks against staged changes. + @\ + tox -p -e py38; \ + pre-commit run black-check --hook-stage manual | grep -v "INFO"; \ + pre-commit run flake8-check --hook-stage manual | grep -v "INFO"; \ + pre-commit run mypy-check --hook-stage manual | grep -v "INFO" + +.PHONY: clean + @echo "cleaning repo" + @git clean -f -X + +.PHONY: help +help: ## Show this help message. + @echo 'usage: make [target]' + @echo + @echo 'targets:' + @grep -E '^[7+a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' diff --git a/README.md b/README.md index 037a49895..241d869d7 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ more information, consult [the docs](https://docs.getdbt.com/docs/profile-spark) ## Running locally A `docker-compose` environment starts a Spark Thrift server and a Postgres database as a Hive Metastore backend. -Note that this is spark 2 not spark 3 so some functionalities might not be available. +Note: dbt-spark now supports Spark 3.1.1 (formerly on Spark 2.x). The following command would start two docker containers ``` diff --git a/dbt/adapters/spark/__init__.py b/dbt/adapters/spark/__init__.py index 469e202b9..6ecc5eccf 100644 --- a/dbt/adapters/spark/__init__.py +++ b/dbt/adapters/spark/__init__.py @@ -8,6 +8,5 @@ from dbt.include import spark Plugin = AdapterPlugin( - adapter=SparkAdapter, - credentials=SparkCredentials, - include_path=spark.PACKAGE_PATH) + adapter=SparkAdapter, credentials=SparkCredentials, include_path=spark.PACKAGE_PATH +) diff --git a/dbt/adapters/spark/column.py b/dbt/adapters/spark/column.py index fd377ad15..dcf7590e9 100644 --- a/dbt/adapters/spark/column.py +++ b/dbt/adapters/spark/column.py @@ -1,11 +1,11 @@ from dataclasses import dataclass -from typing import TypeVar, Optional, Dict, Any +from typing import Any, Dict, Optional, TypeVar, Union from dbt.adapters.base.column import Column from dbt.dataclass_schema import dbtClassMixin from hologram import JsonDict -Self = TypeVar('Self', bound='SparkColumn') +Self = TypeVar("Self", bound="SparkColumn") @dataclass @@ -31,37 +31,42 @@ def literal(self, value): @property def quoted(self) -> str: - return '`{}`'.format(self.column) + return "`{}`".format(self.column) @property def data_type(self) -> str: return self.dtype + @classmethod + def numeric_type(cls, dtype: str, precision: Any, scale: Any) -> str: + # SparkSQL does not support 'numeric' or 'number', only 'decimal' + if precision is None or scale is None: + return "decimal" + else: + return "{}({},{})".format("decimal", precision, scale) + def __repr__(self) -> str: return "".format(self.name, self.data_type) @staticmethod def convert_table_stats(raw_stats: Optional[str]) -> Dict[str, Any]: - table_stats = {} + table_stats: Dict[str, Union[int, str, bool]] = {} if raw_stats: # format: 1109049927 bytes, 14093476 rows stats = { - stats.split(" ")[1]: int(stats.split(" ")[0]) - for stats in raw_stats.split(', ') + stats.split(" ")[1]: int(stats.split(" ")[0]) for stats in raw_stats.split(", ") } for key, val in stats.items(): - table_stats[f'stats:{key}:label'] = key - table_stats[f'stats:{key}:value'] = val - table_stats[f'stats:{key}:description'] = '' - table_stats[f'stats:{key}:include'] = True + table_stats[f"stats:{key}:label"] = key + table_stats[f"stats:{key}:value"] = val + table_stats[f"stats:{key}:description"] = "" + table_stats[f"stats:{key}:include"] = True return table_stats - def to_column_dict( - self, omit_none: bool = True, validate: bool = False - ) -> JsonDict: + def to_column_dict(self, omit_none: bool = True, validate: bool = False) -> JsonDict: original_dict = self.to_dict(omit_none=omit_none) # If there are stats, merge them into the root of the dict - original_stats = original_dict.pop('table_stats', None) + original_stats = original_dict.pop("table_stats", None) if original_stats: original_dict.update(original_stats) return original_dict diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 11163ccf0..59ceb9dd8 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -26,6 +26,7 @@ from hologram.helpers import StrEnum from dataclasses import dataclass, field from typing import Any, Dict, Optional + try: from thrift.transport.TSSLSocket import TSSLSocket import thrift @@ -33,11 +34,7 @@ import sasl import thrift_sasl except ImportError: - TSSLSocket = None - thrift = None - ssl = None - sasl = None - thrift_sasl = None + pass # done deliberately: setting modules to None explicitly violates MyPy contracts by degrading type semantics import base64 import time @@ -52,10 +49,10 @@ def _build_odbc_connnection_string(**kwargs) -> str: class SparkConnectionMethod(StrEnum): - THRIFT = 'thrift' - HTTP = 'http' - ODBC = 'odbc' - SESSION = 'session' + THRIFT = "thrift" + HTTP = "http" + ODBC = "odbc" + SESSION = "session" @dataclass @@ -71,7 +68,7 @@ class SparkCredentials(Credentials): port: int = 443 auth: Optional[str] = None kerberos_service_name: Optional[str] = None - organization: str = '0' + organization: str = "0" connect_retries: int = 0 connect_timeout: int = 10 use_ssl: bool = False @@ -81,27 +78,24 @@ class SparkCredentials(Credentials): @classmethod def __pre_deserialize__(cls, data): data = super().__pre_deserialize__(data) - if 'database' not in data: - data['database'] = None + if "database" not in data: + data["database"] = None return data def __post_init__(self): # spark classifies database and schema as the same thing - if ( - self.database is not None and - self.database != self.schema - ): + if self.database is not None and self.database != self.schema: raise dbt.exceptions.RuntimeException( - f' schema: {self.schema} \n' - f' database: {self.database} \n' - f'On Spark, database must be omitted or have the same value as' - f' schema.' + f" schema: {self.schema} \n" + f" database: {self.database} \n" + f"On Spark, database must be omitted or have the same value as" + f" schema." ) self.database = None if self.method == SparkConnectionMethod.ODBC: try: - import pyodbc # noqa: F401 + import pyodbc # noqa: F401 except ImportError as e: raise dbt.exceptions.RuntimeException( f"{self.method} connection method requires " @@ -111,22 +105,16 @@ def __post_init__(self): f"ImportError({e.msg})" ) from e - if ( - self.method == SparkConnectionMethod.ODBC and - self.cluster and - self.endpoint - ): + if self.method == SparkConnectionMethod.ODBC and self.cluster and self.endpoint: raise dbt.exceptions.RuntimeException( "`cluster` and `endpoint` cannot both be set when" f" using {self.method} method to connect to Spark" ) if ( - self.method == SparkConnectionMethod.HTTP or - self.method == SparkConnectionMethod.THRIFT - ) and not ( - ThriftState and THttpClient and hive - ): + self.method == SparkConnectionMethod.HTTP + or self.method == SparkConnectionMethod.THRIFT + ) and not (ThriftState and THttpClient and hive): raise dbt.exceptions.RuntimeException( f"{self.method} connection method requires " "additional dependencies. \n" @@ -148,19 +136,19 @@ def __post_init__(self): @property def type(self): - return 'spark' + return "spark" @property def unique_field(self): return self.host def _connection_keys(self): - return ('host', 'port', 'cluster', - 'endpoint', 'schema', 'organization') + return ("host", "port", "cluster", "endpoint", "schema", "organization") class PyhiveConnectionWrapper(object): """Wrap a Spark connection in a way that no-ops transactions""" + # https://forums.databricks.com/questions/2157/in-apache-spark-sql-can-we-roll-back-the-transacti.html # noqa def __init__(self, handle): @@ -178,9 +166,7 @@ def cancel(self): try: self._cursor.cancel() except EnvironmentError as exc: - logger.debug( - "Exception while cancelling query: {}".format(exc) - ) + logger.debug("Exception while cancelling query: {}".format(exc)) def close(self): if self._cursor: @@ -189,9 +175,7 @@ def close(self): try: self._cursor.close() except EnvironmentError as exc: - logger.debug( - "Exception while closing cursor: {}".format(exc) - ) + logger.debug("Exception while closing cursor: {}".format(exc)) self.handle.close() def rollback(self, *args, **kwargs): @@ -247,23 +231,20 @@ def execute(self, sql, bindings=None): dbt.exceptions.raise_database_error(poll_state.errorMessage) elif state not in STATE_SUCCESS: - status_type = ThriftState._VALUES_TO_NAMES.get( - state, - 'Unknown<{!r}>'.format(state)) + status_type = ThriftState._VALUES_TO_NAMES.get(state, "Unknown<{!r}>".format(state)) - dbt.exceptions.raise_database_error( - "Query failed with status: {}".format(status_type)) + dbt.exceptions.raise_database_error("Query failed with status: {}".format(status_type)) logger.debug("Poll status: {}, query complete".format(state)) @classmethod def _fix_binding(cls, value): """Convert complex datatypes to primitives that can be loaded by - the Spark driver""" + the Spark driver""" if isinstance(value, NUMBERS): return float(value) elif isinstance(value, datetime): - return value.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] + return value.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] else: return value @@ -273,7 +254,6 @@ def description(self): class PyodbcConnectionWrapper(PyhiveConnectionWrapper): - def execute(self, sql, bindings=None): if sql.strip().endswith(";"): sql = sql.strip()[:-1] @@ -282,19 +262,17 @@ def execute(self, sql, bindings=None): self._cursor.execute(sql) else: # pyodbc only supports `qmark` sql params! - query = sqlparams.SQLParams('format', 'qmark') + query = sqlparams.SQLParams("format", "qmark") sql, bindings = query.format(sql, bindings) self._cursor.execute(sql, *bindings) class SparkConnectionManager(SQLConnectionManager): - TYPE = 'spark' + TYPE = "spark" SPARK_CLUSTER_HTTP_PATH = "/sql/protocolv1/o/{organization}/{cluster}" SPARK_SQL_ENDPOINT_HTTP_PATH = "/sql/1.0/endpoints/{endpoint}" - SPARK_CONNECTION_URL = ( - "{host}:{port}" + SPARK_CLUSTER_HTTP_PATH - ) + SPARK_CONNECTION_URL = "{host}:{port}" + SPARK_CLUSTER_HTTP_PATH @contextmanager def exception_handler(self, sql): @@ -308,7 +286,7 @@ def exception_handler(self, sql): raise thrift_resp = exc.args[0] - if hasattr(thrift_resp, 'status'): + if hasattr(thrift_resp, "status"): msg = thrift_resp.status.errorMessage raise dbt.exceptions.RuntimeException(msg) else: @@ -320,10 +298,8 @@ def cancel(self, connection): @classmethod def get_response(cls, cursor) -> AdapterResponse: # https://github.com/dbt-labs/dbt-spark/issues/142 - message = 'OK' - return AdapterResponse( - _message=message - ) + message = "OK" + return AdapterResponse(_message=message) # No transactions on Spark.... def add_begin_query(self, *args, **kwargs): @@ -346,12 +322,13 @@ def validate_creds(cls, creds, required): if not hasattr(creds, key): raise dbt.exceptions.DbtProfileError( "The config '{}' is required when using the {} method" - " to connect to Spark".format(key, method)) + " to connect to Spark".format(key, method) + ) @classmethod def open(cls, connection): if connection.state == ConnectionState.OPEN: - logger.debug('Connection is already open, skipping open.') + logger.debug("Connection is already open, skipping open.") return connection creds = connection.credentials @@ -360,19 +337,18 @@ def open(cls, connection): for i in range(1 + creds.connect_retries): try: if creds.method == SparkConnectionMethod.HTTP: - cls.validate_creds(creds, ['token', 'host', 'port', - 'cluster', 'organization']) + cls.validate_creds(creds, ["token", "host", "port", "cluster", "organization"]) # Prepend https:// if it is missing host = creds.host - if not host.startswith('https://'): - host = 'https://' + creds.host + if not host.startswith("https://"): + host = "https://" + creds.host conn_url = cls.SPARK_CONNECTION_URL.format( host=host, port=creds.port, organization=creds.organization, - cluster=creds.cluster + cluster=creds.cluster, ) logger.debug("connection url: {}".format(conn_url)) @@ -381,15 +357,12 @@ def open(cls, connection): raw_token = "token:{}".format(creds.token).encode() token = base64.standard_b64encode(raw_token).decode() - transport.setCustomHeaders({ - 'Authorization': 'Basic {}'.format(token) - }) + transport.setCustomHeaders({"Authorization": "Basic {}".format(token)}) conn = hive.connect(thrift_transport=transport) handle = PyhiveConnectionWrapper(conn) elif creds.method == SparkConnectionMethod.THRIFT: - cls.validate_creds(creds, - ['host', 'port', 'user', 'schema']) + cls.validate_creds(creds, ["host", "port", "user", "schema"]) if creds.use_ssl: transport = build_ssl_transport( @@ -397,26 +370,33 @@ def open(cls, connection): port=creds.port, username=creds.user, auth=creds.auth, - kerberos_service_name=creds.kerberos_service_name) + kerberos_service_name=creds.kerberos_service_name, + ) conn = hive.connect(thrift_transport=transport) else: - conn = hive.connect(host=creds.host, - port=creds.port, - username=creds.user, - auth=creds.auth, - kerberos_service_name=creds.kerberos_service_name) # noqa + conn = hive.connect( + host=creds.host, + port=creds.port, + username=creds.user, + auth=creds.auth, + kerberos_service_name=creds.kerberos_service_name, + ) # noqa handle = PyhiveConnectionWrapper(conn) elif creds.method == SparkConnectionMethod.ODBC: if creds.cluster is not None: - required_fields = ['driver', 'host', 'port', 'token', - 'organization', 'cluster'] + required_fields = [ + "driver", + "host", + "port", + "token", + "organization", + "cluster", + ] http_path = cls.SPARK_CLUSTER_HTTP_PATH.format( - organization=creds.organization, - cluster=creds.cluster + organization=creds.organization, cluster=creds.cluster ) elif creds.endpoint is not None: - required_fields = ['driver', 'host', 'port', 'token', - 'endpoint'] + required_fields = ["driver", "host", "port", "token", "endpoint"] http_path = cls.SPARK_SQL_ENDPOINT_HTTP_PATH.format( endpoint=creds.endpoint ) @@ -429,13 +409,12 @@ def open(cls, connection): cls.validate_creds(creds, required_fields) dbt_spark_version = __version__.version - user_agent_entry = f"dbt-labs-dbt-spark/{dbt_spark_version} (Databricks)" # noqa + user_agent_entry = ( + f"dbt-labs-dbt-spark/{dbt_spark_version} (Databricks)" # noqa + ) # http://simba.wpengine.com/products/Spark/doc/ODBC_InstallGuide/unix/content/odbc/hi/configuring/serverside.htm - ssp = { - f"SSP_{k}": f"{{{v}}}" - for k, v in creds.server_side_parameters.items() - } + ssp = {f"SSP_{k}": f"{{{v}}}" for k, v in creds.server_side_parameters.items()} # https://www.simba.com/products/Spark/doc/v2/ODBC_InstallGuide/unix/content/odbc/options/driver.htm connection_str = _build_odbc_connnection_string( @@ -461,6 +440,7 @@ def open(cls, connection): Connection, SessionConnectionWrapper, ) + handle = SessionConnectionWrapper(Connection()) else: raise dbt.exceptions.DbtProfileError( @@ -472,9 +452,9 @@ def open(cls, connection): if isinstance(e, EOFError): # The user almost certainly has invalid credentials. # Perhaps a token expired, or something - msg = 'Failed to connect' + msg = "Failed to connect" if creds.token is not None: - msg += ', is your token valid?' + msg += ", is your token valid?" raise dbt.exceptions.FailedToConnectException(msg) from e retryable_message = _is_retryable_error(e) if retryable_message and creds.connect_retries > 0: @@ -496,9 +476,7 @@ def open(cls, connection): logger.warning(msg) time.sleep(creds.connect_timeout) else: - raise dbt.exceptions.FailedToConnectException( - 'failed to connect' - ) from e + raise dbt.exceptions.FailedToConnectException("failed to connect") from e else: raise exc @@ -507,56 +485,50 @@ def open(cls, connection): return connection -def build_ssl_transport(host, port, username, auth, - kerberos_service_name, password=None): +def build_ssl_transport(host, port, username, auth, kerberos_service_name, password=None): transport = None if port is None: port = 10000 if auth is None: - auth = 'NONE' + auth = "NONE" socket = TSSLSocket(host, port, cert_reqs=ssl.CERT_NONE) - if auth == 'NOSASL': + if auth == "NOSASL": # NOSASL corresponds to hive.server2.authentication=NOSASL # in hive-site.xml transport = thrift.transport.TTransport.TBufferedTransport(socket) - elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'): + elif auth in ("LDAP", "KERBEROS", "NONE", "CUSTOM"): # Defer import so package dependency is optional - if auth == 'KERBEROS': + if auth == "KERBEROS": # KERBEROS mode in hive.server2.authentication is GSSAPI # in sasl library - sasl_auth = 'GSSAPI' + sasl_auth = "GSSAPI" else: - sasl_auth = 'PLAIN' + sasl_auth = "PLAIN" if password is None: # Password doesn't matter in NONE mode, just needs # to be nonempty. - password = 'x' + password = "x" def sasl_factory(): sasl_client = sasl.Client() - sasl_client.setAttr('host', host) - if sasl_auth == 'GSSAPI': - sasl_client.setAttr('service', kerberos_service_name) - elif sasl_auth == 'PLAIN': - sasl_client.setAttr('username', username) - sasl_client.setAttr('password', password) + sasl_client.setAttr("host", host) + if sasl_auth == "GSSAPI": + sasl_client.setAttr("service", kerberos_service_name) + elif sasl_auth == "PLAIN": + sasl_client.setAttr("username", username) + sasl_client.setAttr("password", password) else: raise AssertionError sasl_client.init() return sasl_client - transport = thrift_sasl.TSaslClientTransport(sasl_factory, - sasl_auth, socket) + transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket) return transport -def _is_retryable_error(exc: Exception) -> Optional[str]: - message = getattr(exc, 'message', None) - if message is None: - return None - message = message.lower() - if 'pending' in message: - return exc.message - if 'temporarily_unavailable' in message: - return exc.message - return None +def _is_retryable_error(exc: Exception) -> str: + message = str(exc).lower() + if "pending" in message or "temporarily_unavailable" in message: + return str(exc) + else: + return "" diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index dae2aaa4f..12c42ab98 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -4,7 +4,9 @@ import base64 from concurrent.futures import Future from dataclasses import dataclass -from typing import Optional, List, Dict, Any, Union, Iterable +from typing import Any, Dict, Iterable, List, Optional, Union +from typing_extensions import TypeAlias + import agate from dbt.contracts.relation import RelationType @@ -12,7 +14,7 @@ import dbt.exceptions from dbt.adapters.base import AdapterConfig -from dbt.adapters.base.impl import catch_as_completed +from dbt.adapters.base.impl import catch_as_completed, log_code_execution from dbt.adapters.base.meta import available from dbt.adapters.sql import SQLAdapter from dbt.adapters.spark import SparkConnectionManager @@ -25,7 +27,7 @@ logger = AdapterLogger("Spark") -GET_COLUMNS_IN_RELATION_MACRO_NAME = "get_columns_in_relation" +GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME = "get_columns_in_relation_raw" LIST_SCHEMAS_MACRO_NAME = "list_schemas" LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching" DROP_RELATION_MACRO_NAME = "drop_relation" @@ -78,10 +80,10 @@ class SparkAdapter(SQLAdapter): "_hoodie_file_name", ] - Relation = SparkRelation - Column = SparkColumn - ConnectionManager = SparkConnectionManager - AdapterSpecificConfigs = SparkConfig + Relation: TypeAlias = SparkRelation + Column: TypeAlias = SparkColumn + ConnectionManager: TypeAlias = SparkConnectionManager + AdapterSpecificConfigs: TypeAlias = SparkConfig @classmethod def date_function(cls) -> str: @@ -163,7 +165,7 @@ def list_relations_without_caching( def get_relation(self, database: str, schema: str, identifier: str) -> Optional[BaseRelation]: if not self.Relation.include_policy.database: - database = None + database = None # type: ignore return super().get_relation(database, schema, identifier) @@ -178,7 +180,7 @@ def parse_describe_extended( # Remove rows that start with a hash, they are comments rows = [row for row in raw_rows[0:pos] if not row["col_name"].startswith("#")] - metadata = {col["col_name"]: col["data_type"] for col in raw_rows[pos + 1:]} + metadata = {col["col_name"]: col["data_type"] for col in raw_rows[pos + 1 :]} raw_table_stats = metadata.get(KEY_TABLE_STATISTICS) table_stats = SparkColumn.convert_table_stats(raw_table_stats) @@ -225,7 +227,9 @@ def get_columns_in_relation(self, relation: Relation) -> List[SparkColumn]: # use get_columns_in_relation spark macro # which would execute 'describe extended tablename' query try: - rows: List[agate.Row] = super().get_columns_in_relation(relation) + rows: List[agate.Row] = self.execute_macro( + GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME, kwargs={"relation": relation} + ) columns = self.parse_describe_extended(relation, rows) except dbt.exceptions.RuntimeException as e: # spark would throw error when table doesn't exist, where other @@ -384,7 +388,8 @@ def run_sql_for_tests(self, sql, fetch, conn): conn.transaction_open = False @available.parse_none - def submit_python_job(self, parsed_model:dict, compiled_code: str, timeout=None): + @log_code_execution + def submit_python_job(self, parsed_model: dict, compiled_code: str, timeout=None): # TODO improve the typing here. N.B. Jinja returns a `jinja2.runtime.Undefined` instead # of `None` which evaluates to True! @@ -392,7 +397,7 @@ def submit_python_job(self, parsed_model:dict, compiled_code: str, timeout=None) # assuming that for python job running over 1 day user would mannually overwrite this schema = getattr(parsed_model, "schema", self.config.credentials.schema) - identifier = parsed_model['alias'] + identifier = parsed_model["alias"] if not timeout: timeout = 60 * 60 * 24 if timeout <= 0: @@ -414,7 +419,7 @@ def submit_python_job(self, parsed_model:dict, compiled_code: str, timeout=None) ) if response.status_code != 200: raise dbt.exceptions.RuntimeException( - f"Error creating work_dir for python notebooks\n {response.content}" + f"Error creating work_dir for python notebooks\n {response.content!r}" ) # add notebook @@ -432,7 +437,7 @@ def submit_python_job(self, parsed_model:dict, compiled_code: str, timeout=None) ) if response.status_code != 200: raise dbt.exceptions.RuntimeException( - f"Error creating python notebook.\n {response.content}" + f"Error creating python notebook.\n {response.content!r}" ) # submit job @@ -449,7 +454,7 @@ def submit_python_job(self, parsed_model:dict, compiled_code: str, timeout=None) ) if submit_response.status_code != 200: raise dbt.exceptions.RuntimeException( - f"Error creating python run.\n {response.content}" + f"Error creating python run.\n {response.content!r}" ) # poll until job finish @@ -466,7 +471,7 @@ def submit_python_job(self, parsed_model:dict, compiled_code: str, timeout=None) ) json_resp = resp.json() state = json_resp["state"]["life_cycle_state"] - logger.debug(f"Polling.... in state: {state}") + # logger.debug(f"Polling.... in state: {state}") if state != "TERMINATED": raise dbt.exceptions.RuntimeException( "python model run ended in state" @@ -490,6 +495,22 @@ def submit_python_job(self, parsed_model:dict, compiled_code: str, timeout=None) ) return self.connections.get_response(None) + def standardize_grants_dict(self, grants_table: agate.Table) -> dict: + grants_dict: Dict[str, List[str]] = {} + for row in grants_table: + grantee = row["Principal"] + privilege = row["ActionType"] + object_type = row["ObjectType"] + + # we only want to consider grants on this object + # (view or table both appear as 'TABLE') + # and we don't want to consider the OWN privilege + if object_type == "TABLE" and privilege != "OWN": + if privilege in grants_dict.keys(): + grants_dict[privilege].append(grantee) + else: + grants_dict.update({privilege: [grantee]}) + return grants_dict # spark does something interesting with joins when both tables have the same diff --git a/dbt/adapters/spark/relation.py b/dbt/adapters/spark/relation.py index 043cabfa0..249caf0d7 100644 --- a/dbt/adapters/spark/relation.py +++ b/dbt/adapters/spark/relation.py @@ -24,19 +24,19 @@ class SparkIncludePolicy(Policy): class SparkRelation(BaseRelation): quote_policy: SparkQuotePolicy = SparkQuotePolicy() include_policy: SparkIncludePolicy = SparkIncludePolicy() - quote_character: str = '`' + quote_character: str = "`" is_delta: Optional[bool] = None is_hudi: Optional[bool] = None - information: str = None + information: Optional[str] = None def __post_init__(self): if self.database != self.schema and self.database: - raise RuntimeException('Cannot set database in spark!') + raise RuntimeException("Cannot set database in spark!") def render(self): if self.include_policy.database and self.include_policy.schema: raise RuntimeException( - 'Got a spark relation with schema and database set to ' - 'include, but only one can be set' + "Got a spark relation with schema and database set to " + "include, but only one can be set" ) return super().render() diff --git a/dbt/adapters/spark/session.py b/dbt/adapters/spark/session.py index 6010df920..beb77d548 100644 --- a/dbt/adapters/spark/session.py +++ b/dbt/adapters/spark/session.py @@ -4,7 +4,7 @@ import datetime as dt from types import TracebackType -from typing import Any +from typing import Any, List, Optional, Tuple from dbt.events import AdapterLogger from dbt.utils import DECIMALS @@ -25,17 +25,17 @@ class Cursor: """ def __init__(self) -> None: - self._df: DataFrame | None = None - self._rows: list[Row] | None = None + self._df: Optional[DataFrame] = None + self._rows: Optional[List[Row]] = None def __enter__(self) -> Cursor: return self def __exit__( self, - exc_type: type[BaseException] | None, - exc_val: Exception | None, - exc_tb: TracebackType | None, + exc_type: Optional[BaseException], + exc_val: Optional[Exception], + exc_tb: Optional[TracebackType], ) -> bool: self.close() return True @@ -43,13 +43,13 @@ def __exit__( @property def description( self, - ) -> list[tuple[str, str, None, None, None, None, bool]]: + ) -> List[Tuple[str, str, None, None, None, None, bool]]: """ Get the description. Returns ------- - out : list[tuple[str, str, None, None, None, None, bool]] + out : List[Tuple[str, str, None, None, None, None, bool]] The description. Source @@ -109,13 +109,13 @@ def execute(self, sql: str, *parameters: Any) -> None: spark_session = SparkSession.builder.enableHiveSupport().getOrCreate() self._df = spark_session.sql(sql) - def fetchall(self) -> list[Row] | None: + def fetchall(self) -> Optional[List[Row]]: """ Fetch all data. Returns ------- - out : list[Row] | None + out : Optional[List[Row]] The rows. Source @@ -126,7 +126,7 @@ def fetchall(self) -> list[Row] | None: self._rows = self._df.collect() return self._rows - def fetchone(self) -> Row | None: + def fetchone(self) -> Optional[Row]: """ Fetch the first output. diff --git a/dbt/include/spark/__init__.py b/dbt/include/spark/__init__.py index 564a3d1e8..b177e5d49 100644 --- a/dbt/include/spark/__init__.py +++ b/dbt/include/spark/__init__.py @@ -1,2 +1,3 @@ import os + PACKAGE_PATH = os.path.dirname(__file__) diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index 2811e7b45..05630ede5 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -148,14 +148,14 @@ {{ compiled_code }} {%- endif -%} {%- elif language == 'python' -%} - {#-- + {#-- N.B. Python models _can_ write to temp views HOWEVER they use a different session and have already expired by the time they need to be used (I.E. in merges for incremental models) - - TODO: Deep dive into spark sessions to see if we can reuse a single session for an entire + + TODO: Deep dive into spark sessions to see if we can reuse a single session for an entire dbt invocation. --#} - {{ py_complete_script(compiled_code=compiled_code, target_relation=relation) }} + {{ py_write_table(compiled_code=compiled_code, target_relation=relation) }} {%- endif -%} {%- endmacro -%} @@ -179,11 +179,19 @@ {%- endcall -%} {% endmacro %} -{% macro spark__get_columns_in_relation(relation) -%} - {% call statement('get_columns_in_relation', fetch_result=True) %} +{% macro get_columns_in_relation_raw(relation) -%} + {{ return(adapter.dispatch('get_columns_in_relation_raw', 'dbt')(relation)) }} +{%- endmacro -%} + +{% macro spark__get_columns_in_relation_raw(relation) -%} + {% call statement('get_columns_in_relation_raw', fetch_result=True) %} describe extended {{ relation.include(schema=(schema is not none)) }} {% endcall %} - {% do return(load_result('get_columns_in_relation').table) %} + {% do return(load_result('get_columns_in_relation_raw').table) %} +{% endmacro %} + +{% macro spark__get_columns_in_relation(relation) -%} + {{ return(adapter.get_columns_in_relation(relation)) }} {% endmacro %} {% macro spark__list_relations_without_caching(relation) %} @@ -242,7 +250,7 @@ {% set comment = column_dict[column_name]['description'] %} {% set escaped_comment = comment | replace('\'', '\\\'') %} {% set comment_query %} - alter table {{ relation }} change column + alter table {{ relation }} change column {{ adapter.quote(column_name) if column_dict[column_name]['quote'] else column_name }} comment '{{ escaped_comment }}'; {% endset %} @@ -271,28 +279,27 @@ {% macro spark__alter_relation_add_remove_columns(relation, add_columns, remove_columns) %} - + {% if remove_columns %} {% set platform_name = 'Delta Lake' if relation.is_delta else 'Apache Spark' %} {{ exceptions.raise_compiler_error(platform_name + ' does not support dropping columns from tables') }} {% endif %} - + {% if add_columns is none %} {% set add_columns = [] %} {% endif %} - + {% set sql -%} - + alter {{ relation.type }} {{ relation }} - + {% if add_columns %} add columns {% endif %} {% for column in add_columns %} {{ column.name }} {{ column.data_type }}{{ ',' if not loop.last }} {% endfor %} - + {%- endset -%} {% do run_query(sql) %} {% endmacro %} - diff --git a/dbt/include/spark/macros/apply_grants.sql b/dbt/include/spark/macros/apply_grants.sql new file mode 100644 index 000000000..49dae95dc --- /dev/null +++ b/dbt/include/spark/macros/apply_grants.sql @@ -0,0 +1,39 @@ +{% macro spark__copy_grants() %} + + {% if config.materialized == 'view' %} + {#-- Spark views don't copy grants when they're replaced --#} + {{ return(False) }} + + {% else %} + {#-- This depends on how we're replacing the table, which depends on its file format + -- Just play it safe by assuming that grants have been copied over, and need to be checked / possibly revoked + -- We can make this more efficient in the future + #} + {{ return(True) }} + + {% endif %} +{% endmacro %} + + +{%- macro spark__get_grant_sql(relation, privilege, grantees) -%} + grant {{ privilege }} on {{ relation }} to {{ adapter.quote(grantees[0]) }} +{%- endmacro %} + + +{%- macro spark__get_revoke_sql(relation, privilege, grantees) -%} + revoke {{ privilege }} on {{ relation }} from {{ adapter.quote(grantees[0]) }} +{%- endmacro %} + + +{%- macro spark__support_multiple_grantees_per_dcl_statement() -%} + {{ return(False) }} +{%- endmacro -%} + + +{% macro spark__call_dcl_statements(dcl_statement_list) %} + {% for dcl_statement in dcl_statement_list %} + {% call statement('grant_or_revoke') %} + {{ dcl_statement }} + {% endcall %} + {% endfor %} +{% endmacro %} diff --git a/dbt/include/spark/macros/materializations/incremental/incremental.sql b/dbt/include/spark/macros/materializations/incremental/incremental.sql index 878a338f2..01ab0a328 100644 --- a/dbt/include/spark/macros/materializations/incremental/incremental.sql +++ b/dbt/include/spark/macros/materializations/incremental/incremental.sql @@ -1,14 +1,16 @@ {% materialization incremental, adapter='spark' -%} {#-- Validate early so we don't run SQL if the file_format + strategy combo is invalid --#} {%- set raw_file_format = config.get('file_format', default='parquet') -%} - {%- set raw_strategy = config.get('incremental_strategy', default='append') -%} + {%- set raw_strategy = config.get('incremental_strategy') or 'append' -%} + {%- set grant_config = config.get('grants') -%} + {%- set file_format = dbt_spark_validate_get_file_format(raw_file_format) -%} {%- set strategy = dbt_spark_validate_get_incremental_strategy(raw_strategy, file_format) -%} {#-- Set vars --#} {%- set unique_key = config.get('unique_key', none) -%} - {%- set partition_by = config.get('partition_by', none) -%} - {%- set language = config.get('language') -%} + {%- set partition_by = config.get('partition_by', none) -%} + {%- set language = model['language'] -%} {%- set on_schema_change = incremental_validate_on_schema_change(config.get('on_schema_change'), default='ignore') -%} {%- set target_relation = this -%} {%- set existing_relation = load_relation(this) -%} @@ -47,7 +49,7 @@ {%- endcall -%} {%- if language == 'python' -%} {#-- - This is yucky. + This is yucky. See note in dbt-spark/dbt/include/spark/macros/adapters.sql re: python models and temporary views. @@ -58,9 +60,12 @@ {%- endcall %} {%- endif -%} {%- endif -%} - + + {% set should_revoke = should_revoke(existing_relation, full_refresh_mode) %} + {% do apply_grants(target_relation, grant_config, should_revoke) %} + {% do persist_docs(target_relation, model) %} - + {{ run_hooks(post_hooks) }} {{ return({'relations': [target_relation]}) }} diff --git a/dbt/include/spark/macros/materializations/incremental/strategies.sql b/dbt/include/spark/macros/materializations/incremental/strategies.sql index 215b5f3f9..28b8f2001 100644 --- a/dbt/include/spark/macros/materializations/incremental/strategies.sql +++ b/dbt/include/spark/macros/materializations/incremental/strategies.sql @@ -1,5 +1,5 @@ {% macro get_insert_overwrite_sql(source_relation, target_relation) %} - + {%- set dest_columns = adapter.get_columns_in_relation(target_relation) -%} {%- set dest_cols_csv = dest_columns | map(attribute='quoted') | join(', ') -%} insert overwrite table {{ target_relation }} @@ -41,20 +41,20 @@ {% else %} {% do predicates.append('FALSE') %} {% endif %} - + {{ sql_header if sql_header is not none }} - + merge into {{ target }} as DBT_INTERNAL_DEST using {{ source.include(schema=false) }} as DBT_INTERNAL_SOURCE on {{ predicates | join(' and ') }} - + when matched then update set {% if update_columns -%}{%- for column_name in update_columns %} {{ column_name }} = DBT_INTERNAL_SOURCE.{{ column_name }} {%- if not loop.last %}, {%- endif %} {%- endfor %} {%- else %} * {% endif %} - + when not matched then insert * {% endmacro %} diff --git a/dbt/include/spark/macros/materializations/incremental/validate.sql b/dbt/include/spark/macros/materializations/incremental/validate.sql index 3e9de359b..ffd56f106 100644 --- a/dbt/include/spark/macros/materializations/incremental/validate.sql +++ b/dbt/include/spark/macros/materializations/incremental/validate.sql @@ -28,13 +28,13 @@ Invalid incremental strategy provided: {{ raw_strategy }} You can only choose this strategy when file_format is set to 'delta' or 'hudi' {%- endset %} - + {% set invalid_insert_overwrite_delta_msg -%} Invalid incremental strategy provided: {{ raw_strategy }} You cannot use this strategy when file_format is set to 'delta' Use the 'append' or 'merge' strategy instead {%- endset %} - + {% set invalid_insert_overwrite_endpoint_msg -%} Invalid incremental strategy provided: {{ raw_strategy }} You cannot use this strategy when connecting via endpoint diff --git a/dbt/include/spark/macros/materializations/snapshot.sql b/dbt/include/spark/macros/materializations/snapshot.sql index 82d186ce2..6cf2358fe 100644 --- a/dbt/include/spark/macros/materializations/snapshot.sql +++ b/dbt/include/spark/macros/materializations/snapshot.sql @@ -32,7 +32,7 @@ {% macro spark_build_snapshot_staging_table(strategy, sql, target_relation) %} {% set tmp_identifier = target_relation.identifier ~ '__dbt_tmp' %} - + {%- set tmp_relation = api.Relation.create(identifier=tmp_identifier, schema=target_relation.schema, database=none, @@ -75,6 +75,7 @@ {%- set strategy_name = config.get('strategy') -%} {%- set unique_key = config.get('unique_key') %} {%- set file_format = config.get('file_format', 'parquet') -%} + {%- set grant_config = config.get('grants') -%} {% set target_relation_exists, target_relation = get_or_create_relation( database=none, @@ -116,7 +117,7 @@ {% if not target_relation_exists %} - {% set build_sql = build_snapshot_table(strategy, model['compiled_sql']) %} + {% set build_sql = build_snapshot_table(strategy, model['compiled_code']) %} {% set final_sql = create_table_as(False, target_relation, build_sql) %} {% else %} @@ -163,6 +164,9 @@ {{ final_sql }} {% endcall %} + {% set should_revoke = should_revoke(target_relation_exists, full_refresh_mode) %} + {% do apply_grants(target_relation, grant_config, should_revoke) %} + {% do persist_docs(target_relation, model) %} {{ run_hooks(post_hooks, inside_transaction=True) }} diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql index 8ac005bf7..76cea9e57 100644 --- a/dbt/include/spark/macros/materializations/table.sql +++ b/dbt/include/spark/macros/materializations/table.sql @@ -1,6 +1,8 @@ {% materialization table, adapter = 'spark' %} - {%- set language = config.get('language') -%} + {%- set language = model['language'] -%} {%- set identifier = model['alias'] -%} + {%- set grant_config = config.get('grants') -%} + {%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%} {%- set target_relation = api.Relation.create(identifier=identifier, schema=schema, @@ -16,10 +18,14 @@ {{ adapter.drop_relation(old_relation) }} {%- endif %} + -- build model {%- call statement('main', language=language) -%} {{ create_table_as(False, target_relation, compiled_code, language) }} {%- endcall -%} + {% set should_revoke = should_revoke(old_relation, full_refresh_mode=True) %} + {% do apply_grants(target_relation, grant_config, should_revoke) %} + {% do persist_docs(target_relation, model) %} {{ run_hooks(post_hooks) }} @@ -29,11 +35,11 @@ {% endmaterialization %} -{% macro py_complete_script(compiled_code, target_relation) %} +{% macro py_write_table(compiled_code, target_relation) %} {{ compiled_code }} # --- Autogenerated dbt code below this line. Do not modify. --- # dbt = dbtObj(spark.table) -df = model(dbt) +df = model(dbt, spark) df.write.mode("overwrite").format("delta").saveAsTable("{{ target_relation }}") {%- endmacro -%} diff --git a/dbt/include/spark/macros/utils/any_value.sql b/dbt/include/spark/macros/utils/any_value.sql new file mode 100644 index 000000000..eb0a019b3 --- /dev/null +++ b/dbt/include/spark/macros/utils/any_value.sql @@ -0,0 +1,5 @@ +{% macro spark__any_value(expression) -%} + {#-- return any value (non-deterministic) --#} + first({{ expression }}) + +{%- endmacro %} diff --git a/dbt/include/spark/macros/utils/assert_not_null.sql b/dbt/include/spark/macros/utils/assert_not_null.sql new file mode 100644 index 000000000..e5454bce9 --- /dev/null +++ b/dbt/include/spark/macros/utils/assert_not_null.sql @@ -0,0 +1,9 @@ +{% macro assert_not_null(function, arg) -%} + {{ return(adapter.dispatch('assert_not_null', 'dbt')(function, arg)) }} +{%- endmacro %} + +{% macro spark__assert_not_null(function, arg) %} + + coalesce({{function}}({{arg}}), nvl2({{function}}({{arg}}), assert_true({{function}}({{arg}}) is not null), null)) + +{% endmacro %} diff --git a/dbt/include/spark/macros/utils/bool_or.sql b/dbt/include/spark/macros/utils/bool_or.sql new file mode 100644 index 000000000..60d705eb3 --- /dev/null +++ b/dbt/include/spark/macros/utils/bool_or.sql @@ -0,0 +1,11 @@ +{#-- Spark v3 supports 'bool_or' and 'any', but Spark v2 needs to use 'max' for this + -- https://spark.apache.org/docs/latest/api/sql/index.html#any + -- https://spark.apache.org/docs/latest/api/sql/index.html#bool_or + -- https://spark.apache.org/docs/latest/api/sql/index.html#max +#} + +{% macro spark__bool_or(expression) -%} + + max({{ expression }}) + +{%- endmacro %} diff --git a/dbt/include/spark/macros/utils/concat.sql b/dbt/include/spark/macros/utils/concat.sql new file mode 100644 index 000000000..30f1a420e --- /dev/null +++ b/dbt/include/spark/macros/utils/concat.sql @@ -0,0 +1,3 @@ +{% macro spark__concat(fields) -%} + concat({{ fields|join(', ') }}) +{%- endmacro %} diff --git a/dbt/include/spark/macros/utils/dateadd.sql b/dbt/include/spark/macros/utils/dateadd.sql new file mode 100644 index 000000000..e2a20d0f2 --- /dev/null +++ b/dbt/include/spark/macros/utils/dateadd.sql @@ -0,0 +1,62 @@ +{% macro spark__dateadd(datepart, interval, from_date_or_timestamp) %} + + {%- set clock_component -%} + {# make sure the dates + timestamps are real, otherwise raise an error asap #} + to_unix_timestamp({{ assert_not_null('to_timestamp', from_date_or_timestamp) }}) + - to_unix_timestamp({{ assert_not_null('date', from_date_or_timestamp) }}) + {%- endset -%} + + {%- if datepart in ['day', 'week'] -%} + + {%- set multiplier = 7 if datepart == 'week' else 1 -%} + + to_timestamp( + to_unix_timestamp( + date_add( + {{ assert_not_null('date', from_date_or_timestamp) }}, + cast({{interval}} * {{multiplier}} as int) + ) + ) + {{clock_component}} + ) + + {%- elif datepart in ['month', 'quarter', 'year'] -%} + + {%- set multiplier -%} + {%- if datepart == 'month' -%} 1 + {%- elif datepart == 'quarter' -%} 3 + {%- elif datepart == 'year' -%} 12 + {%- endif -%} + {%- endset -%} + + to_timestamp( + to_unix_timestamp( + add_months( + {{ assert_not_null('date', from_date_or_timestamp) }}, + cast({{interval}} * {{multiplier}} as int) + ) + ) + {{clock_component}} + ) + + {%- elif datepart in ('hour', 'minute', 'second', 'millisecond', 'microsecond') -%} + + {%- set multiplier -%} + {%- if datepart == 'hour' -%} 3600 + {%- elif datepart == 'minute' -%} 60 + {%- elif datepart == 'second' -%} 1 + {%- elif datepart == 'millisecond' -%} (1/1000000) + {%- elif datepart == 'microsecond' -%} (1/1000000) + {%- endif -%} + {%- endset -%} + + to_timestamp( + {{ assert_not_null('to_unix_timestamp', from_date_or_timestamp) }} + + cast({{interval}} * {{multiplier}} as int) + ) + + {%- else -%} + + {{ exceptions.raise_compiler_error("macro dateadd not implemented for datepart ~ '" ~ datepart ~ "' ~ on Spark") }} + + {%- endif -%} + +{% endmacro %} diff --git a/dbt/include/spark/macros/utils/datediff.sql b/dbt/include/spark/macros/utils/datediff.sql new file mode 100644 index 000000000..d0e684c47 --- /dev/null +++ b/dbt/include/spark/macros/utils/datediff.sql @@ -0,0 +1,107 @@ +{% macro spark__datediff(first_date, second_date, datepart) %} + + {%- if datepart in ['day', 'week', 'month', 'quarter', 'year'] -%} + + {# make sure the dates are real, otherwise raise an error asap #} + {% set first_date = assert_not_null('date', first_date) %} + {% set second_date = assert_not_null('date', second_date) %} + + {%- endif -%} + + {%- if datepart == 'day' -%} + + datediff({{second_date}}, {{first_date}}) + + {%- elif datepart == 'week' -%} + + case when {{first_date}} < {{second_date}} + then floor(datediff({{second_date}}, {{first_date}})/7) + else ceil(datediff({{second_date}}, {{first_date}})/7) + end + + -- did we cross a week boundary (Sunday)? + + case + when {{first_date}} < {{second_date}} and dayofweek({{second_date}}) < dayofweek({{first_date}}) then 1 + when {{first_date}} > {{second_date}} and dayofweek({{second_date}}) > dayofweek({{first_date}}) then -1 + else 0 end + + {%- elif datepart == 'month' -%} + + case when {{first_date}} < {{second_date}} + then floor(months_between(date({{second_date}}), date({{first_date}}))) + else ceil(months_between(date({{second_date}}), date({{first_date}}))) + end + + -- did we cross a month boundary? + + case + when {{first_date}} < {{second_date}} and dayofmonth({{second_date}}) < dayofmonth({{first_date}}) then 1 + when {{first_date}} > {{second_date}} and dayofmonth({{second_date}}) > dayofmonth({{first_date}}) then -1 + else 0 end + + {%- elif datepart == 'quarter' -%} + + case when {{first_date}} < {{second_date}} + then floor(months_between(date({{second_date}}), date({{first_date}}))/3) + else ceil(months_between(date({{second_date}}), date({{first_date}}))/3) + end + + -- did we cross a quarter boundary? + + case + when {{first_date}} < {{second_date}} and ( + (dayofyear({{second_date}}) - (quarter({{second_date}}) * 365/4)) + < (dayofyear({{first_date}}) - (quarter({{first_date}}) * 365/4)) + ) then 1 + when {{first_date}} > {{second_date}} and ( + (dayofyear({{second_date}}) - (quarter({{second_date}}) * 365/4)) + > (dayofyear({{first_date}}) - (quarter({{first_date}}) * 365/4)) + ) then -1 + else 0 end + + {%- elif datepart == 'year' -%} + + year({{second_date}}) - year({{first_date}}) + + {%- elif datepart in ('hour', 'minute', 'second', 'millisecond', 'microsecond') -%} + + {%- set divisor -%} + {%- if datepart == 'hour' -%} 3600 + {%- elif datepart == 'minute' -%} 60 + {%- elif datepart == 'second' -%} 1 + {%- elif datepart == 'millisecond' -%} (1/1000) + {%- elif datepart == 'microsecond' -%} (1/1000000) + {%- endif -%} + {%- endset -%} + + case when {{first_date}} < {{second_date}} + then ceil(( + {# make sure the timestamps are real, otherwise raise an error asap #} + {{ assert_not_null('to_unix_timestamp', assert_not_null('to_timestamp', second_date)) }} + - {{ assert_not_null('to_unix_timestamp', assert_not_null('to_timestamp', first_date)) }} + ) / {{divisor}}) + else floor(( + {{ assert_not_null('to_unix_timestamp', assert_not_null('to_timestamp', second_date)) }} + - {{ assert_not_null('to_unix_timestamp', assert_not_null('to_timestamp', first_date)) }} + ) / {{divisor}}) + end + + {% if datepart == 'millisecond' %} + + cast(date_format({{second_date}}, 'SSS') as int) + - cast(date_format({{first_date}}, 'SSS') as int) + {% endif %} + + {% if datepart == 'microsecond' %} + {% set capture_str = '[0-9]{4}-[0-9]{2}-[0-9]{2}.[0-9]{2}:[0-9]{2}:[0-9]{2}.([0-9]{6})' %} + -- Spark doesn't really support microseconds, so this is a massive hack! + -- It will only work if the timestamp-string is of the format + -- 'yyyy-MM-dd-HH mm.ss.SSSSSS' + + cast(regexp_extract({{second_date}}, '{{capture_str}}', 1) as int) + - cast(regexp_extract({{first_date}}, '{{capture_str}}', 1) as int) + {% endif %} + + {%- else -%} + + {{ exceptions.raise_compiler_error("macro datediff not implemented for datepart ~ '" ~ datepart ~ "' ~ on Spark") }} + + {%- endif -%} + +{% endmacro %} diff --git a/dbt/include/spark/macros/utils/listagg.sql b/dbt/include/spark/macros/utils/listagg.sql new file mode 100644 index 000000000..3577edb71 --- /dev/null +++ b/dbt/include/spark/macros/utils/listagg.sql @@ -0,0 +1,17 @@ +{% macro spark__listagg(measure, delimiter_text, order_by_clause, limit_num) -%} + + {% if order_by_clause %} + {{ exceptions.warn("order_by_clause is not supported for listagg on Spark/Databricks") }} + {% endif %} + + {% set collect_list %} collect_list({{ measure }}) {% endset %} + + {% set limited %} slice({{ collect_list }}, 1, {{ limit_num }}) {% endset %} + + {% set collected = limited if limit_num else collect_list %} + + {% set final %} array_join({{ collected }}, {{ delimiter_text }}) {% endset %} + + {% do return(final) %} + +{%- endmacro %} diff --git a/dbt/include/spark/macros/utils/split_part.sql b/dbt/include/spark/macros/utils/split_part.sql new file mode 100644 index 000000000..d5ae30924 --- /dev/null +++ b/dbt/include/spark/macros/utils/split_part.sql @@ -0,0 +1,23 @@ +{% macro spark__split_part(string_text, delimiter_text, part_number) %} + + {% set delimiter_expr %} + + -- escape if starts with a special character + case when regexp_extract({{ delimiter_text }}, '([^A-Za-z0-9])(.*)', 1) != '_' + then concat('\\', {{ delimiter_text }}) + else {{ delimiter_text }} end + + {% endset %} + + {% set split_part_expr %} + + split( + {{ string_text }}, + {{ delimiter_expr }} + )[({{ part_number - 1 }})] + + {% endset %} + + {{ return(split_part_expr) }} + +{% endmacro %} diff --git a/dev_requirements.txt b/dev-requirements.txt similarity index 51% rename from dev_requirements.txt rename to dev-requirements.txt index 0f84cbd5d..2314439a0 100644 --- a/dev_requirements.txt +++ b/dev-requirements.txt @@ -1,20 +1,24 @@ # install latest changes in dbt-core # TODO: how to automate switching from develop to version branches? -git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-core&subdirectory=core -git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-tests-adapter&subdirectory=tests/adapter +git+https://github.com/dbt-labs/dbt-core.git@feature/python-model-v1#egg=dbt-core&subdirectory=core +git+https://github.com/dbt-labs/dbt-core.git@feature/python-model-v1#egg=dbt-tests-adapter&subdirectory=tests/adapter +black==22.3.0 +bumpversion +click~=8.0.4 +flake8 +flaky freezegun==0.3.9 -pytest>=6.0.2 +ipdb mock>=1.3.0 -flake8 +mypy==0.950 +pre-commit +pytest-csv +pytest-dotenv +pytest-xdist +pytest>=6.0.2 pytz -bumpversion tox>=3.2.0 -ipdb -pytest-xdist -pytest-dotenv -pytest-csv -flaky # Test requirements sasl>=0.2.1 diff --git a/docker-compose.yml b/docker-compose.yml index 8054dfd75..9bc9e509c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,8 +1,8 @@ version: "3.7" services: - dbt-spark2-thrift: - image: godatadriven/spark:3.0 + dbt-spark3-thrift: + image: godatadriven/spark:3.1.1 ports: - "10000:10000" - "4040:4040" diff --git a/docker/spark-defaults.conf b/docker/spark-defaults.conf index 48a0501c2..30ec59591 100644 --- a/docker/spark-defaults.conf +++ b/docker/spark-defaults.conf @@ -1,7 +1,9 @@ +spark.driver.memory 2g +spark.executor.memory 2g spark.hadoop.datanucleus.autoCreateTables true spark.hadoop.datanucleus.schema.autoCreateTables true spark.hadoop.datanucleus.fixedDatastore false spark.serializer org.apache.spark.serializer.KryoSerializer -spark.jars.packages org.apache.hudi:hudi-spark3-bundle_2.12:0.9.0 +spark.jars.packages org.apache.hudi:hudi-spark3-bundle_2.12:0.10.0 spark.sql.extensions org.apache.spark.sql.hudi.HoodieSparkSessionExtension spark.driver.userClassPathFirst true diff --git a/requirements.txt b/requirements.txt index e03320a41..aab6b56a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ PyHive[hive]>=0.6.0,<0.7.0 -pyodbc>=4.0.30 +requests[python]>=2.28.1 +pyodbc==4.0.32 sqlparams>=3.0.0 thrift>=0.13.0 sqlparse>=0.4.2 # not directly required, pinned by Snyk to avoid a vulnerability diff --git a/scripts/build-dist.sh b/scripts/build-dist.sh index 65e6dbc97..3c3808399 100755 --- a/scripts/build-dist.sh +++ b/scripts/build-dist.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/bash set -eo pipefail diff --git a/setup.py b/setup.py index e9ba3cc1e..cb0c40aec 100644 --- a/setup.py +++ b/setup.py @@ -5,41 +5,39 @@ # require python 3.7 or newer if sys.version_info < (3, 7): - print('Error: dbt does not support this version of Python.') - print('Please upgrade to Python 3.7 or higher.') + print("Error: dbt does not support this version of Python.") + print("Please upgrade to Python 3.7 or higher.") sys.exit(1) # require version of setuptools that supports find_namespace_packages from setuptools import setup + try: from setuptools import find_namespace_packages except ImportError: # the user has a downlevel version of setuptools. - print('Error: dbt requires setuptools v40.1.0 or higher.') - print('Please upgrade setuptools with "pip install --upgrade setuptools" ' - 'and try again') + print("Error: dbt requires setuptools v40.1.0 or higher.") + print('Please upgrade setuptools with "pip install --upgrade setuptools" ' "and try again") sys.exit(1) # pull long description from README this_directory = os.path.abspath(os.path.dirname(__file__)) -with open(os.path.join(this_directory, 'README.md'), 'r', encoding='utf8') as f: +with open(os.path.join(this_directory, "README.md"), "r", encoding="utf8") as f: long_description = f.read() # get this package's version from dbt/adapters//__version__.py def _get_plugin_version_dict(): - _version_path = os.path.join( - this_directory, 'dbt', 'adapters', 'spark', '__version__.py' - ) - _semver = r'''(?P\d+)\.(?P\d+)\.(?P\d+)''' - _pre = r'''((?Pa|b|rc)(?P
\d+))?'''
-    _version_pattern = fr'''version\s*=\s*["']{_semver}{_pre}["']'''
+    _version_path = os.path.join(this_directory, "dbt", "adapters", "spark", "__version__.py")
+    _semver = r"""(?P\d+)\.(?P\d+)\.(?P\d+)"""
+    _pre = r"""((?Pa|b|rc)(?P
\d+))?"""
+    _version_pattern = fr"""version\s*=\s*["']{_semver}{_pre}["']"""
     with open(_version_path) as f:
         match = re.search(_version_pattern, f.read().strip())
         if match is None:
-            raise ValueError(f'invalid version at {_version_path}')
+            raise ValueError(f"invalid version at {_version_path}")
         return match.groupdict()
 
 
@@ -47,7 +45,7 @@ def _get_plugin_version_dict():
 def _get_dbt_core_version():
     parts = _get_plugin_version_dict()
     minor = "{major}.{minor}.0".format(**parts)
-    pre = (parts["prekind"]+"1" if parts["prekind"] else "")
+    pre = parts["prekind"] + "1" if parts["prekind"] else ""
     return f"{minor}{pre}"
 
 
@@ -56,33 +54,28 @@ def _get_dbt_core_version():
 dbt_core_version = _get_dbt_core_version()
 description = """The Apache Spark adapter plugin for dbt"""
 
-odbc_extras = ['pyodbc>=4.0.30']
+odbc_extras = ["pyodbc>=4.0.30"]
 pyhive_extras = [
-    'PyHive[hive]>=0.6.0,<0.7.0',
-    'thrift>=0.11.0,<0.16.0',
-]
-session_extras = [
-    "pyspark>=3.0.0,<4.0.0"
+    "PyHive[hive]>=0.6.0,<0.7.0",
+    "thrift>=0.11.0,<0.16.0",
 ]
+session_extras = ["pyspark>=3.0.0,<4.0.0"]
 all_extras = odbc_extras + pyhive_extras + session_extras
 
 setup(
     name=package_name,
     version=package_version,
-
     description=description,
     long_description=long_description,
-    long_description_content_type='text/markdown',
-
-    author='dbt Labs',
-    author_email='info@dbtlabs.com',
-    url='https://github.com/dbt-labs/dbt-spark',
-
-    packages=find_namespace_packages(include=['dbt', 'dbt.*']),
+    long_description_content_type="text/markdown",
+    author="dbt Labs",
+    author_email="info@dbtlabs.com",
+    url="https://github.com/dbt-labs/dbt-spark",
+    packages=find_namespace_packages(include=["dbt", "dbt.*"]),
     include_package_data=True,
     install_requires=[
-        'dbt-core~={}'.format(dbt_core_version),
-        'sqlparams>=3.0.0',
+        "dbt-core~={}".format(dbt_core_version),
+        "sqlparams>=3.0.0",
     ],
     extras_require={
         "ODBC": odbc_extras,
@@ -92,17 +85,14 @@ def _get_dbt_core_version():
     },
     zip_safe=False,
     classifiers=[
-        'Development Status :: 5 - Production/Stable',
-        
-        'License :: OSI Approved :: Apache Software License',
-        
-        'Operating System :: Microsoft :: Windows',
-        'Operating System :: MacOS :: MacOS X',
-        'Operating System :: POSIX :: Linux',
-
-        'Programming Language :: Python :: 3.7',
-        'Programming Language :: Python :: 3.8',
-        'Programming Language :: Python :: 3.9',
+        "Development Status :: 5 - Production/Stable",
+        "License :: OSI Approved :: Apache Software License",
+        "Operating System :: Microsoft :: Windows",
+        "Operating System :: MacOS :: MacOS X",
+        "Operating System :: POSIX :: Linux",
+        "Programming Language :: Python :: 3.7",
+        "Programming Language :: Python :: 3.8",
+        "Programming Language :: Python :: 3.9",
     ],
     python_requires=">=3.7",
 )
diff --git a/test.env.example b/test.env.example
new file mode 100644
index 000000000..e69f700b7
--- /dev/null
+++ b/test.env.example
@@ -0,0 +1,15 @@
+# Cluster ID
+DBT_DATABRICKS_CLUSTER_NAME=
+# SQL Endpoint
+DBT_DATABRICKS_ENDPOINT=
+# Server Hostname value
+DBT_DATABRICKS_HOST_NAME=
+# personal token
+DBT_DATABRICKS_TOKEN=
+# file path to local ODBC driver
+ODBC_DRIVER=
+
+# users for testing 'grants' functionality
+DBT_TEST_USER_1=
+DBT_TEST_USER_2=
+DBT_TEST_USER_3=
diff --git a/tests/conftest.py b/tests/conftest.py
index f2f0abcf0..3aefaecc6 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -8,7 +8,7 @@ def pytest_addoption(parser):
     parser.addoption("--profile", action="store", default="apache_spark", type=str)
 
 
-# Using @pytest.mark.skip_adapter('apache_spark') uses the 'skip_by_adapter_type'
+# Using @pytest.mark.skip_profile('apache_spark') uses the 'skip_by_profile_type'
 # autouse fixture below
 def pytest_configure(config):
     config.addinivalue_line(
@@ -109,4 +109,4 @@ def skip_by_profile_type(request):
     if request.node.get_closest_marker("skip_profile"):
         for skip_profile_type in request.node.get_closest_marker("skip_profile").args:
             if skip_profile_type == profile_type:
-                pytest.skip("skipped on '{profile_type}' profile")
+                pytest.skip(f"skipped on '{profile_type}' profile")
diff --git a/tests/functional/adapter/test_basic.py b/tests/functional/adapter/test_basic.py
index 70f3267a4..bdccf169d 100644
--- a/tests/functional/adapter/test_basic.py
+++ b/tests/functional/adapter/test_basic.py
@@ -64,7 +64,7 @@ def project_config_update(self):
         }
 
 
-#hese tests were not enabled in the dbtspec files, so skipping here.
+# These tests were not enabled in the dbtspec files, so skipping here.
 # Error encountered was: Error running query: java.lang.ClassNotFoundException: delta.DefaultSource
 @pytest.mark.skip_profile('apache_spark', 'spark_session')
 class TestSnapshotTimestampSpark(BaseSnapshotTimestamp):
@@ -79,5 +79,6 @@ def project_config_update(self):
             }
         }
 
+@pytest.mark.skip_profile('spark_session')
 class TestBaseAdapterMethod(BaseAdapterMethod):
-    pass
\ No newline at end of file
+    pass
diff --git a/tests/functional/adapter/test_grants.py b/tests/functional/adapter/test_grants.py
new file mode 100644
index 000000000..8e0341df6
--- /dev/null
+++ b/tests/functional/adapter/test_grants.py
@@ -0,0 +1,60 @@
+import pytest
+from dbt.tests.adapter.grants.test_model_grants import BaseModelGrants
+from dbt.tests.adapter.grants.test_incremental_grants import BaseIncrementalGrants
+from dbt.tests.adapter.grants.test_invalid_grants import BaseInvalidGrants
+from dbt.tests.adapter.grants.test_seed_grants import BaseSeedGrants
+from dbt.tests.adapter.grants.test_snapshot_grants import BaseSnapshotGrants
+
+
+@pytest.mark.skip_profile("apache_spark", "spark_session")
+class TestModelGrantsSpark(BaseModelGrants):
+    def privilege_grantee_name_overrides(self):
+        # insert --> modify
+        return {
+            "select": "select",
+            "insert": "modify",
+            "fake_privilege": "fake_privilege",
+            "invalid_user": "invalid_user",
+        }
+
+
+@pytest.mark.skip_profile("apache_spark", "spark_session")
+class TestIncrementalGrantsSpark(BaseIncrementalGrants):
+    @pytest.fixture(scope="class")
+    def project_config_update(self):
+        return {
+            "models": {
+                "+file_format": "delta",
+                "+incremental_strategy": "merge",
+            }
+        }
+
+
+@pytest.mark.skip_profile("apache_spark", "spark_session")
+class TestSeedGrantsSpark(BaseSeedGrants):
+    # seeds in dbt-spark are currently "full refreshed," in such a way that
+    # the grants are not carried over
+    # see https://github.com/dbt-labs/dbt-spark/issues/388
+    def seeds_support_partial_refresh(self):
+        return False
+
+
+@pytest.mark.skip_profile("apache_spark", "spark_session")
+class TestSnapshotGrantsSpark(BaseSnapshotGrants):
+    @pytest.fixture(scope="class")
+    def project_config_update(self):
+        return {
+            "snapshots": {
+                "+file_format": "delta",
+                "+incremental_strategy": "merge",
+            }
+        }
+
+
+@pytest.mark.skip_profile("apache_spark", "spark_session")
+class TestInvalidGrantsSpark(BaseInvalidGrants):
+    def grantee_does_not_exist_error(self):
+        return "RESOURCE_DOES_NOT_EXIST"
+        
+    def privilege_does_not_exist_error(self):
+        return "Action Unknown"
diff --git a/tests/functional/adapter/test_python.py b/tests/functional/adapter/test_python.py
deleted file mode 100644
index 103ed1829..000000000
--- a/tests/functional/adapter/test_python.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import pytest
-from dbt.tests.util import run_dbt, write_file
-from dbt.tests.adapter.python_model.test_python_model import BasePythonModelTests
-
-class TestPythonModelSpark(BasePythonModelTests):
-    pass
-
-models__simple_python_model = """
-import pandas
-
-def model(dbt):
-    dbt.config(
-        materialized='table',
-    )
-    data = [[1,2]] * 10
-    return spark.createDataFrame(data, schema=['test', 'test2'])
-"""
-models__simple_python_model_v2 = """
-import pandas
-
-def model(dbt):
-    dbt.config(
-        materialized='table',
-    )
-    data = [[1,2]] * 10
-    return spark.createDataFrame(data, schema=['test1', 'test3'])
-"""
-
-class TestChangingSchemaSnowflake:
-    @pytest.fixture(scope="class")
-    def models(self):
-        return {
-            "simple_python_model.py": models__simple_python_model
-            }
-    def test_changing_schema(self,project):
-        run_dbt(["run"])
-        write_file(models__simple_python_model_v2, project.project_root + '/models', "simple_python_model.py")
-        run_dbt(["run"])
\ No newline at end of file
diff --git a/tests/functional/adapter/test_python_model.py b/tests/functional/adapter/test_python_model.py
new file mode 100644
index 000000000..7cf2d596c
--- /dev/null
+++ b/tests/functional/adapter/test_python_model.py
@@ -0,0 +1,55 @@
+import os
+import pytest
+from dbt.tests.util import run_dbt, write_file, run_dbt_and_capture
+from dbt.tests.adapter.python_model.test_python_model import BasePythonModelTests
+
+@pytest.skip("Need to supply extra config", allow_module_level=True)
+class TestPythonModelSpark(BasePythonModelTests):
+    pass
+
+
+models__simple_python_model = """
+import pandas
+
+def model(dbt, spark):
+    dbt.config(
+        materialized='table',
+    )
+    data = [[1,2]] * 10
+    return spark.createDataFrame(data, schema=['test', 'test2'])
+"""
+models__simple_python_model_v2 = """
+import pandas
+
+def model(dbt, spark):
+    dbt.config(
+        materialized='table',
+    )
+    data = [[1,2]] * 10
+    return spark.createDataFrame(data, schema=['test1', 'test3'])
+"""
+
+
+
+@pytest.skip("Need to supply extra config", allow_module_level=True)
+@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint")
+class TestChangingSchemaSpark:
+    @pytest.fixture(scope="class")
+    def models(self):
+        return {"simple_python_model.py": models__simple_python_model}
+
+    def test_changing_schema_with_log_validation(self, project, logs_dir):
+        run_dbt(["run"])
+        write_file(
+            models__simple_python_model_v2,
+            project.project_root + "/models",
+            "simple_python_model.py",
+        )
+        run_dbt(["run"])
+        log_file = os.path.join(logs_dir, "dbt.log")
+        with open(log_file, "r") as f:
+            log = f.read()
+            # validate #5510 log_code_execution works
+            assert "On model.test.simple_python_model:" in log
+            assert "spark.createDataFrame(data, schema=['test1', 'test3'])" in log
+            assert "Execution status: OK in" in log
diff --git a/tests/functional/adapter/utils/fixture_listagg.py b/tests/functional/adapter/utils/fixture_listagg.py
new file mode 100644
index 000000000..0262ca234
--- /dev/null
+++ b/tests/functional/adapter/utils/fixture_listagg.py
@@ -0,0 +1,61 @@
+# SparkSQL does not support 'order by' for its 'listagg' equivalent
+# the argument is ignored, so let's ignore those fields when checking equivalency
+
+models__test_listagg_no_order_by_sql = """
+with data as (
+    select * from {{ ref('data_listagg') }}
+),
+data_output as (
+    select * from {{ ref('data_listagg_output') }}
+),
+calculate as (
+/*
+
+    select
+        group_col,
+        {{ listagg('string_text', "'_|_'", "order by order_col") }} as actual,
+        'bottom_ordered' as version
+    from data
+    group by group_col
+    union all
+    select
+        group_col,
+        {{ listagg('string_text', "'_|_'", "order by order_col", 2) }} as actual,
+        'bottom_ordered_limited' as version
+    from data
+    group by group_col
+    union all
+
+*/
+    select
+        group_col,
+        {{ listagg('string_text', "', '") }} as actual,
+        'comma_whitespace_unordered' as version
+    from data
+    where group_col = 3
+    group by group_col
+    union all
+    select
+        group_col,
+        {{ listagg('DISTINCT string_text', "','") }} as actual,
+        'distinct_comma' as version
+    from data
+    where group_col = 3
+    group by group_col
+    union all
+    select
+        group_col,
+        {{ listagg('string_text') }} as actual,
+        'no_params' as version
+    from data
+    where group_col = 3
+    group by group_col
+)
+select
+    calculate.actual,
+    data_output.expected
+from calculate
+left join data_output
+on calculate.group_col = data_output.group_col
+and calculate.version = data_output.version
+"""
diff --git a/tests/functional/adapter/utils/test_data_types.py b/tests/functional/adapter/utils/test_data_types.py
new file mode 100644
index 000000000..65a24a3a9
--- /dev/null
+++ b/tests/functional/adapter/utils/test_data_types.py
@@ -0,0 +1,67 @@
+import pytest
+from dbt.tests.adapter.utils.data_types.test_type_bigint import BaseTypeBigInt
+from dbt.tests.adapter.utils.data_types.test_type_float import (
+    BaseTypeFloat, seeds__expected_csv as seeds__float_expected_csv
+)
+from dbt.tests.adapter.utils.data_types.test_type_int import (
+    BaseTypeInt, seeds__expected_csv as seeds__int_expected_csv
+)
+from dbt.tests.adapter.utils.data_types.test_type_numeric import BaseTypeNumeric
+from dbt.tests.adapter.utils.data_types.test_type_string import BaseTypeString
+from dbt.tests.adapter.utils.data_types.test_type_timestamp import BaseTypeTimestamp
+
+
+class TestTypeBigInt(BaseTypeBigInt):
+    pass
+
+
+# need to explicitly cast this to avoid it being inferred/loaded as a DOUBLE on Spark
+# in SparkSQL, the two are equivalent for `=` comparison, but distinct for EXCEPT comparison
+seeds__float_expected_yml = """
+version: 2
+seeds:
+  - name: expected
+    config:
+      column_types:
+        float_col: float
+"""
+
+class TestTypeFloat(BaseTypeFloat):
+    @pytest.fixture(scope="class")
+    def seeds(self):
+        return {
+            "expected.csv": seeds__float_expected_csv,
+            "expected.yml": seeds__float_expected_yml,
+        }
+
+
+# need to explicitly cast this to avoid it being inferred/loaded as a BIGINT on Spark
+seeds__int_expected_yml = """
+version: 2
+seeds:
+  - name: expected
+    config:
+      column_types:
+        int_col: int
+"""
+
+class TestTypeInt(BaseTypeInt):
+    @pytest.fixture(scope="class")
+    def seeds(self):
+        return {
+            "expected.csv": seeds__int_expected_csv,
+            "expected.yml": seeds__int_expected_yml,
+        }
+
+    
+class TestTypeNumeric(BaseTypeNumeric):
+    def numeric_fixture_type(self):
+        return "decimal(28,6)"
+
+    
+class TestTypeString(BaseTypeString):
+    pass
+
+    
+class TestTypeTimestamp(BaseTypeTimestamp):
+    pass
diff --git a/tests/functional/adapter/utils/test_utils.py b/tests/functional/adapter/utils/test_utils.py
new file mode 100644
index 000000000..c71161e65
--- /dev/null
+++ b/tests/functional/adapter/utils/test_utils.py
@@ -0,0 +1,122 @@
+import pytest
+
+from dbt.tests.adapter.utils.test_any_value import BaseAnyValue
+from dbt.tests.adapter.utils.test_bool_or import BaseBoolOr
+from dbt.tests.adapter.utils.test_cast_bool_to_text import BaseCastBoolToText
+from dbt.tests.adapter.utils.test_concat import BaseConcat
+from dbt.tests.adapter.utils.test_dateadd import BaseDateAdd
+from dbt.tests.adapter.utils.test_datediff import BaseDateDiff
+from dbt.tests.adapter.utils.test_date_trunc import BaseDateTrunc
+from dbt.tests.adapter.utils.test_escape_single_quotes import BaseEscapeSingleQuotesQuote
+from dbt.tests.adapter.utils.test_escape_single_quotes import BaseEscapeSingleQuotesBackslash
+from dbt.tests.adapter.utils.test_except import BaseExcept
+from dbt.tests.adapter.utils.test_hash import BaseHash
+from dbt.tests.adapter.utils.test_intersect import BaseIntersect
+from dbt.tests.adapter.utils.test_last_day import BaseLastDay
+from dbt.tests.adapter.utils.test_length import BaseLength
+from dbt.tests.adapter.utils.test_position import BasePosition
+from dbt.tests.adapter.utils.test_replace import BaseReplace
+from dbt.tests.adapter.utils.test_right import BaseRight
+from dbt.tests.adapter.utils.test_safe_cast import BaseSafeCast
+from dbt.tests.adapter.utils.test_split_part import BaseSplitPart
+from dbt.tests.adapter.utils.test_string_literal import BaseStringLiteral
+
+# requires modification
+from dbt.tests.adapter.utils.test_listagg import BaseListagg
+from dbt.tests.adapter.utils.fixture_listagg import models__test_listagg_yml
+from tests.functional.adapter.utils.fixture_listagg import models__test_listagg_no_order_by_sql
+
+
+class TestAnyValue(BaseAnyValue):
+    pass
+
+
+class TestBoolOr(BaseBoolOr):
+    pass
+
+
+class TestCastBoolToText(BaseCastBoolToText):
+    pass
+
+
+@pytest.mark.skip_profile('spark_session')
+class TestConcat(BaseConcat):
+    pass
+
+
+class TestDateAdd(BaseDateAdd):
+    pass
+
+
+# this generates too much SQL to run successfully in our testing environments :(
+@pytest.mark.skip_profile('apache_spark', 'spark_session')
+class TestDateDiff(BaseDateDiff):
+    pass
+
+
+class TestDateTrunc(BaseDateTrunc):
+    pass
+
+
+class TestEscapeSingleQuotes(BaseEscapeSingleQuotesQuote):
+    pass
+
+
+class TestExcept(BaseExcept):
+    pass
+
+
+@pytest.mark.skip_profile('spark_session')
+class TestHash(BaseHash):
+    pass
+
+
+class TestIntersect(BaseIntersect):
+    pass
+
+
+class TestLastDay(BaseLastDay):
+    pass
+
+
+class TestLength(BaseLength):
+    pass
+
+
+# SparkSQL does not support 'order by' for its 'listagg' equivalent
+# the argument is ignored, so let's ignore those fields when checking equivalency
+class TestListagg(BaseListagg):
+    @pytest.fixture(scope="class")
+    def models(self):
+        return {
+            "test_listagg.yml": models__test_listagg_yml,
+            "test_listagg.sql": self.interpolate_macro_namespace(
+                models__test_listagg_no_order_by_sql, "listagg"
+            ),
+        }
+
+
+class TestPosition(BasePosition):
+    pass
+
+
+@pytest.mark.skip_profile('spark_session')
+class TestReplace(BaseReplace):
+    pass
+
+
+@pytest.mark.skip_profile('spark_session')
+class TestRight(BaseRight):
+    pass
+
+
+class TestSafeCast(BaseSafeCast):
+    pass
+
+
+class TestSplitPart(BaseSplitPart):
+    pass
+
+
+class TestStringLiteral(BaseStringLiteral):
+    pass
diff --git a/tox.ini b/tox.ini
index 1e0e2b8b6..a75e2a26a 100644
--- a/tox.ini
+++ b/tox.ini
@@ -2,21 +2,13 @@
 skipsdist = True
 envlist = unit, flake8, integration-spark-thrift
 
-
-[testenv:flake8]
-basepython = python3.8
-commands = /bin/bash -c '$(which flake8) --max-line-length 99 --select=E,W,F --ignore=W504 dbt/'
-passenv = DBT_* PYTEST_ADDOPTS
-deps =
-     -r{toxinidir}/dev_requirements.txt
-
 [testenv:unit]
 basepython = python3.8
 commands = /bin/bash -c '{envpython} -m pytest -v {posargs} tests/unit'
 passenv = DBT_* PYTEST_ADDOPTS
 deps =
     -r{toxinidir}/requirements.txt
-    -r{toxinidir}/dev_requirements.txt
+    -r{toxinidir}/dev-requirements.txt
 
 [testenv:integration-spark-databricks-http]
 basepython = python3.8
@@ -24,7 +16,7 @@ commands = /bin/bash -c '{envpython} -m pytest -v --profile databricks_http_clus
 passenv = DBT_* PYTEST_ADDOPTS
 deps =
     -r{toxinidir}/requirements.txt
-    -r{toxinidir}/dev_requirements.txt
+    -r{toxinidir}/dev-requirements.txt
     -e.
 
 [testenv:integration-spark-databricks-odbc-cluster]
@@ -34,7 +26,7 @@ commands = /bin/bash -c '{envpython} -m pytest -v --profile databricks_cluster {
 passenv = DBT_* PYTEST_ADDOPTS ODBC_DRIVER
 deps =
     -r{toxinidir}/requirements.txt
-    -r{toxinidir}/dev_requirements.txt
+    -r{toxinidir}/dev-requirements.txt
     -e.
 
 [testenv:integration-spark-databricks-odbc-sql-endpoint]
@@ -44,7 +36,7 @@ commands = /bin/bash -c '{envpython} -m pytest -v --profile databricks_sql_endpo
 passenv = DBT_* PYTEST_ADDOPTS ODBC_DRIVER
 deps =
     -r{toxinidir}/requirements.txt
-    -r{toxinidir}/dev_requirements.txt
+    -r{toxinidir}/dev-requirements.txt
     -e.
 
 
@@ -55,7 +47,7 @@ commands = /bin/bash -c '{envpython} -m pytest -v --profile apache_spark {posarg
 passenv = DBT_* PYTEST_ADDOPTS
 deps =
     -r{toxinidir}/requirements.txt
-    -r{toxinidir}/dev_requirements.txt
+    -r{toxinidir}/dev-requirements.txt
     -e.
 
 [testenv:integration-spark-session]
@@ -67,5 +59,5 @@ passenv =
     PIP_CACHE_DIR
 deps =
     -r{toxinidir}/requirements.txt
-    -r{toxinidir}/dev_requirements.txt
+    -r{toxinidir}/dev-requirements.txt
     -e.[session]

From c907a002af14520c28b5032a5ca82fd4b75ecaa8 Mon Sep 17 00:00:00 2001
From: Chenyu Li 
Date: Mon, 25 Jul 2022 21:32:05 -0700
Subject: [PATCH 30/35] make black happy

---
 dbt/adapters/spark/impl.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py
index 0c202402f..12c42ab98 100644
--- a/dbt/adapters/spark/impl.py
+++ b/dbt/adapters/spark/impl.py
@@ -163,7 +163,6 @@ def list_relations_without_caching(
 
         return relations
 
-
     def get_relation(self, database: str, schema: str, identifier: str) -> Optional[BaseRelation]:
         if not self.Relation.include_policy.database:
             database = None  # type: ignore

From 0aebf04c4686212e79abe281de30bc95bb0a7cee Mon Sep 17 00:00:00 2001
From: Chenyu Li 
Date: Wed, 27 Jul 2022 11:54:33 -0700
Subject: [PATCH 31/35] enable python model test (#409)

---
 tests/conftest.py                             | 3 ++-
 tests/functional/adapter/test_python_model.py | 4 +---
 2 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/tests/conftest.py b/tests/conftest.py
index 3aefaecc6..2fa50d6c7 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -60,6 +60,7 @@ def databricks_cluster_target():
         "connect_retries": 3,
         "connect_timeout": 5,
         "retry_all": True,
+        "user": os.getenv('DBT_DATABRICKS_USER'),
     }
 
 
@@ -91,7 +92,7 @@ def databricks_http_cluster_target():
         "connect_retries": 5,
         "connect_timeout": 60, 
         "retry_all": bool(os.getenv('DBT_DATABRICKS_RETRY_ALL', False)),
-        "user": os.getenv('DBT_DATABRICKS_USER')
+        "user": os.getenv('DBT_DATABRICKS_USER'),
     }
 
 
diff --git a/tests/functional/adapter/test_python_model.py b/tests/functional/adapter/test_python_model.py
index 7cf2d596c..c6f7d7cb7 100644
--- a/tests/functional/adapter/test_python_model.py
+++ b/tests/functional/adapter/test_python_model.py
@@ -3,7 +3,7 @@
 from dbt.tests.util import run_dbt, write_file, run_dbt_and_capture
 from dbt.tests.adapter.python_model.test_python_model import BasePythonModelTests
 
-@pytest.skip("Need to supply extra config", allow_module_level=True)
+@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint")
 class TestPythonModelSpark(BasePythonModelTests):
     pass
 
@@ -30,8 +30,6 @@ def model(dbt, spark):
 """
 
 
-
-@pytest.skip("Need to supply extra config", allow_module_level=True)
 @pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint")
 class TestChangingSchemaSpark:
     @pytest.fixture(scope="class")

From 6efee9ea497ec65156b2f711e89663af1fef56fc Mon Sep 17 00:00:00 2001
From: Chenyu Li 
Date: Wed, 27 Jul 2022 16:42:07 -0700
Subject: [PATCH 32/35] skip test that failed on main

---
 .../incremental_strategies/test_incremental_strategies.py  | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/tests/integration/incremental_strategies/test_incremental_strategies.py b/tests/integration/incremental_strategies/test_incremental_strategies.py
index 839f167e6..f578b88b5 100644
--- a/tests/integration/incremental_strategies/test_incremental_strategies.py
+++ b/tests/integration/incremental_strategies/test_incremental_strategies.py
@@ -60,9 +60,10 @@ def run_and_test(self):
     def test_insert_overwrite_apache_spark(self):
         self.run_and_test()
 
-    @use_profile("databricks_cluster")
-    def test_insert_overwrite_databricks_cluster(self):
-        self.run_and_test()
+    # https://github.com/dbt-labs/dbt-spark/issues/410 tracks it
+    # @use_profile("databricks_cluster")
+    # def test_insert_overwrite_databricks_cluster(self):
+    #     self.run_and_test()
 
 
 class TestDeltaStrategies(TestIncrementalStrategies):

From 7e8943b0bd4e6afca41224f343929486f5d1f9cf Mon Sep 17 00:00:00 2001
From: Chenyu Li 
Date: Wed, 27 Jul 2022 17:38:13 -0700
Subject: [PATCH 33/35] add comment to run code

---
 dbt/include/spark/macros/materializations/table.sql | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql
index 3acbdf053..6a02ea164 100644
--- a/dbt/include/spark/macros/materializations/table.sql
+++ b/dbt/include/spark/macros/materializations/table.sql
@@ -38,9 +38,14 @@
 
 {% macro py_write_table(compiled_code, target_relation) %}
 {{ compiled_code }}
-
-# --- Autogenerated dbt code below this line. Do not modify. --- #
+# --- Autogenerated dbt materialization code. --- #
 dbt = dbtObj(spark.table)
 df = model(dbt, spark)
 df.write.mode("overwrite").format("delta").saveAsTable("{{ target_relation }}")
 {%- endmacro -%}
+
+{%macro py_script_comment()%}
+# how to execute python model in notebook
+# dbt = dbtObj(spark.table)
+# df = model(dbt, spark)
+{%endmacro%}

From 34f144e6538ad1493651eba834d192d94c4514d9 Mon Sep 17 00:00:00 2001
From: Chenyu Li 
Date: Thu, 28 Jul 2022 11:52:11 -0700
Subject: [PATCH 34/35] using core code and bring back incremental test

---
 dev-requirements.txt                                     | 4 ++--
 tests/functional/adapter/test_python_model.py            | 8 +++++++-
 .../test_incremental_strategies.py                       | 9 +++++----
 3 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/dev-requirements.txt b/dev-requirements.txt
index 409f8c50d..5b29e5e9d 100644
--- a/dev-requirements.txt
+++ b/dev-requirements.txt
@@ -1,7 +1,7 @@
 # install latest changes in dbt-core
 # TODO: how to automate switching from develop to version branches?
-git+https://github.com/dbt-labs/dbt-core.git@feature/python-model-v1#egg=dbt-core&subdirectory=core
-git+https://github.com/dbt-labs/dbt-core.git@feature/python-model-v1#egg=dbt-tests-adapter&subdirectory=tests/adapter
+git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-core&subdirectory=core
+git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-tests-adapter&subdirectory=tests/adapter
 
 
 
diff --git a/tests/functional/adapter/test_python_model.py b/tests/functional/adapter/test_python_model.py
index c6f7d7cb7..059412f10 100644
--- a/tests/functional/adapter/test_python_model.py
+++ b/tests/functional/adapter/test_python_model.py
@@ -1,12 +1,18 @@
 import os
 import pytest
 from dbt.tests.util import run_dbt, write_file, run_dbt_and_capture
-from dbt.tests.adapter.python_model.test_python_model import BasePythonModelTests
+from dbt.tests.adapter.python_model.test_python_model import BasePythonModelTests, BasePythonIncrementalTests
 
 @pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint")
 class TestPythonModelSpark(BasePythonModelTests):
     pass
 
+@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint")
+class TestPythonIncrementalModelSpark(BasePythonIncrementalTests):
+    @pytest.fixture(scope="class")
+    def project_config_update(self):
+        return {}
+
 
 models__simple_python_model = """
 import pandas
diff --git a/tests/integration/incremental_strategies/test_incremental_strategies.py b/tests/integration/incremental_strategies/test_incremental_strategies.py
index f578b88b5..3848d11ae 100644
--- a/tests/integration/incremental_strategies/test_incremental_strategies.py
+++ b/tests/integration/incremental_strategies/test_incremental_strategies.py
@@ -60,10 +60,11 @@ def run_and_test(self):
     def test_insert_overwrite_apache_spark(self):
         self.run_and_test()
 
-    # https://github.com/dbt-labs/dbt-spark/issues/410 tracks it
-    # @use_profile("databricks_cluster")
-    # def test_insert_overwrite_databricks_cluster(self):
-    #     self.run_and_test()
+    # This test requires settings on the test cluster
+    # more info at https://docs.getdbt.com/reference/resource-configs/spark-configs#the-insert_overwrite-strategy
+    @use_profile("databricks_cluster")
+    def test_insert_overwrite_databricks_cluster(self):
+        self.run_and_test()
 
 
 class TestDeltaStrategies(TestIncrementalStrategies):

From fa303d9682021d4a876be7b74d907dc927a4b2f2 Mon Sep 17 00:00:00 2001
From: Chenyu Li 
Date: Thu, 28 Jul 2022 12:19:44 -0700
Subject: [PATCH 35/35] add changelog

---
 CHANGELOG.md | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 28f7e138b..d015a26c7 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,8 @@
 ## dbt-spark 1.3.0b1 (Release TBD)
 
+### Features
+- support python model through notebook, currently supported materializations are table and incremental. ([#377](https://github.com/dbt-labs/dbt-spark/pull/377))
+
 ### Fixes
 - Pin `pyodbc` to version 4.0.32 to prevent overwriting `libodbc.so` and `libltdl.so` on Linux ([#397](https://github.com/dbt-labs/dbt-spark/issues/397/), [#398](https://github.com/dbt-labs/dbt-spark/pull/398/))