From 73fb0f960008f0ee6d114cb89fb8cd7fe68f5f91 Mon Sep 17 00:00:00 2001 From: Abhishek-N Date: Thu, 2 May 2024 10:36:12 +0530 Subject: [PATCH 01/12] add raw sql generic op --- dbt_automation/operations/rawsql.py | 60 +++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 dbt_automation/operations/rawsql.py diff --git a/dbt_automation/operations/rawsql.py b/dbt_automation/operations/rawsql.py new file mode 100644 index 0000000..71f6ce1 --- /dev/null +++ b/dbt_automation/operations/rawsql.py @@ -0,0 +1,60 @@ +from sql_metadata import Parser + +from dbt_automation.utils.columnutils import quote_columnname +from dbt_automation.utils.dbtproject import dbtProject +from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface +from dbt_automation.utils.tableutils import source_or_ref + +def raw_generic_dbt_sql( + config: str, + warehouse: WarehouseInterface, +): + """ + Parses the given SQL statement to extract tables and columns and generates DBT code. + """ + source_columns = config.get('source_columns', []) + sql_statement = config.get('raw_sql') + parser = Parser(sql_statement) + tables = parser.tables + columns = parser.columns + + if columns == "*": + dbt_code = "SELECT *" + else: + dbt_code = f"SELECT {', '.join([quote_columnname(col, warehouse.name) for col in columns])}" + + if len(tables) == 1: + config['input']['input_name'] = tables[0] + select_from = source_or_ref(tables[0], tables[0], "source") + else: + select_from = " JOIN ".join([f"{{{{ source('{table}', '{table}') }}}}" for table in tables]) + + select_from = source_or_ref(**config["input"]) + if config["input"]["input_type"] == "cte": + dbt_code += "\n FROM " + select_from + "\n" + else: + dbt_code += "\n FROM " + "{{" + select_from + "}}" + "\n" + + breakpoint() + return dbt_code, source_columns + +def generic_sql_function(config: dict, warehouse: WarehouseInterface, project_dir: str): + """ + Perform a generic SQL function operation. + """ + dbt_sql = "" + if config["input"]["input_type"] != "cte": + dbt_sql = ( + "{{ config(materialized='table', schema='" + config["dest_schema"] + "') }}" + ) + + select_statement, output_cols = raw_generic_dbt_sql(config, warehouse) + + dest_schema = config["dest_schema"] + output_name = config["output_model_name"] + + dbtproject = dbtProject(project_dir) + dbtproject.ensure_models_dir(dest_schema) + model_sql_path = dbtproject.write_model(dest_schema, output_name, dbt_sql + select_statement) + + return model_sql_path, output_cols From ef8e430b49e18f89acd284d2e46cd2671e6c4238 Mon Sep 17 00:00:00 2001 From: Abhishek-N Date: Thu, 2 May 2024 10:36:31 +0530 Subject: [PATCH 02/12] include raw sql in merge --- dbt_automation/operations/mergeoperations.py | 5 +++++ scripts/main.py | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/dbt_automation/operations/mergeoperations.py b/dbt_automation/operations/mergeoperations.py index 77260cc..5a901f3 100644 --- a/dbt_automation/operations/mergeoperations.py +++ b/dbt_automation/operations/mergeoperations.py @@ -25,6 +25,7 @@ from dbt_automation.operations.mergetables import union_tables_sql from dbt_automation.operations.pivot import pivot_dbt_sql from dbt_automation.operations.unpivot import unpivot_dbt_sql +from dbt_automation.operations.rawsql import raw_generic_dbt_sql def merge_operations_sql( @@ -132,6 +133,10 @@ def merge_operations_sql( op_select_statement, out_cols = generic_function_dbt_sql( operation["config"], warehouse ) + elif operation["type"] == "rawsql": + op_select_statement, out_cols = raw_generic_dbt_sql( + operation["config"], warehouse + ) output_cols = out_cols diff --git a/scripts/main.py b/scripts/main.py index bc7ee79..7242c55 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -11,6 +11,7 @@ from dbt_automation.operations.generic import generic_function from dbt_automation.operations.arithmetic import arithmetic from dbt_automation.operations.mergeoperations import merge_operations +from dbt_automation.operations.rawsql import generic_sql_function from dbt_automation.operations.scaffold import scaffold from dbt_automation.utils.warehouseclient import get_client from dbt_automation.operations.droprenamecolumns import drop_columns, rename_columns @@ -53,7 +54,8 @@ "casewhen": casewhen, "pivot": pivot, "unpivot": unpivot, - "generic": generic_function + "generic": generic_function, + "rawsql": generic_sql_function, } load_dotenv("./../dbconnection.env") From 3252d45233fd6b292b56ab5943c176f715cea60c Mon Sep 17 00:00:00 2001 From: Abhishek-N Date: Thu, 2 May 2024 11:28:27 +0530 Subject: [PATCH 03/12] fix input --- dbt_automation/operations/rawsql.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dbt_automation/operations/rawsql.py b/dbt_automation/operations/rawsql.py index 71f6ce1..82aeb4b 100644 --- a/dbt_automation/operations/rawsql.py +++ b/dbt_automation/operations/rawsql.py @@ -35,7 +35,6 @@ def raw_generic_dbt_sql( else: dbt_code += "\n FROM " + "{{" + select_from + "}}" + "\n" - breakpoint() return dbt_code, source_columns def generic_sql_function(config: dict, warehouse: WarehouseInterface, project_dir: str): From 9bbbdb412978752fbac09f68af75e458337964d7 Mon Sep 17 00:00:00 2001 From: Abhishek-N Date: Thu, 2 May 2024 11:39:11 +0530 Subject: [PATCH 04/12] check for table --- dbt_automation/operations/rawsql.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dbt_automation/operations/rawsql.py b/dbt_automation/operations/rawsql.py index 82aeb4b..272485a 100644 --- a/dbt_automation/operations/rawsql.py +++ b/dbt_automation/operations/rawsql.py @@ -23,11 +23,10 @@ def raw_generic_dbt_sql( else: dbt_code = f"SELECT {', '.join([quote_columnname(col, warehouse.name) for col in columns])}" - if len(tables) == 1: - config['input']['input_name'] = tables[0] - select_from = source_or_ref(tables[0], tables[0], "source") + if not tables: + raise ValueError("No tables provided") else: - select_from = " JOIN ".join([f"{{{{ source('{table}', '{table}') }}}}" for table in tables]) + config['input']['input_name'] = tables[0] select_from = source_or_ref(**config["input"]) if config["input"]["input_type"] == "cte": @@ -35,6 +34,7 @@ def raw_generic_dbt_sql( else: dbt_code += "\n FROM " + "{{" + select_from + "}}" + "\n" + breakpoint() return dbt_code, source_columns def generic_sql_function(config: dict, warehouse: WarehouseInterface, project_dir: str): From 7f4192327315ce96df7554cca291c42f6a9a6714 Mon Sep 17 00:00:00 2001 From: Abhishek-N Date: Thu, 2 May 2024 11:40:11 +0530 Subject: [PATCH 05/12] remove breakpoint --- dbt_automation/operations/rawsql.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dbt_automation/operations/rawsql.py b/dbt_automation/operations/rawsql.py index 272485a..bcfebf9 100644 --- a/dbt_automation/operations/rawsql.py +++ b/dbt_automation/operations/rawsql.py @@ -34,7 +34,6 @@ def raw_generic_dbt_sql( else: dbt_code += "\n FROM " + "{{" + select_from + "}}" + "\n" - breakpoint() return dbt_code, source_columns def generic_sql_function(config: dict, warehouse: WarehouseInterface, project_dir: str): From 80f55c038307bba18b77474273809c27fd2b052d Mon Sep 17 00:00:00 2001 From: Abhishek-N Date: Thu, 2 May 2024 23:54:57 +0530 Subject: [PATCH 06/12] remove parser --- dbt_automation/operations/rawsql.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/dbt_automation/operations/rawsql.py b/dbt_automation/operations/rawsql.py index bcfebf9..9d8607c 100644 --- a/dbt_automation/operations/rawsql.py +++ b/dbt_automation/operations/rawsql.py @@ -14,19 +14,18 @@ def raw_generic_dbt_sql( """ source_columns = config.get('source_columns', []) sql_statement = config.get('raw_sql') - parser = Parser(sql_statement) - tables = parser.tables - columns = parser.columns + # parser = Parser(sql_statement) + # tables = parser.tables + # columns = parser.columns - if columns == "*": - dbt_code = "SELECT *" - else: - dbt_code = f"SELECT {', '.join([quote_columnname(col, warehouse.name) for col in columns])}" + if not sql_statement: + raise ValueError("Query fragment is required") - if not tables: - raise ValueError("No tables provided") - else: - config['input']['input_name'] = tables[0] + dbt_code = f"{sql_statement}" + + input_config = config.get("input") + if not input_config or "input_name" not in input_config: + raise ValueError("Input configuration must include 'input_name'") select_from = source_or_ref(**config["input"]) if config["input"]["input_type"] == "cte": From 21b8295f386c76b302d69480e08dc4be6cf62dec Mon Sep 17 00:00:00 2001 From: Abhishek-N Date: Sat, 4 May 2024 22:27:47 +0530 Subject: [PATCH 07/12] sql_statement_1 and sql_statement_2 added to raw_sql --- dbt_automation/operations/mergeoperations.py | 3 +- dbt_automation/operations/rawsql.py | 38 +++++++++----------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/dbt_automation/operations/mergeoperations.py b/dbt_automation/operations/mergeoperations.py index 5a901f3..09ae698 100644 --- a/dbt_automation/operations/mergeoperations.py +++ b/dbt_automation/operations/mergeoperations.py @@ -138,7 +138,8 @@ def merge_operations_sql( operation["config"], warehouse ) - output_cols = out_cols + if output_cols: + output_cols = out_cols cte_sql = f" , {operation['as_cte']} as (\n" if cte_counter == 0: diff --git a/dbt_automation/operations/rawsql.py b/dbt_automation/operations/rawsql.py index 9d8607c..239798f 100644 --- a/dbt_automation/operations/rawsql.py +++ b/dbt_automation/operations/rawsql.py @@ -1,39 +1,35 @@ -from sql_metadata import Parser - -from dbt_automation.utils.columnutils import quote_columnname from dbt_automation.utils.dbtproject import dbtProject from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface from dbt_automation.utils.tableutils import source_or_ref def raw_generic_dbt_sql( config: str, - warehouse: WarehouseInterface, ): """ - Parses the given SQL statement to extract tables and columns and generates DBT code. + Parses the given SQL statements to generate DBT code, handling an optional WHERE clause. """ - source_columns = config.get('source_columns', []) - sql_statement = config.get('raw_sql') - # parser = Parser(sql_statement) - # tables = parser.tables - # columns = parser.columns + sql_statement_1 = config.get('sql_statement_1') + sql_statement_2 = config.get('sql_statement_2', '') - if not sql_statement: - raise ValueError("Query fragment is required") + if not sql_statement_1: + raise ValueError("Primary SQL statement (sql_statement_1) is required") - dbt_code = f"{sql_statement}" + # Check if 'SELECT' is part of the sql_statement_1, if not, prepend it + if not sql_statement_1.strip().lower().startswith('select'): + sql_statement_1 = "SELECT " + sql_statement_1 - input_config = config.get("input") - if not input_config or "input_name" not in input_config: - raise ValueError("Input configuration must include 'input_name'") + dbt_code = f"{sql_statement_1}" select_from = source_or_ref(**config["input"]) if config["input"]["input_type"] == "cte": - dbt_code += "\n FROM " + select_from + "\n" + dbt_code += " FROM " + select_from else: - dbt_code += "\n FROM " + "{{" + select_from + "}}" + "\n" + dbt_code += " FROM " + "{{" + select_from + "}}" + + if sql_statement_2: + dbt_code += " " + sql_statement_2 - return dbt_code, source_columns + return dbt_code def generic_sql_function(config: dict, warehouse: WarehouseInterface, project_dir: str): """ @@ -45,7 +41,7 @@ def generic_sql_function(config: dict, warehouse: WarehouseInterface, project_di "{{ config(materialized='table', schema='" + config["dest_schema"] + "') }}" ) - select_statement, output_cols = raw_generic_dbt_sql(config, warehouse) + select_statement = raw_generic_dbt_sql(config) dest_schema = config["dest_schema"] output_name = config["output_model_name"] @@ -54,4 +50,4 @@ def generic_sql_function(config: dict, warehouse: WarehouseInterface, project_di dbtproject.ensure_models_dir(dest_schema) model_sql_path = dbtproject.write_model(dest_schema, output_name, dbt_sql + select_statement) - return model_sql_path, output_cols + return model_sql_path From 918c6ae9cdff6356208ceb25254e3c2ba6db9167 Mon Sep 17 00:00:00 2001 From: Abhishek-N Date: Sat, 4 May 2024 23:10:31 +0530 Subject: [PATCH 08/12] update merge op for raw sql --- dbt_automation/operations/mergeoperations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dbt_automation/operations/mergeoperations.py b/dbt_automation/operations/mergeoperations.py index 09ae698..1c9fa55 100644 --- a/dbt_automation/operations/mergeoperations.py +++ b/dbt_automation/operations/mergeoperations.py @@ -134,8 +134,8 @@ def merge_operations_sql( operation["config"], warehouse ) elif operation["type"] == "rawsql": - op_select_statement, out_cols = raw_generic_dbt_sql( - operation["config"], warehouse + op_select_statement = raw_generic_dbt_sql( + operation["config"] ) if output_cols: From 48721b524b35e17d6b98a80de34328c0cc32a11a Mon Sep 17 00:00:00 2001 From: Abhishek-N Date: Sun, 5 May 2024 16:56:45 +0530 Subject: [PATCH 09/12] rawsql tests --- tests/warehouse/test_bigquery_ops.py | 37 ++++++++++++++++++++++++++++ tests/warehouse/test_postgres_ops.py | 34 +++++++++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/tests/warehouse/test_bigquery_ops.py b/tests/warehouse/test_bigquery_ops.py index 0638d44..a1b01b9 100644 --- a/tests/warehouse/test_bigquery_ops.py +++ b/tests/warehouse/test_bigquery_ops.py @@ -12,6 +12,7 @@ from dbt_automation.operations.mergeoperations import ( merge_operations, ) +from dbt_automation.operations.rawsql import generic_sql_function from dbt_automation.utils.warehouseclient import get_client from dbt_automation.utils.dbtproject import dbtProject from dbt_automation.operations.scaffold import scaffold @@ -986,6 +987,13 @@ def test_merge_operation(self): ], }, }, + { + "type": "rawsql", + "config": { + "sql_statement_1": "*", + "sql_statement_2": "WHERE CAST(measure1 AS INT64) != 0" + }, + }, ], } @@ -1036,6 +1044,9 @@ def test_merge_operation(self): == 0 ) + assert all(row['measure1'] != 0 for row in table_data) + + def test_flattenjson(self): """Test flattenjson.""" wc_client = TestBigqueryOperations.wc_client @@ -1130,3 +1141,29 @@ def test_generic(self): for value in ngo_column: assert value == value.lower(), f"Value {value} in 'NGO' column is not lowercase" + + + def test_generic_sql_function(self): + """ test generic raw sql""" + wc_client = TestBigqueryOperations.wc_client + output_name = "rawsql" + + config = { + "input": { + "input_type": "model", + "input_name": "_airbyte_raw_Sheet1", + "source_name": None, + }, + "dest_schema": "pytest_intermediate", + "output_model_name": output_name, + "sql_statement_1": "measure1, measure2", + "sql_statement_2": "WHERE measure = '183'" + } + + generic_sql_function(config, wc_client, TestBigqueryOperations.test_project_dir) + + TestBigqueryOperations.execute_dbt("run", output_name) + + col_data = wc_client.get_table_data("pytest_intermediate", output_name, 1) + assert "183" in col_data[0]['measure1'] + diff --git a/tests/warehouse/test_postgres_ops.py b/tests/warehouse/test_postgres_ops.py index 1527357..6f52e1e 100644 --- a/tests/warehouse/test_postgres_ops.py +++ b/tests/warehouse/test_postgres_ops.py @@ -9,6 +9,7 @@ from dbt_automation.operations.droprenamecolumns import rename_columns, drop_columns from dbt_automation.operations.generic import generic_function from dbt_automation.operations.mergeoperations import merge_operations +from dbt_automation.operations.rawsql import generic_sql_function from dbt_automation.utils.warehouseclient import get_client from dbt_automation.operations.scaffold import scaffold from dbt_automation.operations.syncsources import sync_sources @@ -1043,6 +1044,13 @@ def test_merge_operation(self): ], }, }, + { + "type": "rawsql", + "config": { + "sql_statement_1": "*", + "sql_statement_2": "WHERE CAST(measure1 AS INT64) != 0" + }, + }, ], } @@ -1093,6 +1101,8 @@ def test_merge_operation(self): == 0 ) + assert all(row['measure1'] != 0 for row in table_data) + def test_generic(self): """test generic operation""" wc_client = TestPostgresOperations.wc_client @@ -1148,3 +1158,27 @@ def test_generic(self): for value in ngo_column: assert value == value.lower(), f"Value {value} in 'NGO' column is not lowercase" + + def test_generic_sql_function(self): + """ test generic raw sql""" + wc_client = TestPostgresOperations.wc_client + output_name = "rawsql" + + config = { + "input": { + "input_type": "model", + "input_name": "_airbyte_raw_Sheet1", + "source_name": None, + }, + "dest_schema": "pytest_intermediate", + "output_model_name": output_name, + "sql_statement_1": "measure1, measure2", + "sql_statement_2": "WHERE measure1 = '183'" + } + + generic_sql_function(config, wc_client, TestPostgresOperations.test_project_dir) + + TestPostgresOperations.execute_dbt("run", output_name) + + col_data = wc_client.get_table_data("pytest_intermediate", output_name, 1) + assert "183" in col_data[0]['measure1'] From 26dc64b1e9431a64b0bd89a8c4327bf7fd2d7795 Mon Sep 17 00:00:00 2001 From: Abhishek-N Date: Mon, 6 May 2024 10:11:56 +0530 Subject: [PATCH 10/12] add rawsql to operations yaml --- dbt_automation/assets/operations.template.yml | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/dbt_automation/assets/operations.template.yml b/dbt_automation/assets/operations.template.yml index 5ff29f3..1dbb30a 100644 --- a/dbt_automation/assets/operations.template.yml +++ b/dbt_automation/assets/operations.template.yml @@ -429,7 +429,16 @@ operations: cast_to: dest_schema: output_name: - + - type: rawsql + config: + - input: + input_type: <"source" or "model" of table1> + input_name: + source_name: + sql_statement_1: + sql_statement_2: + dest_schema: + output_model_name: - type: mergeoperations config: @@ -744,3 +753,13 @@ operations: - value: is_col: output_column_name: + - type: rawsql + config: + - input: + input_type: <"source" or "model" of table1> + input_name: + source_name: + sql_statement_1: + sql_statement_2: + dest_schema: + output_model_name: From 83f626600440bc31566834f13d333009fb949d8a Mon Sep 17 00:00:00 2001 From: Abhishek-N Date: Mon, 6 May 2024 16:04:19 +0530 Subject: [PATCH 11/12] send output_cols and warehouse --- dbt_automation/operations/mergeoperations.py | 7 +++---- dbt_automation/operations/rawsql.py | 8 +++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/dbt_automation/operations/mergeoperations.py b/dbt_automation/operations/mergeoperations.py index 1c9fa55..5a901f3 100644 --- a/dbt_automation/operations/mergeoperations.py +++ b/dbt_automation/operations/mergeoperations.py @@ -134,12 +134,11 @@ def merge_operations_sql( operation["config"], warehouse ) elif operation["type"] == "rawsql": - op_select_statement = raw_generic_dbt_sql( - operation["config"] + op_select_statement, out_cols = raw_generic_dbt_sql( + operation["config"], warehouse ) - if output_cols: - output_cols = out_cols + output_cols = out_cols cte_sql = f" , {operation['as_cte']} as (\n" if cte_counter == 0: diff --git a/dbt_automation/operations/rawsql.py b/dbt_automation/operations/rawsql.py index 239798f..4895a04 100644 --- a/dbt_automation/operations/rawsql.py +++ b/dbt_automation/operations/rawsql.py @@ -4,12 +4,14 @@ def raw_generic_dbt_sql( config: str, + warehouse: WarehouseInterface, ): """ Parses the given SQL statements to generate DBT code, handling an optional WHERE clause. """ sql_statement_1 = config.get('sql_statement_1') sql_statement_2 = config.get('sql_statement_2', '') + output_cols = [] if not sql_statement_1: raise ValueError("Primary SQL statement (sql_statement_1) is required") @@ -29,7 +31,7 @@ def raw_generic_dbt_sql( if sql_statement_2: dbt_code += " " + sql_statement_2 - return dbt_code + return dbt_code, output_cols def generic_sql_function(config: dict, warehouse: WarehouseInterface, project_dir: str): """ @@ -41,7 +43,7 @@ def generic_sql_function(config: dict, warehouse: WarehouseInterface, project_di "{{ config(materialized='table', schema='" + config["dest_schema"] + "') }}" ) - select_statement = raw_generic_dbt_sql(config) + select_statement, output_cols = raw_generic_dbt_sql(config, warehouse) dest_schema = config["dest_schema"] output_name = config["output_model_name"] @@ -50,4 +52,4 @@ def generic_sql_function(config: dict, warehouse: WarehouseInterface, project_di dbtproject.ensure_models_dir(dest_schema) model_sql_path = dbtproject.write_model(dest_schema, output_name, dbt_sql + select_statement) - return model_sql_path + return model_sql_path, output_cols From 125786666688d33b11a6c39546678e9ec8cee7cd Mon Sep 17 00:00:00 2001 From: Abhishek-N Date: Mon, 6 May 2024 16:19:38 +0530 Subject: [PATCH 12/12] measure1 not measure --- tests/warehouse/test_bigquery_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/warehouse/test_bigquery_ops.py b/tests/warehouse/test_bigquery_ops.py index a1b01b9..184ed69 100644 --- a/tests/warehouse/test_bigquery_ops.py +++ b/tests/warehouse/test_bigquery_ops.py @@ -1157,7 +1157,7 @@ def test_generic_sql_function(self): "dest_schema": "pytest_intermediate", "output_model_name": output_name, "sql_statement_1": "measure1, measure2", - "sql_statement_2": "WHERE measure = '183'" + "sql_statement_2": "WHERE measure1 = '183'" } generic_sql_function(config, wc_client, TestBigqueryOperations.test_project_dir)