Skip to content

Commit

Permalink
Merge pull request #111 from DalgoT4D/107-generic-function-operation
Browse files Browse the repository at this point in the history
107 generic function operation
  • Loading branch information
fatchat authored Apr 12, 2024
2 parents ef626df + 4e20d4b commit 648f7f5
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 1 deletion.
48 changes: 47 additions & 1 deletion dbt_automation/assets/operations.template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,29 @@ operations:
is_col: <boolean>
output_column_name: <output column name>
sql_snippet: <custom sql snippet of CASE WHEN END AS 'output_column_name'>
- type: generic
config:
input:
input_type: <"source" or "model">
input_name: <name of source table or ref model>
source_name: <name of the source defined in source.yml; will be null for type "model">
source_columns:
- <column name>
- <column name>
- <column name>
- ...
dest_schema: <destination schema>
output_model_name: <name of the output model>
computed_columns:
- function_name: <name of the sql function>
operands:
- value: <string (column name or const)>
is_col: <boolean>
- value: <string (column name or const)>
is_col: <boolean>
- value: <string (column name or const)>
is_col: <boolean>
output_column_name: <output column name>

- type: pivot
config:
Expand Down Expand Up @@ -697,4 +720,27 @@ operations:
- <column name>
cast_to: <data type to cast values to - "varchar" for postgres & "STRING" for bigquery>
unpivot_field_name: <by default - "field_name">
unpivot_value_name: <by default - "value">
unpivot_value_name: <by default - "value">
- type: generic
config:
input:
input_type: <"source" or "model">
input_name: <name of source table or ref model>
source_name: <name of the source defined in source.yml; will be null for type "model">
source_columns:
- <column name>
- <column name>
- <column name>
- ...
dest_schema: <destination schema>
output_model_name: <name of the output model>
computed_columns:
- function_name: <name of the sql function>
operands:
- value: <string (column name or const)>
is_col: <boolean>
- value: <string (column name or const)>
is_col: <boolean>
- value: <string (column name or const)>
is_col: <boolean>
output_column_name: <output column name>
75 changes: 75 additions & 0 deletions dbt_automation/operations/generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
This file contains the airthmetic operations for dbt automation
"""

from logging import basicConfig, getLogger, INFO
from dbt_automation.utils.dbtproject import dbtProject
from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface
from dbt_automation.utils.columnutils import quote_columnname, quote_constvalue

from dbt_automation.utils.tableutils import source_or_ref

basicConfig(level=INFO)
logger = getLogger()


def generic_function_dbt_sql(
config: dict,
warehouse: WarehouseInterface,
):
"""
source_columns: list of columns to copy from the input model
computed_columns: list of computed columns with function_name, operands, and output_column_name
function_name: name of the function to be used
operands: list of operands to be passed to the function
output_column_name: name of the output column
"""
source_columns = config["source_columns"]
computed_columns = config["computed_columns"]

if source_columns == "*":
dbt_code = "SELECT *"
else:
dbt_code = f"SELECT {', '.join([quote_columnname(col, warehouse.name) for col in source_columns])}"

for computed_column in computed_columns:
function_name = computed_column["function_name"]
operands = [
quote_columnname(str(operand["value"]), warehouse.name)
if operand["is_col"]
else quote_constvalue(str(operand["value"]), warehouse.name)
for operand in computed_column["operands"]
]
output_column_name = computed_column["output_column_name"]

dbt_code += f", {function_name}({', '.join(operands)}) AS {output_column_name}"

select_from = source_or_ref(**config["input"])
if config["input"]["input_type"] == "cte":
dbt_code += f" FROM {select_from}"
else:
dbt_code += f" FROM {{{{{select_from}}}}}"

return dbt_code, source_columns


def generic_function(config: dict, warehouse: WarehouseInterface, project_dir: str):
"""
Perform a generic SQL function operation.
"""
dbt_sql = ""
if config["input"]["input_type"] != "cte":
dbt_sql = (
"{{ config(materialized='table', schema='" + config["dest_schema"] + "') }}"
)

select_statement, output_cols = generic_function_dbt_sql(config, warehouse)

dest_schema = config["dest_schema"]
output_name = config["output_model_name"]

dbtproject = dbtProject(project_dir)
dbtproject.ensure_models_dir(dest_schema)
model_sql_path = dbtproject.write_model(dest_schema, output_name, dbt_sql + select_statement)

return model_sql_path, output_cols
5 changes: 5 additions & 0 deletions dbt_automation/operations/mergeoperations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
rename_columns_dbt_sql,
)
from dbt_automation.operations.flattenjson import flattenjson_dbt_sql
from dbt_automation.operations.generic import generic_function_dbt_sql
from dbt_automation.operations.mergetables import union_tables_sql
from dbt_automation.operations.regexextraction import regex_extraction_sql
from dbt_automation.utils.dbtproject import dbtProject
Expand Down Expand Up @@ -127,6 +128,10 @@ def merge_operations_sql(
op_select_statement, out_cols = unpivot_dbt_sql(
operation["config"], warehouse
)
elif operation["type"] == "generic":
op_select_statement, out_cols = generic_function_dbt_sql(
operation["config"], warehouse
)

output_cols = out_cols

Expand Down
2 changes: 2 additions & 0 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import yaml
from dotenv import load_dotenv
from dbt_automation.operations.generic import generic_function
from dbt_automation.operations.arithmetic import arithmetic
from dbt_automation.operations.mergeoperations import merge_operations
from dbt_automation.operations.scaffold import scaffold
Expand Down Expand Up @@ -52,6 +53,7 @@
"casewhen": casewhen,
"pivot": pivot,
"unpivot": unpivot,
"generic": generic_function
}

load_dotenv("./../dbconnection.env")
Expand Down
51 changes: 51 additions & 0 deletions tests/warehouse/test_bigquery_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from logging import basicConfig, getLogger, INFO
from dbt_automation.operations.droprenamecolumns import rename_columns, drop_columns
from dbt_automation.operations.flattenjson import flattenjson
from dbt_automation.operations.generic import generic_function
from dbt_automation.operations.mergeoperations import (
merge_operations,
)
Expand Down Expand Up @@ -1078,3 +1079,53 @@ def test_flattenjson(self):
assert "_airbyte_data_NGO" in cols
assert "_airbyte_data_Month" in cols
assert "_airbyte_ab_id" in cols

def test_generic(self):
"""test generic function"""
wc_client = TestBigqueryOperations.wc_client
output_name = "generic_table"

config = {
"input": {
"input_type": "model",
"input_name": "_airbyte_raw_Sheet2",
"source_name": None,
},
"dest_schema": "pytest_intermediate",
"output_model_name": output_name,
"source_columns": ["NGO", "Month", "measure1", "measure2", "Indicator"],
"computed_columns": [
{
"function_name": "LOWER",
"operands": [
{"value": "NGO", "is_col": True}
],
"output_column_name": "ngo_lower"
},
{
"function_name": "TRIM",
"operands": [
{"value": "measure1", "is_col": True}
],
"output_column_name": "trimmed_indicator"
}
],
}

generic_function(config, wc_client, TestBigqueryOperations.test_project_dir)

TestBigqueryOperations.execute_dbt("run", output_name)

cols = [
col_dict["name"]
for col_dict in wc_client.get_table_columns(
"pytest_intermediate", output_name
)
]
assert "NGO" in cols
assert "Indicator" in cols
table_data = wc_client.get_table_data("pytest_intermediate", output_name, 1)
ngo_column = [row['ngo_lower'] for row in table_data]

for value in ngo_column:
assert value == value.lower(), f"Value {value} in 'NGO' column is not lowercase"
57 changes: 57 additions & 0 deletions tests/warehouse/test_postgres_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from logging import basicConfig, getLogger, INFO
from dbt_automation.operations.flattenjson import flattenjson
from dbt_automation.operations.droprenamecolumns import rename_columns, drop_columns
from dbt_automation.operations.generic import generic_function
from dbt_automation.operations.mergeoperations import merge_operations
from dbt_automation.utils.warehouseclient import get_client
from dbt_automation.operations.scaffold import scaffold
Expand Down Expand Up @@ -1091,3 +1092,59 @@ def test_merge_operation(self):
)
== 0
)

def test_generic(self):
"""test generic operation"""
wc_client = TestPostgresOperations.wc_client
output_name = "generic"

config = {
"input": {
"input_type": "model",
"input_name": "_airbyte_raw_Sheet2",
"source_name": None,
},
"dest_schema": "pytest_intermediate",
"output_model_name": output_name,
"source_columns": ["NGO", "Month", "measure1", "measure2", "Indicator"],
"computed_columns": [
{
"function_name": "LOWER",
"operands": [
{"value": "NGO", "is_col": True}
],
"output_column_name": "ngo_lower"
},
{
"function_name": "TRIM",
"operands": [
{"value": "measure1", "is_col": True}
],
"output_column_name": "trimmed_measure_1"
}
],
}

generic_function(
config,
wc_client,
TestPostgresOperations.test_project_dir,
)

TestPostgresOperations.execute_dbt("run", output_name)

cols = [
col_dict["name"]
for col_dict in wc_client.get_table_columns(
"pytest_intermediate", output_name
)
]

assert "NGO" in cols
assert "Indicator" in cols
table_data = wc_client.get_table_data("pytest_intermediate", output_name, 1)
ngo_column = [row['ngo_lower'] for row in table_data]

for value in ngo_column:
assert value == value.lower(), f"Value {value} in 'NGO' column is not lowercase"

0 comments on commit 648f7f5

Please sign in to comment.