Skip to content

Commit

Permalink
Merge pull request #87 from DalgoT4D/83-op-aggregate-over-table
Browse files Browse the repository at this point in the history
83 op aggregate over table
  • Loading branch information
fatchat authored Mar 17, 2024
2 parents 7df33c2 + 6165cb3 commit af1d781
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 1 deletion.
35 changes: 34 additions & 1 deletion dbt_automation/assets/operations.template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,25 @@ operations:
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>

- type: aggregate
config:
input:
input_type: <"source" or "model" of table1>
input_name: <name of source table or ref model table1>
source_name: <name of the source defined in source.yml; will be null for type "model" table1>
dest_schema: <destination schema>
output_name: <name of the output model>
aggregate_on:
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>


- type: mergeoperations
Expand Down Expand Up @@ -471,4 +490,18 @@ operations:
output_col_name: <output col name>
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>
output_col_name: <output col name>

- type: aggregate
config:
aggregate_on:
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>

106 changes: 106 additions & 0 deletions dbt_automation/operations/aggregate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
Generates a model after grouping by and aggregating columns
"""

from logging import basicConfig, getLogger, INFO

from dbt_automation.utils.dbtproject import dbtProject
from dbt_automation.utils.columnutils import quote_columnname
from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface
from dbt_automation.utils.tableutils import source_or_ref

basicConfig(level=INFO)
logger = getLogger()

# sql, len_output_set = aggregate.aggregate_dbt_sql({
# "input": {
# "input_type": "source",
# "source_name": "pytest_intermediate",
# "input_name": "arithmetic_add",
# },
# "aggregate_on": [
# {
# "operation": "count",
# "column": "NGO",
# "output_col_name": "count__ngo"
# },
# {
# "operation": "countdistinct",
# "column": "Month",
# "output_col_name": "distinctmonths"
# },
# ],

# }, wc_client)

# SELECT
# COUNT("NGO") AS "count__ngo", COUNT(DISTINCT "Month") AS "distinctmonths"
# FROM {{source('pytest_intermediate', 'arithmetic_add')}}


# pylint:disable=unused-argument,logging-fstring-interpolation
def aggregate_dbt_sql(
config: dict,
warehouse: WarehouseInterface,
):
"""
Generate SQL code for the coalesce_columns operation.
"""
source_columns = config.get(
"source_columns", []
) # we wont be using any select on source_columns; sql will fail, only aggregate columns will be selected
aggregate_on: list[dict] = config.get("aggregate_on", [])
input_table = config["input"]

dbt_code = "SELECT\n"

# dbt_code += ",\n".join(
# [quote_columnname(col_name, warehouse.name) for col_name in source_columns]
# )

for agg_col in aggregate_on:
if agg_col["operation"] == "count":
dbt_code += (
f" COUNT({quote_columnname(agg_col['column'], warehouse.name)}) "
)
elif agg_col["operation"] == "countdistinct":
dbt_code += f" COUNT(DISTINCT {quote_columnname(agg_col['column'], warehouse.name)}) "
else:
dbt_code += f" {agg_col['operation'].upper()}({quote_columnname(agg_col['column'], warehouse.name)}) "

dbt_code += (
f" AS {quote_columnname(agg_col['output_col_name'], warehouse.name)},"
)

dbt_code = dbt_code[:-1] # remove the last comma
dbt_code += "\n"
select_from = source_or_ref(**input_table)
if input_table["input_type"] == "cte":
dbt_code += f"FROM {select_from}\n"
else:
dbt_code += f"FROM {{{{{select_from}}}}}\n"

return dbt_code, [col["output_col_name"] for col in aggregate_on]


def aggregate(config: dict, warehouse: WarehouseInterface, project_dir: str):
"""
Perform coalescing of columns and generate a DBT model.
"""
dbt_sql = ""
if config["input"]["input_type"] != "cte":
dbt_sql = (
"{{ config(materialized='table', schema='" + config["dest_schema"] + "') }}"
)

select_statement, output_cols = aggregate_dbt_sql(config, warehouse)
dbt_sql += "\n" + select_statement

dbt_project = dbtProject(project_dir)
dbt_project.ensure_models_dir(config["dest_schema"])

output_name = config["output_name"]
dest_schema = config["dest_schema"]
model_sql_path = dbt_project.write_model(dest_schema, output_name, dbt_sql)

return model_sql_path, output_cols
5 changes: 5 additions & 0 deletions dbt_automation/operations/mergeoperations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dbt_automation.operations.joins import joins_sql
from dbt_automation.operations.wherefilter import where_filter_sql
from dbt_automation.operations.groupby import groupby_dbt_sql
from dbt_automation.operations.aggregate import aggregate_dbt_sql


def merge_operations_sql(
Expand Down Expand Up @@ -101,6 +102,10 @@ def merge_operations_sql(
op_select_statement, out_cols = groupby_dbt_sql(
operation["config"], warehouse
)
elif operation["type"] == "aggregate":
op_select_statement, out_cols = aggregate_dbt_sql(
operation["config"], warehouse
)

output_cols = out_cols

Expand Down
2 changes: 2 additions & 0 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from dbt_automation.operations.joins import join
from dbt_automation.operations.wherefilter import where_filter
from dbt_automation.operations.groupby import groupby
from dbt_automation.operations.aggregate import aggregate

OPERATIONS_DICT = {
"flatten": flatten_operation,
Expand All @@ -44,6 +45,7 @@
"join": join,
"where": where_filter,
"groupby": groupby,
"aggregate": aggregate,
}

load_dotenv("./../dbconnection.env")
Expand Down
41 changes: 41 additions & 0 deletions tests/warehouse/test_bigquery_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from dbt_automation.operations.castdatatypes import cast_datatypes
from dbt_automation.operations.regexextraction import regex_extraction
from dbt_automation.operations.mergetables import union_tables
from dbt_automation.operations.aggregate import aggregate


basicConfig(level=INFO)
Expand Down Expand Up @@ -561,6 +562,46 @@ def test_regexextract(self):
else (regex["NGO"] is None)
)

def test_aggregate(self):
"""test aggregate col operation"""
wc_client = TestBigqueryOperations.wc_client
output_name = "aggregate_col"

config = {
"input": {
"input_type": "model",
"input_name": "cast",
"source_name": None,
},
"dest_schema": "pytest_intermediate",
"output_name": output_name,
"aggregate_on": [
{"column": "measure1", "operation": "sum", "output_col_name": "agg1"},
{"column": "measure2", "operation": "sum", "output_col_name": "agg2"},
],
}

aggregate(
config,
wc_client,
TestBigqueryOperations.test_project_dir,
)

TestBigqueryOperations.execute_dbt("run", output_name)

cols = [
col_dict["name"]
for col_dict in wc_client.get_table_columns(
"pytest_intermediate", output_name
)
]
assert "agg1" in cols
assert "agg2" in cols
table_data_agg = wc_client.get_table_data(
"pytest_intermediate", output_name, 10
)
assert len(table_data_agg) == 1

def test_mergetables(self):
"""test merge tables"""
wc_client = TestBigqueryOperations.wc_client
Expand Down
41 changes: 41 additions & 0 deletions tests/warehouse/test_postgres_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dbt_automation.utils.dbtproject import dbtProject
from dbt_automation.operations.regexextraction import regex_extraction
from dbt_automation.operations.mergetables import union_tables
from dbt_automation.operations.aggregate import aggregate


basicConfig(level=INFO)
Expand Down Expand Up @@ -572,6 +573,46 @@ def test_regexextract(self):
else (regex["NGO"] is None)
)

def test_aggregate(self):
"""test aggregate col operation"""
wc_client = TestPostgresOperations.wc_client
output_name = "aggregate_col"

config = {
"input": {
"input_type": "model",
"input_name": "cast",
"source_name": None,
},
"dest_schema": "pytest_intermediate",
"output_name": output_name,
"aggregate_on": [
{"column": "measure1", "operation": "sum", "output_col_name": "agg1"},
{"column": "measure2", "operation": "sum", "output_col_name": "agg2"},
],
}

aggregate(
config,
wc_client,
TestPostgresOperations.test_project_dir,
)

TestPostgresOperations.execute_dbt("run", output_name)

cols = [
col_dict["name"]
for col_dict in wc_client.get_table_columns(
"pytest_intermediate", output_name
)
]
assert "agg1" in cols
assert "agg2" in cols
table_data_agg = wc_client.get_table_data(
"pytest_intermediate", output_name, 10
)
assert len(table_data_agg) == 1

def test_mergetables(self):
"""test merge tables"""
wc_client = TestPostgresOperations.wc_client
Expand Down

0 comments on commit af1d781

Please sign in to comment.