diff --git a/dbt_automation/assets/operations.template.yml b/dbt_automation/assets/operations.template.yml index 0e47824..1a75447 100644 --- a/dbt_automation/assets/operations.template.yml +++ b/dbt_automation/assets/operations.template.yml @@ -285,6 +285,25 @@ operations: - column: operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct"> output_col_name: + + - type: aggregate + config: + input: + input_type: <"source" or "model" of table1> + input_name: + source_name: + dest_schema: + output_name: + aggregate_on: + - column: + operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct"> + output_col_name: + - column: + operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct"> + output_col_name: + - column: + operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct"> + output_col_name: - type: mergeoperations @@ -471,4 +490,18 @@ operations: output_col_name: - column: operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct"> - output_col_name: \ No newline at end of file + output_col_name: + + - type: aggregate + config: + aggregate_on: + - column: + operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct"> + output_col_name: + - column: + operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct"> + output_col_name: + - column: + operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct"> + output_col_name: + \ No newline at end of file diff --git a/dbt_automation/operations/aggregate.py b/dbt_automation/operations/aggregate.py new file mode 100644 index 0000000..c125f59 --- /dev/null +++ b/dbt_automation/operations/aggregate.py @@ -0,0 +1,106 @@ +""" +Generates a model after grouping by and aggregating columns +""" + +from logging import basicConfig, getLogger, INFO + +from dbt_automation.utils.dbtproject import dbtProject +from dbt_automation.utils.columnutils import quote_columnname +from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface +from dbt_automation.utils.tableutils import source_or_ref + +basicConfig(level=INFO) +logger = getLogger() + +# sql, len_output_set = aggregate.aggregate_dbt_sql({ +# "input": { +# "input_type": "source", +# "source_name": "pytest_intermediate", +# "input_name": "arithmetic_add", +# }, +# "aggregate_on": [ +# { +# "operation": "count", +# "column": "NGO", +# "output_col_name": "count__ngo" +# }, +# { +# "operation": "countdistinct", +# "column": "Month", +# "output_col_name": "distinctmonths" +# }, +# ], + +# }, wc_client) + +# SELECT +# COUNT("NGO") AS "count__ngo", COUNT(DISTINCT "Month") AS "distinctmonths" +# FROM {{source('pytest_intermediate', 'arithmetic_add')}} + + +# pylint:disable=unused-argument,logging-fstring-interpolation +def aggregate_dbt_sql( + config: dict, + warehouse: WarehouseInterface, +): + """ + Generate SQL code for the coalesce_columns operation. + """ + source_columns = config.get( + "source_columns", [] + ) # we wont be using any select on source_columns; sql will fail, only aggregate columns will be selected + aggregate_on: list[dict] = config.get("aggregate_on", []) + input_table = config["input"] + + dbt_code = "SELECT\n" + + # dbt_code += ",\n".join( + # [quote_columnname(col_name, warehouse.name) for col_name in source_columns] + # ) + + for agg_col in aggregate_on: + if agg_col["operation"] == "count": + dbt_code += ( + f" COUNT({quote_columnname(agg_col['column'], warehouse.name)}) " + ) + elif agg_col["operation"] == "countdistinct": + dbt_code += f" COUNT(DISTINCT {quote_columnname(agg_col['column'], warehouse.name)}) " + else: + dbt_code += f" {agg_col['operation'].upper()}({quote_columnname(agg_col['column'], warehouse.name)}) " + + dbt_code += ( + f" AS {quote_columnname(agg_col['output_col_name'], warehouse.name)}," + ) + + dbt_code = dbt_code[:-1] # remove the last comma + dbt_code += "\n" + select_from = source_or_ref(**input_table) + if input_table["input_type"] == "cte": + dbt_code += f"FROM {select_from}\n" + else: + dbt_code += f"FROM {{{{{select_from}}}}}\n" + + return dbt_code, [col["output_col_name"] for col in aggregate_on] + + +def aggregate(config: dict, warehouse: WarehouseInterface, project_dir: str): + """ + Perform coalescing of columns and generate a DBT model. + """ + dbt_sql = "" + if config["input"]["input_type"] != "cte": + dbt_sql = ( + "{{ config(materialized='table', schema='" + config["dest_schema"] + "') }}" + ) + + select_statement, output_cols = aggregate_dbt_sql(config, warehouse) + dbt_sql += "\n" + select_statement + + dbt_project = dbtProject(project_dir) + dbt_project.ensure_models_dir(config["dest_schema"]) + + output_name = config["output_name"] + dest_schema = config["dest_schema"] + model_sql_path = dbt_project.write_model(dest_schema, output_name, dbt_sql) + + return model_sql_path, output_cols diff --git a/dbt_automation/operations/mergeoperations.py b/dbt_automation/operations/mergeoperations.py index fe03f05..ba5531c 100644 --- a/dbt_automation/operations/mergeoperations.py +++ b/dbt_automation/operations/mergeoperations.py @@ -18,6 +18,7 @@ from dbt_automation.operations.joins import joins_sql from dbt_automation.operations.wherefilter import where_filter_sql from dbt_automation.operations.groupby import groupby_dbt_sql +from dbt_automation.operations.aggregate import aggregate_dbt_sql def merge_operations_sql( @@ -101,6 +102,10 @@ def merge_operations_sql( op_select_statement, out_cols = groupby_dbt_sql( operation["config"], warehouse ) + elif operation["type"] == "aggregate": + op_select_statement, out_cols = aggregate_dbt_sql( + operation["config"], warehouse + ) output_cols = out_cols diff --git a/scripts/main.py b/scripts/main.py index 959d1da..db433ee 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -25,6 +25,7 @@ from dbt_automation.operations.joins import join from dbt_automation.operations.wherefilter import where_filter from dbt_automation.operations.groupby import groupby +from dbt_automation.operations.aggregate import aggregate OPERATIONS_DICT = { "flatten": flatten_operation, @@ -44,6 +45,7 @@ "join": join, "where": where_filter, "groupby": groupby, + "aggregate": aggregate, } load_dotenv("./../dbconnection.env") diff --git a/tests/warehouse/test_bigquery_ops.py b/tests/warehouse/test_bigquery_ops.py index 3a0e81a..d54cb3e 100644 --- a/tests/warehouse/test_bigquery_ops.py +++ b/tests/warehouse/test_bigquery_ops.py @@ -21,6 +21,7 @@ from dbt_automation.operations.castdatatypes import cast_datatypes from dbt_automation.operations.regexextraction import regex_extraction from dbt_automation.operations.mergetables import union_tables +from dbt_automation.operations.aggregate import aggregate basicConfig(level=INFO) @@ -561,6 +562,46 @@ def test_regexextract(self): else (regex["NGO"] is None) ) + def test_aggregate(self): + """test aggregate col operation""" + wc_client = TestBigqueryOperations.wc_client + output_name = "aggregate_col" + + config = { + "input": { + "input_type": "model", + "input_name": "cast", + "source_name": None, + }, + "dest_schema": "pytest_intermediate", + "output_name": output_name, + "aggregate_on": [ + {"column": "measure1", "operation": "sum", "output_col_name": "agg1"}, + {"column": "measure2", "operation": "sum", "output_col_name": "agg2"}, + ], + } + + aggregate( + config, + wc_client, + TestBigqueryOperations.test_project_dir, + ) + + TestBigqueryOperations.execute_dbt("run", output_name) + + cols = [ + col_dict["name"] + for col_dict in wc_client.get_table_columns( + "pytest_intermediate", output_name + ) + ] + assert "agg1" in cols + assert "agg2" in cols + table_data_agg = wc_client.get_table_data( + "pytest_intermediate", output_name, 10 + ) + assert len(table_data_agg) == 1 + def test_mergetables(self): """test merge tables""" wc_client = TestBigqueryOperations.wc_client diff --git a/tests/warehouse/test_postgres_ops.py b/tests/warehouse/test_postgres_ops.py index e6dd9f8..3ee1860 100644 --- a/tests/warehouse/test_postgres_ops.py +++ b/tests/warehouse/test_postgres_ops.py @@ -19,6 +19,7 @@ from dbt_automation.utils.dbtproject import dbtProject from dbt_automation.operations.regexextraction import regex_extraction from dbt_automation.operations.mergetables import union_tables +from dbt_automation.operations.aggregate import aggregate basicConfig(level=INFO) @@ -572,6 +573,46 @@ def test_regexextract(self): else (regex["NGO"] is None) ) + def test_aggregate(self): + """test aggregate col operation""" + wc_client = TestPostgresOperations.wc_client + output_name = "aggregate_col" + + config = { + "input": { + "input_type": "model", + "input_name": "cast", + "source_name": None, + }, + "dest_schema": "pytest_intermediate", + "output_name": output_name, + "aggregate_on": [ + {"column": "measure1", "operation": "sum", "output_col_name": "agg1"}, + {"column": "measure2", "operation": "sum", "output_col_name": "agg2"}, + ], + } + + aggregate( + config, + wc_client, + TestPostgresOperations.test_project_dir, + ) + + TestPostgresOperations.execute_dbt("run", output_name) + + cols = [ + col_dict["name"] + for col_dict in wc_client.get_table_columns( + "pytest_intermediate", output_name + ) + ] + assert "agg1" in cols + assert "agg2" in cols + table_data_agg = wc_client.get_table_data( + "pytest_intermediate", output_name, 10 + ) + assert len(table_data_agg) == 1 + def test_mergetables(self): """test merge tables""" wc_client = TestPostgresOperations.wc_client