diff --git a/.gitignore b/.gitignore index 04b7133..280e2b1 100644 --- a/.gitignore +++ b/.gitignore @@ -42,4 +42,7 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ -pytest.ini \ No newline at end of file +pytest.ini + +# scripts +run-dbt.sh \ No newline at end of file diff --git a/dbt_automation/assets/operations.template.yml b/dbt_automation/assets/operations.template.yml index 91693bf..412e580 100644 --- a/dbt_automation/assets/operations.template.yml +++ b/dbt_automation/assets/operations.template.yml @@ -259,6 +259,30 @@ operations: operator: <"=" or "!=" or "<" or ">" or "<=" or ">=" > value: sql_snippet: < custom sql snippet assume its formatted; eg. col1 != 5 > + + - type: groupby + config: + input: + input_type: <"source" or "model" of table1> + input_name: + source_name: + source_columns: + - + - + - + dest_schema: + output_name: + aggregate_on: + - column: + operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct"> + output_col_name: + - column: + operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct"> + output_col_name: + - column: + operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct"> + output_col_name: + - type: mergeoperations config: @@ -424,4 +448,21 @@ operations: - column: operator: <"=" or "!=" or "<" or ">" or "<=" or ">=" > value: - sql_snippet: < custom sql snippet assume its formatted; eg. col1 != 5 > \ No newline at end of file + sql_snippet: < custom sql snippet assume its formatted; eg. col1 != 5 > + + - type: groupby + config: + source_columns: + - + - + - + aggregate_on: + - column: + operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct"> + output_col_name: + - column: + operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct"> + output_col_name: + - column: + operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct"> + output_col_name: \ No newline at end of file diff --git a/dbt_automation/operations/groupby.py b/dbt_automation/operations/groupby.py new file mode 100644 index 0000000..e9bf443 --- /dev/null +++ b/dbt_automation/operations/groupby.py @@ -0,0 +1,123 @@ +""" +Generates a model after grouping by and aggregating columns +""" + +from logging import basicConfig, getLogger, INFO + +from dbt_automation.utils.dbtproject import dbtProject +from dbt_automation.utils.columnutils import quote_columnname +from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface +from dbt_automation.utils.tableutils import source_or_ref + +basicConfig(level=INFO) +logger = getLogger() + +# sql, columns = groupby.groupby_dbt_sql({ +# "source_columns": ["NGO", "Month"], +# "aggregate_on": [ +# { +# "operation": "count", +# "column": "measure1", +# "output_col_name": "measure1__count", +# }, +# { +# "operation": "countdistinct", +# "column": "measure2", +# "output_col_name": "measure2__count", +# }, +# { +# "operation": "sum", +# "column": "Indicator", +# "output_col_name": "sum_of_indicator" +# }, +# ], +# "input": { +# "input_type": "source", +# "source_name": "pytest_intermediate", +# "input_name": "arithmetic_add", +# }, +# }, wc_client) +# +# => +# +# SELECT +# "NGO", +# "Month", +# COUNT("measure1") AS "measure1__count", +# COUNT(DISTINCT "measure2") AS "measure2__count", +# SUM("Indicator") AS "sum_of_indicator" +# FROM {{source('pytest_intermediate', 'arithmetic_add')}} +# GROUP BY "NGO","Month" + + +# pylint:disable=unused-argument,logging-fstring-interpolation +def groupby_dbt_sql( + config: dict, + warehouse: WarehouseInterface, +): + """ + Generate SQL code for the coalesce_columns operation. + """ + source_columns = config["source_columns"] + aggregate_on: list[dict] = config.get("aggregate_on", []) + input_table = config["input"] + + dbt_code = "SELECT\n" + + dbt_code += ",\n".join( + [quote_columnname(col_name, warehouse.name) for col_name in source_columns] + ) + + for agg_col in aggregate_on: + if agg_col["operation"] == "count": + dbt_code += ( + f",\n COUNT({quote_columnname(agg_col['column'], warehouse.name)})" + ) + elif agg_col["operation"] == "countdistinct": + dbt_code += f",\n COUNT(DISTINCT {quote_columnname(agg_col['column'], warehouse.name)})" + else: + dbt_code += f",\n {agg_col['operation'].upper()}({quote_columnname(agg_col['column'], warehouse.name)})" + + dbt_code += ( + f" AS {quote_columnname(agg_col['output_col_name'], warehouse.name)}" + ) + + dbt_code += "\n" + select_from = source_or_ref(**input_table) + if input_table["input_type"] == "cte": + dbt_code += f"FROM {select_from}\n" + else: + dbt_code += f"FROM {{{{{select_from}}}}}\n" + + if len(source_columns) > 0: + dbt_code += "GROUP BY " + dbt_code += ",".join( + [quote_columnname(col_name, warehouse.name) for col_name in source_columns] + ) + + output_columns = source_columns + [col["output_col_name"] for col in aggregate_on] + + return dbt_code, output_columns + + +def groupby(config: dict, warehouse: WarehouseInterface, project_dir: str): + """ + Perform coalescing of columns and generate a DBT model. + """ + dbt_sql = "" + if config["input"]["input_type"] != "cte": + dbt_sql = ( + "{{ config(materialized='table', schema='" + config["dest_schema"] + "') }}" + ) + + select_statement, output_cols = groupby_dbt_sql(config, warehouse) + dbt_sql += "\n" + select_statement + + dbt_project = dbtProject(project_dir) + dbt_project.ensure_models_dir(config["dest_schema"]) + + output_name = config["output_name"] + dest_schema = config["dest_schema"] + model_sql_path = dbt_project.write_model(dest_schema, output_name, dbt_sql) + + return model_sql_path, output_cols diff --git a/dbt_automation/operations/mergeoperations.py b/dbt_automation/operations/mergeoperations.py index b2cf557..fe03f05 100644 --- a/dbt_automation/operations/mergeoperations.py +++ b/dbt_automation/operations/mergeoperations.py @@ -17,6 +17,7 @@ from dbt_automation.operations.replace import replace_dbt_sql from dbt_automation.operations.joins import joins_sql from dbt_automation.operations.wherefilter import where_filter_sql +from dbt_automation.operations.groupby import groupby_dbt_sql def merge_operations_sql( @@ -96,6 +97,10 @@ def merge_operations_sql( op_select_statement, out_cols = where_filter_sql( operation["config"], warehouse ) + elif operation["type"] == "groupby": + op_select_statement, out_cols = groupby_dbt_sql( + operation["config"], warehouse + ) output_cols = out_cols diff --git a/run-dbt.example.sh b/run-dbt.example.sh new file mode 100644 index 0000000..858342a --- /dev/null +++ b/run-dbt.example.sh @@ -0,0 +1,9 @@ +#!/bin/sh +# Make a copy of this file and rename it to run-dbt.sh + +# Variables +project_dir="/Path/to/dbt/project/dir" +virtual_env_dir="/Path/to/dbt/environment/" + +# Activate the virtual environment +"$virtual_env_dir"/bin/dbt run --project-dir "$project_dir" --profiles-dir "$project_dir"/profiles diff --git a/scripts/main.py b/scripts/main.py index 4d011f0..959d1da 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -24,6 +24,7 @@ from dbt_automation.operations.replace import replace from dbt_automation.operations.joins import join from dbt_automation.operations.wherefilter import where_filter +from dbt_automation.operations.groupby import groupby OPERATIONS_DICT = { "flatten": flatten_operation, @@ -42,6 +43,7 @@ "replace": replace, "join": join, "where": where_filter, + "groupby": groupby, } load_dotenv("./../dbconnection.env") @@ -104,9 +106,10 @@ logger.info(f"running the {op_type} operation") logger.info(f"using config {config}") - OPERATIONS_DICT[op_type]( + output = OPERATIONS_DICT[op_type]( config=config, warehouse=warehouse, project_dir=project_dir ) logger.info(f"finished running the {op_type} operation") + logger.info(output) warehouse.close()