Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

107 generic function operation #111

Merged
merged 14 commits into from
Apr 12, 2024
48 changes: 47 additions & 1 deletion dbt_automation/assets/operations.template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,29 @@ operations:
is_col: <boolean>
output_column_name: <output column name>
sql_snippet: <custom sql snippet of CASE WHEN END AS 'output_column_name'>
- type: generic
config:
input:
input_type: <"source" or "model">
input_name: <name of source table or ref model>
source_name: <name of the source defined in source.yml; will be null for type "model">
source_columns:
- <column name>
- <column name>
- <column name>
- ...
dest_schema: <destination schema>
output_model_name: <name of the output model>
computed_columns:
- function_name: <name of the sql function>
operands:
- value: <string (column name or const)>
is_col: <boolean>
- value: <string (column name or const)>
is_col: <boolean>
- value: <string (column name or const)>
is_col: <boolean>
output_column_name: <output column name>

- type: pivot
config:
Expand Down Expand Up @@ -697,4 +720,27 @@ operations:
- <column name>
cast_to: <data type to cast values to - "varchar" for postgres & "STRING" for bigquery>
unpivot_field_name: <by default - "field_name">
unpivot_value_name: <by default - "value">
unpivot_value_name: <by default - "value">
- type: generic
config:
input:
input_type: <"source" or "model">
input_name: <name of source table or ref model>
source_name: <name of the source defined in source.yml; will be null for type "model">
source_columns:
- <column name>
- <column name>
- <column name>
- ...
dest_schema: <destination schema>
output_model_name: <name of the output model>
computed_columns:
- function_name: <name of the sql function>
operands:
- value: <string (column name or const)>
is_col: <boolean>
- value: <string (column name or const)>
is_col: <boolean>
- value: <string (column name or const)>
is_col: <boolean>
output_column_name: <output column name>
75 changes: 75 additions & 0 deletions dbt_automation/operations/generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
This file contains the airthmetic operations for dbt automation
"""

from logging import basicConfig, getLogger, INFO
from dbt_automation.utils.dbtproject import dbtProject
from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface
from dbt_automation.utils.columnutils import quote_columnname, quote_constvalue

from dbt_automation.utils.tableutils import source_or_ref

basicConfig(level=INFO)
logger = getLogger()


def generic_function_dbt_sql(
config: dict,
warehouse: WarehouseInterface,
):
"""
source_columns: list of columns to copy from the input model
computed_columns: list of computed columns with function_name, operands, and output_column_name
function_name: name of the function to be used
operands: list of operands to be passed to the function
output_column_name: name of the output column
"""
source_columns = config["source_columns"]
computed_columns = config["computed_columns"]

if source_columns == "*":
dbt_code = "SELECT *"
else:
dbt_code = f"SELECT {', '.join([quote_columnname(col, warehouse.name) for col in source_columns])}"

for computed_column in computed_columns:
function_name = computed_column["function_name"]
operands = [
quote_columnname(str(operand["value"]), warehouse.name)
if operand["is_col"]
else quote_constvalue(str(operand["value"]), warehouse.name)
for operand in computed_column["operands"]
]
output_column_name = computed_column["output_column_name"]

dbt_code += f", {function_name}({', '.join(operands)}) AS {output_column_name}"

select_from = source_or_ref(**config["input"])
if config["input"]["input_type"] == "cte":
dbt_code += f" FROM {select_from}"
else:
dbt_code += f" FROM {{{{{select_from}}}}}"

return dbt_code, source_columns


def generic_function(config: dict, warehouse: WarehouseInterface, project_dir: str):
"""
Perform a generic SQL function operation.
"""
dbt_sql = ""
if config["input"]["input_type"] != "cte":
dbt_sql = (
"{{ config(materialized='table', schema='" + config["dest_schema"] + "') }}"
)

select_statement, output_cols = generic_function_dbt_sql(config, warehouse)

dest_schema = config["dest_schema"]
output_name = config["output_model_name"]

dbtproject = dbtProject(project_dir)
dbtproject.ensure_models_dir(dest_schema)
model_sql_path = dbtproject.write_model(dest_schema, output_name, dbt_sql + select_statement)

return model_sql_path, output_cols
5 changes: 5 additions & 0 deletions dbt_automation/operations/mergeoperations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
rename_columns_dbt_sql,
)
from dbt_automation.operations.flattenjson import flattenjson_dbt_sql
from dbt_automation.operations.generic import generic_function_dbt_sql
from dbt_automation.operations.mergetables import union_tables_sql
from dbt_automation.operations.regexextraction import regex_extraction_sql
from dbt_automation.utils.dbtproject import dbtProject
Expand Down Expand Up @@ -127,6 +128,10 @@ def merge_operations_sql(
op_select_statement, out_cols = unpivot_dbt_sql(
operation["config"], warehouse
)
elif operation["type"] == "generic":
op_select_statement, out_cols = generic_function_dbt_sql(
operation["config"], warehouse
)

output_cols = out_cols

Expand Down
2 changes: 2 additions & 0 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import yaml
from dotenv import load_dotenv
from dbt_automation.operations.generic import generic_function
from dbt_automation.operations.arithmetic import arithmetic
from dbt_automation.operations.mergeoperations import merge_operations
from dbt_automation.operations.scaffold import scaffold
Expand Down Expand Up @@ -52,6 +53,7 @@
"casewhen": casewhen,
"pivot": pivot,
"unpivot": unpivot,
"generic": generic_function
}

load_dotenv("./../dbconnection.env")
Expand Down
51 changes: 51 additions & 0 deletions tests/warehouse/test_bigquery_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from logging import basicConfig, getLogger, INFO
from dbt_automation.operations.droprenamecolumns import rename_columns, drop_columns
from dbt_automation.operations.flattenjson import flattenjson
from dbt_automation.operations.generic import generic_function
from dbt_automation.operations.mergeoperations import (
merge_operations,
)
Expand Down Expand Up @@ -1078,3 +1079,53 @@ def test_flattenjson(self):
assert "_airbyte_data_NGO" in cols
assert "_airbyte_data_Month" in cols
assert "_airbyte_ab_id" in cols

def test_generic(self):
"""test generic function"""
wc_client = TestBigqueryOperations.wc_client
output_name = "generic_table"

config = {
"input": {
"input_type": "model",
"input_name": "_airbyte_raw_Sheet2",
"source_name": None,
},
"dest_schema": "pytest_intermediate",
"output_model_name": output_name,
"source_columns": ["NGO", "Month", "measure1", "measure2", "Indicator"],
"computed_columns": [
{
"function_name": "LOWER",
"operands": [
{"value": "NGO", "is_col": True}
],
"output_column_name": "ngo_lower"
},
{
"function_name": "TRIM",
"operands": [
{"value": "measure1", "is_col": True}
],
"output_column_name": "trimmed_indicator"
}
],
}

generic_function(config, wc_client, TestBigqueryOperations.test_project_dir)

TestBigqueryOperations.execute_dbt("run", output_name)

cols = [
col_dict["name"]
for col_dict in wc_client.get_table_columns(
"pytest_intermediate", output_name
)
]
assert "NGO" in cols
assert "Indicator" in cols
table_data = wc_client.get_table_data("pytest_intermediate", output_name, 1)
ngo_column = [row['ngo_lower'] for row in table_data]

for value in ngo_column:
assert value == value.lower(), f"Value {value} in 'NGO' column is not lowercase"
57 changes: 57 additions & 0 deletions tests/warehouse/test_postgres_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from logging import basicConfig, getLogger, INFO
from dbt_automation.operations.flattenjson import flattenjson
from dbt_automation.operations.droprenamecolumns import rename_columns, drop_columns
from dbt_automation.operations.generic import generic_function
from dbt_automation.operations.mergeoperations import merge_operations
from dbt_automation.utils.warehouseclient import get_client
from dbt_automation.operations.scaffold import scaffold
Expand Down Expand Up @@ -1091,3 +1092,59 @@ def test_merge_operation(self):
)
== 0
)

def test_generic(self):
"""test generic operation"""
wc_client = TestPostgresOperations.wc_client
output_name = "generic"

config = {
"input": {
"input_type": "model",
"input_name": "_airbyte_raw_Sheet2",
"source_name": None,
},
"dest_schema": "pytest_intermediate",
"output_model_name": output_name,
"source_columns": ["NGO", "Month", "measure1", "measure2", "Indicator"],
"computed_columns": [
{
"function_name": "LOWER",
"operands": [
{"value": "NGO", "is_col": True}
],
"output_column_name": "ngo_lower"
},
{
"function_name": "TRIM",
"operands": [
{"value": "measure1", "is_col": True}
],
"output_column_name": "trimmed_measure_1"
}
],
}

generic_function(
config,
wc_client,
TestPostgresOperations.test_project_dir,
)

TestPostgresOperations.execute_dbt("run", output_name)

cols = [
col_dict["name"]
for col_dict in wc_client.get_table_columns(
"pytest_intermediate", output_name
)
]

assert "NGO" in cols
assert "Indicator" in cols
table_data = wc_client.get_table_data("pytest_intermediate", output_name, 1)
ngo_column = [row['ngo_lower'] for row in table_data]

for value in ngo_column:
assert value == value.lower(), f"Value {value} in 'NGO' column is not lowercase"

Loading