diff --git a/dbt_automation/assets/operations.template.yml b/dbt_automation/assets/operations.template.yml index 61d9429..5ff29f3 100644 --- a/dbt_automation/assets/operations.template.yml +++ b/dbt_automation/assets/operations.template.yml @@ -364,6 +364,29 @@ operations: is_col: output_column_name: sql_snippet: + - type: generic + config: + input: + input_type: <"source" or "model"> + input_name: + source_name: + source_columns: + - + - + - + - ... + dest_schema: + output_model_name: + computed_columns: + - function_name: + operands: + - value: + is_col: + - value: + is_col: + - value: + is_col: + output_column_name: - type: pivot config: @@ -697,4 +720,27 @@ operations: - cast_to: unpivot_field_name: - unpivot_value_name: \ No newline at end of file + unpivot_value_name: + - type: generic + config: + input: + input_type: <"source" or "model"> + input_name: + source_name: + source_columns: + - + - + - + - ... + dest_schema: + output_model_name: + computed_columns: + - function_name: + operands: + - value: + is_col: + - value: + is_col: + - value: + is_col: + output_column_name: diff --git a/dbt_automation/operations/generic.py b/dbt_automation/operations/generic.py new file mode 100644 index 0000000..81e66a4 --- /dev/null +++ b/dbt_automation/operations/generic.py @@ -0,0 +1,75 @@ +""" +This file contains the airthmetic operations for dbt automation +""" + +from logging import basicConfig, getLogger, INFO +from dbt_automation.utils.dbtproject import dbtProject +from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface +from dbt_automation.utils.columnutils import quote_columnname, quote_constvalue + +from dbt_automation.utils.tableutils import source_or_ref + +basicConfig(level=INFO) +logger = getLogger() + + +def generic_function_dbt_sql( + config: dict, + warehouse: WarehouseInterface, +): + """ + source_columns: list of columns to copy from the input model + computed_columns: list of computed columns with function_name, operands, and output_column_name + function_name: name of the function to be used + operands: list of operands to be passed to the function + output_column_name: name of the output column + """ + source_columns = config["source_columns"] + computed_columns = config["computed_columns"] + + if source_columns == "*": + dbt_code = "SELECT *" + else: + dbt_code = f"SELECT {', '.join([quote_columnname(col, warehouse.name) for col in source_columns])}" + + for computed_column in computed_columns: + function_name = computed_column["function_name"] + operands = [ + quote_columnname(str(operand["value"]), warehouse.name) + if operand["is_col"] + else quote_constvalue(str(operand["value"]), warehouse.name) + for operand in computed_column["operands"] + ] + output_column_name = computed_column["output_column_name"] + + dbt_code += f", {function_name}({', '.join(operands)}) AS {output_column_name}" + + select_from = source_or_ref(**config["input"]) + if config["input"]["input_type"] == "cte": + dbt_code += f" FROM {select_from}" + else: + dbt_code += f" FROM {{{{{select_from}}}}}" + + return dbt_code, source_columns + + +def generic_function(config: dict, warehouse: WarehouseInterface, project_dir: str): + """ + Perform a generic SQL function operation. + """ + dbt_sql = "" + if config["input"]["input_type"] != "cte": + dbt_sql = ( + "{{ config(materialized='table', schema='" + config["dest_schema"] + "') }}" + ) + + select_statement, output_cols = generic_function_dbt_sql(config, warehouse) + + dest_schema = config["dest_schema"] + output_name = config["output_model_name"] + + dbtproject = dbtProject(project_dir) + dbtproject.ensure_models_dir(dest_schema) + model_sql_path = dbtproject.write_model(dest_schema, output_name, dbt_sql + select_statement) + + return model_sql_path, output_cols \ No newline at end of file diff --git a/dbt_automation/operations/mergeoperations.py b/dbt_automation/operations/mergeoperations.py index 08c2c69..77260cc 100644 --- a/dbt_automation/operations/mergeoperations.py +++ b/dbt_automation/operations/mergeoperations.py @@ -9,6 +9,7 @@ rename_columns_dbt_sql, ) from dbt_automation.operations.flattenjson import flattenjson_dbt_sql +from dbt_automation.operations.generic import generic_function_dbt_sql from dbt_automation.operations.mergetables import union_tables_sql from dbt_automation.operations.regexextraction import regex_extraction_sql from dbt_automation.utils.dbtproject import dbtProject @@ -127,6 +128,10 @@ def merge_operations_sql( op_select_statement, out_cols = unpivot_dbt_sql( operation["config"], warehouse ) + elif operation["type"] == "generic": + op_select_statement, out_cols = generic_function_dbt_sql( + operation["config"], warehouse + ) output_cols = out_cols diff --git a/scripts/main.py b/scripts/main.py index 08e1da1..bc7ee79 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -8,6 +8,7 @@ import json import yaml from dotenv import load_dotenv +from dbt_automation.operations.generic import generic_function from dbt_automation.operations.arithmetic import arithmetic from dbt_automation.operations.mergeoperations import merge_operations from dbt_automation.operations.scaffold import scaffold @@ -52,6 +53,7 @@ "casewhen": casewhen, "pivot": pivot, "unpivot": unpivot, + "generic": generic_function } load_dotenv("./../dbconnection.env") diff --git a/tests/warehouse/test_bigquery_ops.py b/tests/warehouse/test_bigquery_ops.py index f6e6900..e1551b8 100644 --- a/tests/warehouse/test_bigquery_ops.py +++ b/tests/warehouse/test_bigquery_ops.py @@ -8,6 +8,7 @@ from logging import basicConfig, getLogger, INFO from dbt_automation.operations.droprenamecolumns import rename_columns, drop_columns from dbt_automation.operations.flattenjson import flattenjson +from dbt_automation.operations.generic import generic_function from dbt_automation.operations.mergeoperations import ( merge_operations, ) @@ -1078,3 +1079,53 @@ def test_flattenjson(self): assert "_airbyte_data_NGO" in cols assert "_airbyte_data_Month" in cols assert "_airbyte_ab_id" in cols + + def test_generic(self): + """test generic function""" + wc_client = TestBigqueryOperations.wc_client + output_name = "generic_table" + + config = { + "input": { + "input_type": "model", + "input_name": "_airbyte_raw_Sheet2", + "source_name": None, + }, + "dest_schema": "pytest_intermediate", + "output_model_name": output_name, + "source_columns": ["NGO", "Month", "measure1", "measure2", "Indicator"], + "computed_columns": [ + { + "function_name": "LOWER", + "operands": [ + {"value": "NGO", "is_col": True} + ], + "output_column_name": "ngo_lower" + }, + { + "function_name": "TRIM", + "operands": [ + {"value": "measure1", "is_col": True} + ], + "output_column_name": "trimmed_indicator" + } + ], + } + + generic_function(config, wc_client, TestBigqueryOperations.test_project_dir) + + TestBigqueryOperations.execute_dbt("run", output_name) + + cols = [ + col_dict["name"] + for col_dict in wc_client.get_table_columns( + "pytest_intermediate", output_name + ) + ] + assert "NGO" in cols + assert "Indicator" in cols + table_data = wc_client.get_table_data("pytest_intermediate", output_name, 1) + ngo_column = [row['ngo_lower'] for row in table_data] + + for value in ngo_column: + assert value == value.lower(), f"Value {value} in 'NGO' column is not lowercase" diff --git a/tests/warehouse/test_postgres_ops.py b/tests/warehouse/test_postgres_ops.py index b9bc9fd..7ed230f 100644 --- a/tests/warehouse/test_postgres_ops.py +++ b/tests/warehouse/test_postgres_ops.py @@ -7,6 +7,7 @@ from logging import basicConfig, getLogger, INFO from dbt_automation.operations.flattenjson import flattenjson from dbt_automation.operations.droprenamecolumns import rename_columns, drop_columns +from dbt_automation.operations.generic import generic_function from dbt_automation.operations.mergeoperations import merge_operations from dbt_automation.utils.warehouseclient import get_client from dbt_automation.operations.scaffold import scaffold @@ -1091,3 +1092,59 @@ def test_merge_operation(self): ) == 0 ) + + def test_generic(self): + """test generic operation""" + wc_client = TestPostgresOperations.wc_client + output_name = "generic" + + config = { + "input": { + "input_type": "model", + "input_name": "_airbyte_raw_Sheet2", + "source_name": None, + }, + "dest_schema": "pytest_intermediate", + "output_model_name": output_name, + "source_columns": ["NGO", "Month", "measure1", "measure2", "Indicator"], + "computed_columns": [ + { + "function_name": "LOWER", + "operands": [ + {"value": "NGO", "is_col": True} + ], + "output_column_name": "ngo_lower" + }, + { + "function_name": "TRIM", + "operands": [ + {"value": "measure1", "is_col": True} + ], + "output_column_name": "trimmed_measure_1" + } + ], + } + + generic_function( + config, + wc_client, + TestPostgresOperations.test_project_dir, + ) + + TestPostgresOperations.execute_dbt("run", output_name) + + cols = [ + col_dict["name"] + for col_dict in wc_client.get_table_columns( + "pytest_intermediate", output_name + ) + ] + + assert "NGO" in cols + assert "Indicator" in cols + table_data = wc_client.get_table_data("pytest_intermediate", output_name, 1) + ngo_column = [row['ngo_lower'] for row in table_data] + + for value in ngo_column: + assert value == value.lower(), f"Value {value} in 'NGO' column is not lowercase" +