Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

115 generic sql queries #116

Merged
merged 12 commits into from
May 6, 2024
21 changes: 20 additions & 1 deletion dbt_automation/assets/operations.template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,16 @@ operations:
cast_to: <data type to cast values to - "varchar" for postgres & "STRING" for bigquery>
dest_schema: <destination schema>
output_name: <name of the output model>

- type: rawsql
config:
- input:
input_type: <"source" or "model" of table1>
input_name: <name of source table or ref model table1>
source_name: <name of the source defined in source.yml; will be null for type "model" table1>
sql_statement_1: <sql statement for select>
sql_statement_2: <optional sql statement for where or other clause and filters>
dest_schema: <destination schema>
output_model_name: <name of the output model>

- type: mergeoperations
config:
Expand Down Expand Up @@ -744,3 +753,13 @@ operations:
- value: <string (column name or const)>
is_col: <boolean>
output_column_name: <output column name>
- type: rawsql
config:
- input:
input_type: <"source" or "model" of table1>
input_name: <name of source table or ref model table1>
source_name: <name of the source defined in source.yml; will be null for type "model" table1>
sql_statement_1: <sql statement for select>
sql_statement_2: <optional sql statement for where or other clause and filters>
dest_schema: <destination schema>
output_model_name: <name of the output model>
5 changes: 5 additions & 0 deletions dbt_automation/operations/mergeoperations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from dbt_automation.operations.mergetables import union_tables_sql
from dbt_automation.operations.pivot import pivot_dbt_sql
from dbt_automation.operations.unpivot import unpivot_dbt_sql
from dbt_automation.operations.rawsql import raw_generic_dbt_sql


def merge_operations_sql(
Expand Down Expand Up @@ -132,6 +133,10 @@ def merge_operations_sql(
op_select_statement, out_cols = generic_function_dbt_sql(
operation["config"], warehouse
)
elif operation["type"] == "rawsql":
op_select_statement, out_cols = raw_generic_dbt_sql(
operation["config"], warehouse
)

output_cols = out_cols

Expand Down
55 changes: 55 additions & 0 deletions dbt_automation/operations/rawsql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from dbt_automation.utils.dbtproject import dbtProject
from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface
from dbt_automation.utils.tableutils import source_or_ref

def raw_generic_dbt_sql(
config: str,
warehouse: WarehouseInterface,
):
"""
Parses the given SQL statements to generate DBT code, handling an optional WHERE clause.
"""
sql_statement_1 = config.get('sql_statement_1')
sql_statement_2 = config.get('sql_statement_2', '')
output_cols = []

if not sql_statement_1:
raise ValueError("Primary SQL statement (sql_statement_1) is required")

# Check if 'SELECT' is part of the sql_statement_1, if not, prepend it
if not sql_statement_1.strip().lower().startswith('select'):
sql_statement_1 = "SELECT " + sql_statement_1

dbt_code = f"{sql_statement_1}"

select_from = source_or_ref(**config["input"])
if config["input"]["input_type"] == "cte":
dbt_code += " FROM " + select_from
else:
dbt_code += " FROM " + "{{" + select_from + "}}"

if sql_statement_2:
dbt_code += " " + sql_statement_2

return dbt_code, output_cols

def generic_sql_function(config: dict, warehouse: WarehouseInterface, project_dir: str):
"""
Perform a generic SQL function operation.
"""
dbt_sql = ""
if config["input"]["input_type"] != "cte":
dbt_sql = (
"{{ config(materialized='table', schema='" + config["dest_schema"] + "') }}"
)

select_statement, output_cols = raw_generic_dbt_sql(config, warehouse)

dest_schema = config["dest_schema"]
output_name = config["output_model_name"]

dbtproject = dbtProject(project_dir)
dbtproject.ensure_models_dir(dest_schema)
model_sql_path = dbtproject.write_model(dest_schema, output_name, dbt_sql + select_statement)

return model_sql_path, output_cols
4 changes: 3 additions & 1 deletion scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dbt_automation.operations.generic import generic_function
from dbt_automation.operations.arithmetic import arithmetic
from dbt_automation.operations.mergeoperations import merge_operations
from dbt_automation.operations.rawsql import generic_sql_function
from dbt_automation.operations.scaffold import scaffold
from dbt_automation.utils.warehouseclient import get_client
from dbt_automation.operations.droprenamecolumns import drop_columns, rename_columns
Expand Down Expand Up @@ -53,7 +54,8 @@
"casewhen": casewhen,
"pivot": pivot,
"unpivot": unpivot,
"generic": generic_function
"generic": generic_function,
"rawsql": generic_sql_function,
}

load_dotenv("./../dbconnection.env")
Expand Down
37 changes: 37 additions & 0 deletions tests/warehouse/test_bigquery_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dbt_automation.operations.mergeoperations import (
merge_operations,
)
from dbt_automation.operations.rawsql import generic_sql_function
from dbt_automation.utils.warehouseclient import get_client
from dbt_automation.utils.dbtproject import dbtProject
from dbt_automation.operations.scaffold import scaffold
Expand Down Expand Up @@ -986,6 +987,13 @@ def test_merge_operation(self):
],
},
},
{
"type": "rawsql",
"config": {
"sql_statement_1": "*",
"sql_statement_2": "WHERE CAST(measure1 AS INT64) != 0"
},
},
],
}

Expand Down Expand Up @@ -1036,6 +1044,9 @@ def test_merge_operation(self):
== 0
)

assert all(row['measure1'] != 0 for row in table_data)


def test_flattenjson(self):
"""Test flattenjson."""
wc_client = TestBigqueryOperations.wc_client
Expand Down Expand Up @@ -1130,3 +1141,29 @@ def test_generic(self):

for value in ngo_column:
assert value == value.lower(), f"Value {value} in 'NGO' column is not lowercase"


def test_generic_sql_function(self):
""" test generic raw sql"""
wc_client = TestBigqueryOperations.wc_client
output_name = "rawsql"

config = {
"input": {
"input_type": "model",
"input_name": "_airbyte_raw_Sheet1",
"source_name": None,
},
"dest_schema": "pytest_intermediate",
"output_model_name": output_name,
"sql_statement_1": "measure1, measure2",
"sql_statement_2": "WHERE measure1 = '183'"
}

generic_sql_function(config, wc_client, TestBigqueryOperations.test_project_dir)

TestBigqueryOperations.execute_dbt("run", output_name)

col_data = wc_client.get_table_data("pytest_intermediate", output_name, 1)
assert "183" in col_data[0]['measure1']

34 changes: 34 additions & 0 deletions tests/warehouse/test_postgres_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dbt_automation.operations.droprenamecolumns import rename_columns, drop_columns
from dbt_automation.operations.generic import generic_function
from dbt_automation.operations.mergeoperations import merge_operations
from dbt_automation.operations.rawsql import generic_sql_function
from dbt_automation.utils.warehouseclient import get_client
from dbt_automation.operations.scaffold import scaffold
from dbt_automation.operations.syncsources import sync_sources
Expand Down Expand Up @@ -1043,6 +1044,13 @@ def test_merge_operation(self):
],
},
},
{
"type": "rawsql",
"config": {
"sql_statement_1": "*",
"sql_statement_2": "WHERE CAST(measure1 AS INT64) != 0"
},
},
],
}

Expand Down Expand Up @@ -1093,6 +1101,8 @@ def test_merge_operation(self):
== 0
)

assert all(row['measure1'] != 0 for row in table_data)

def test_generic(self):
"""test generic operation"""
wc_client = TestPostgresOperations.wc_client
Expand Down Expand Up @@ -1148,3 +1158,27 @@ def test_generic(self):
for value in ngo_column:
assert value == value.lower(), f"Value {value} in 'NGO' column is not lowercase"


def test_generic_sql_function(self):
""" test generic raw sql"""
wc_client = TestPostgresOperations.wc_client
output_name = "rawsql"

config = {
"input": {
"input_type": "model",
"input_name": "_airbyte_raw_Sheet1",
"source_name": None,
},
"dest_schema": "pytest_intermediate",
"output_model_name": output_name,
"sql_statement_1": "measure1, measure2",
"sql_statement_2": "WHERE measure1 = '183'"
}

generic_sql_function(config, wc_client, TestPostgresOperations.test_project_dir)

TestPostgresOperations.execute_dbt("run", output_name)

col_data = wc_client.get_table_data("pytest_intermediate", output_name, 1)
assert "183" in col_data[0]['measure1']
Loading