Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

83 op aggregate over table #87

Merged
merged 3 commits into from
Mar 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion dbt_automation/assets/operations.template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,25 @@ operations:
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>

- type: aggregate
config:
input:
input_type: <"source" or "model" of table1>
input_name: <name of source table or ref model table1>
source_name: <name of the source defined in source.yml; will be null for type "model" table1>
dest_schema: <destination schema>
output_name: <name of the output model>
aggregate_on:
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>


- type: mergeoperations
Expand Down Expand Up @@ -471,4 +490,18 @@ operations:
output_col_name: <output col name>
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>
output_col_name: <output col name>

- type: aggregate
config:
aggregate_on:
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>

106 changes: 106 additions & 0 deletions dbt_automation/operations/aggregate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
Generates a model after grouping by and aggregating columns
"""

from logging import basicConfig, getLogger, INFO

from dbt_automation.utils.dbtproject import dbtProject
from dbt_automation.utils.columnutils import quote_columnname
from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface
from dbt_automation.utils.tableutils import source_or_ref

basicConfig(level=INFO)
logger = getLogger()

# sql, len_output_set = aggregate.aggregate_dbt_sql({
# "input": {
# "input_type": "source",
# "source_name": "pytest_intermediate",
# "input_name": "arithmetic_add",
# },
# "aggregate_on": [
# {
# "operation": "count",
# "column": "NGO",
# "output_col_name": "count__ngo"
# },
# {
# "operation": "countdistinct",
# "column": "Month",
# "output_col_name": "distinctmonths"
# },
# ],

# }, wc_client)

# SELECT
# COUNT("NGO") AS "count__ngo", COUNT(DISTINCT "Month") AS "distinctmonths"
# FROM {{source('pytest_intermediate', 'arithmetic_add')}}


# pylint:disable=unused-argument,logging-fstring-interpolation
def aggregate_dbt_sql(
config: dict,
warehouse: WarehouseInterface,
):
"""
Generate SQL code for the coalesce_columns operation.
"""
source_columns = config.get(
"source_columns", []
) # we wont be using any select on source_columns; sql will fail, only aggregate columns will be selected
aggregate_on: list[dict] = config.get("aggregate_on", [])
input_table = config["input"]

dbt_code = "SELECT\n"

# dbt_code += ",\n".join(
# [quote_columnname(col_name, warehouse.name) for col_name in source_columns]
# )

for agg_col in aggregate_on:
if agg_col["operation"] == "count":
dbt_code += (
f" COUNT({quote_columnname(agg_col['column'], warehouse.name)}) "
)
elif agg_col["operation"] == "countdistinct":
dbt_code += f" COUNT(DISTINCT {quote_columnname(agg_col['column'], warehouse.name)}) "
else:
dbt_code += f" {agg_col['operation'].upper()}({quote_columnname(agg_col['column'], warehouse.name)}) "

dbt_code += (
f" AS {quote_columnname(agg_col['output_col_name'], warehouse.name)},"
)

dbt_code = dbt_code[:-1] # remove the last comma
dbt_code += "\n"
select_from = source_or_ref(**input_table)
if input_table["input_type"] == "cte":
dbt_code += f"FROM {select_from}\n"
else:
dbt_code += f"FROM {{{{{select_from}}}}}\n"

return dbt_code, [col["output_col_name"] for col in aggregate_on]


def aggregate(config: dict, warehouse: WarehouseInterface, project_dir: str):
"""
Perform coalescing of columns and generate a DBT model.
"""
dbt_sql = ""
if config["input"]["input_type"] != "cte":
dbt_sql = (
"{{ config(materialized='table', schema='" + config["dest_schema"] + "') }}"
)

select_statement, output_cols = aggregate_dbt_sql(config, warehouse)
dbt_sql += "\n" + select_statement

dbt_project = dbtProject(project_dir)
dbt_project.ensure_models_dir(config["dest_schema"])

output_name = config["output_name"]
dest_schema = config["dest_schema"]
model_sql_path = dbt_project.write_model(dest_schema, output_name, dbt_sql)

return model_sql_path, output_cols
5 changes: 5 additions & 0 deletions dbt_automation/operations/mergeoperations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dbt_automation.operations.joins import joins_sql
from dbt_automation.operations.wherefilter import where_filter_sql
from dbt_automation.operations.groupby import groupby_dbt_sql
from dbt_automation.operations.aggregate import aggregate_dbt_sql


def merge_operations_sql(
Expand Down Expand Up @@ -101,6 +102,10 @@ def merge_operations_sql(
op_select_statement, out_cols = groupby_dbt_sql(
operation["config"], warehouse
)
elif operation["type"] == "aggregate":
op_select_statement, out_cols = aggregate_dbt_sql(
operation["config"], warehouse
)

output_cols = out_cols

Expand Down
2 changes: 2 additions & 0 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from dbt_automation.operations.joins import join
from dbt_automation.operations.wherefilter import where_filter
from dbt_automation.operations.groupby import groupby
from dbt_automation.operations.aggregate import aggregate

OPERATIONS_DICT = {
"flatten": flatten_operation,
Expand All @@ -44,6 +45,7 @@
"join": join,
"where": where_filter,
"groupby": groupby,
"aggregate": aggregate,
}

load_dotenv("./../dbconnection.env")
Expand Down
41 changes: 41 additions & 0 deletions tests/warehouse/test_bigquery_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from dbt_automation.operations.castdatatypes import cast_datatypes
from dbt_automation.operations.regexextraction import regex_extraction
from dbt_automation.operations.mergetables import union_tables
from dbt_automation.operations.aggregate import aggregate


basicConfig(level=INFO)
Expand Down Expand Up @@ -561,6 +562,46 @@ def test_regexextract(self):
else (regex["NGO"] is None)
)

def test_aggregate(self):
"""test aggregate col operation"""
wc_client = TestBigqueryOperations.wc_client
output_name = "aggregate_col"

config = {
"input": {
"input_type": "model",
"input_name": "cast",
"source_name": None,
},
"dest_schema": "pytest_intermediate",
"output_name": output_name,
"aggregate_on": [
{"column": "measure1", "operation": "sum", "output_col_name": "agg1"},
{"column": "measure2", "operation": "sum", "output_col_name": "agg2"},
],
}

aggregate(
config,
wc_client,
TestBigqueryOperations.test_project_dir,
)

TestBigqueryOperations.execute_dbt("run", output_name)

cols = [
col_dict["name"]
for col_dict in wc_client.get_table_columns(
"pytest_intermediate", output_name
)
]
assert "agg1" in cols
assert "agg2" in cols
table_data_agg = wc_client.get_table_data(
"pytest_intermediate", output_name, 10
)
assert len(table_data_agg) == 1

def test_mergetables(self):
"""test merge tables"""
wc_client = TestBigqueryOperations.wc_client
Expand Down
41 changes: 41 additions & 0 deletions tests/warehouse/test_postgres_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dbt_automation.utils.dbtproject import dbtProject
from dbt_automation.operations.regexextraction import regex_extraction
from dbt_automation.operations.mergetables import union_tables
from dbt_automation.operations.aggregate import aggregate


basicConfig(level=INFO)
Expand Down Expand Up @@ -572,6 +573,46 @@ def test_regexextract(self):
else (regex["NGO"] is None)
)

def test_aggregate(self):
"""test aggregate col operation"""
wc_client = TestPostgresOperations.wc_client
output_name = "aggregate_col"

config = {
"input": {
"input_type": "model",
"input_name": "cast",
"source_name": None,
},
"dest_schema": "pytest_intermediate",
"output_name": output_name,
"aggregate_on": [
{"column": "measure1", "operation": "sum", "output_col_name": "agg1"},
{"column": "measure2", "operation": "sum", "output_col_name": "agg2"},
],
}

aggregate(
config,
wc_client,
TestPostgresOperations.test_project_dir,
)

TestPostgresOperations.execute_dbt("run", output_name)

cols = [
col_dict["name"]
for col_dict in wc_client.get_table_columns(
"pytest_intermediate", output_name
)
]
assert "agg1" in cols
assert "agg2" in cols
table_data_agg = wc_client.get_table_data(
"pytest_intermediate", output_name, 10
)
assert len(table_data_agg) == 1

def test_mergetables(self):
"""test merge tables"""
wc_client = TestPostgresOperations.wc_client
Expand Down
Loading