Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Groupby operation #82

Merged
merged 3 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,7 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
pytest.ini
pytest.ini

# scripts
run-dbt.sh
43 changes: 42 additions & 1 deletion dbt_automation/assets/operations.template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,30 @@ operations:
operator: <"=" or "!=" or "<" or ">" or "<=" or ">=" >
value: <value>
sql_snippet: < custom sql snippet assume its formatted; eg. col1 != 5 >

- type: groupby
config:
input:
input_type: <"source" or "model" of table1>
input_name: <name of source table or ref model table1>
source_name: <name of the source defined in source.yml; will be null for type "model" table1>
source_columns:
- <column name>
- <column name>
- <column name>
dest_schema: <destination schema>
output_name: <name of the output model>
aggregate_on:
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>


- type: mergeoperations
config:
Expand Down Expand Up @@ -424,4 +448,21 @@ operations:
- column: <column name>
operator: <"=" or "!=" or "<" or ">" or "<=" or ">=" >
value: <value>
sql_snippet: < custom sql snippet assume its formatted; eg. col1 != 5 >
sql_snippet: < custom sql snippet assume its formatted; eg. col1 != 5 >

- type: groupby
config:
source_columns:
- <column name>
- <column name>
- <column name>
aggregate_on:
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>
- column: <column name from source column>
operation: <"sum" or "avg" or "count" or "min" or "max" or "countdistinct">
output_col_name: <output col name>
123 changes: 123 additions & 0 deletions dbt_automation/operations/groupby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Generates a model after grouping by and aggregating columns
"""

from logging import basicConfig, getLogger, INFO

from dbt_automation.utils.dbtproject import dbtProject
from dbt_automation.utils.columnutils import quote_columnname
from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface
from dbt_automation.utils.tableutils import source_or_ref

basicConfig(level=INFO)
logger = getLogger()

# sql, columns = groupby.groupby_dbt_sql({
# "source_columns": ["NGO", "Month"],
# "aggregate_on": [
# {
# "operation": "count",
# "column": "measure1",
# "output_col_name": "measure1__count",
# },
# {
# "operation": "countdistinct",
# "column": "measure2",
# "output_col_name": "measure2__count",
# },
# {
# "operation": "sum",
# "column": "Indicator",
# "output_col_name": "sum_of_indicator"
# },
# ],
# "input": {
# "input_type": "source",
# "source_name": "pytest_intermediate",
# "input_name": "arithmetic_add",
# },
# }, wc_client)
#
# =>
#
# SELECT
# "NGO",
# "Month",
# COUNT("measure1") AS "measure1__count",
# COUNT(DISTINCT "measure2") AS "measure2__count",
# SUM("Indicator") AS "sum_of_indicator"
# FROM {{source('pytest_intermediate', 'arithmetic_add')}}
# GROUP BY "NGO","Month"


# pylint:disable=unused-argument,logging-fstring-interpolation
def groupby_dbt_sql(
config: dict,
warehouse: WarehouseInterface,
):
"""
Generate SQL code for the coalesce_columns operation.
"""
source_columns = config["source_columns"]
aggregate_on: list[dict] = config.get("aggregate_on", [])
input_table = config["input"]

dbt_code = "SELECT\n"

dbt_code += ",\n".join(
[quote_columnname(col_name, warehouse.name) for col_name in source_columns]
)

for agg_col in aggregate_on:
if agg_col["operation"] == "count":
dbt_code += (
f",\n COUNT({quote_columnname(agg_col['column'], warehouse.name)})"
)
elif agg_col["operation"] == "countdistinct":
dbt_code += f",\n COUNT(DISTINCT {quote_columnname(agg_col['column'], warehouse.name)})"
else:
dbt_code += f",\n {agg_col['operation'].upper()}({quote_columnname(agg_col['column'], warehouse.name)})"

dbt_code += (
f" AS {quote_columnname(agg_col['output_col_name'], warehouse.name)}"
)

dbt_code += "\n"
select_from = source_or_ref(**input_table)
if input_table["input_type"] == "cte":
dbt_code += f"FROM {select_from}\n"
else:
dbt_code += f"FROM {{{{{select_from}}}}}\n"

if len(source_columns) > 0:
dbt_code += "GROUP BY "
dbt_code += ",".join(
[quote_columnname(col_name, warehouse.name) for col_name in source_columns]
)

output_columns = source_columns + [col["output_col_name"] for col in aggregate_on]

return dbt_code, output_columns


def groupby(config: dict, warehouse: WarehouseInterface, project_dir: str):
"""
Perform coalescing of columns and generate a DBT model.
"""
dbt_sql = ""
if config["input"]["input_type"] != "cte":
dbt_sql = (
"{{ config(materialized='table', schema='" + config["dest_schema"] + "') }}"
)

select_statement, output_cols = groupby_dbt_sql(config, warehouse)
dbt_sql += "\n" + select_statement

dbt_project = dbtProject(project_dir)
dbt_project.ensure_models_dir(config["dest_schema"])

output_name = config["output_name"]
dest_schema = config["dest_schema"]
model_sql_path = dbt_project.write_model(dest_schema, output_name, dbt_sql)

return model_sql_path, output_cols
5 changes: 5 additions & 0 deletions dbt_automation/operations/mergeoperations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dbt_automation.operations.replace import replace_dbt_sql
from dbt_automation.operations.joins import joins_sql
from dbt_automation.operations.wherefilter import where_filter_sql
from dbt_automation.operations.groupby import groupby_dbt_sql


def merge_operations_sql(
Expand Down Expand Up @@ -96,6 +97,10 @@ def merge_operations_sql(
op_select_statement, out_cols = where_filter_sql(
operation["config"], warehouse
)
elif operation["type"] == "groupby":
op_select_statement, out_cols = groupby_dbt_sql(
operation["config"], warehouse
)

output_cols = out_cols

Expand Down
9 changes: 9 additions & 0 deletions run-dbt.example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/sh
# Make a copy of this file and rename it to run-dbt.sh

# Variables
project_dir="/Path/to/dbt/project/dir"
virtual_env_dir="/Path/to/dbt/environment/"

# Activate the virtual environment
"$virtual_env_dir"/bin/dbt run --project-dir "$project_dir" --profiles-dir "$project_dir"/profiles
5 changes: 4 additions & 1 deletion scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from dbt_automation.operations.replace import replace
from dbt_automation.operations.joins import join
from dbt_automation.operations.wherefilter import where_filter
from dbt_automation.operations.groupby import groupby

OPERATIONS_DICT = {
"flatten": flatten_operation,
Expand All @@ -42,6 +43,7 @@
"replace": replace,
"join": join,
"where": where_filter,
"groupby": groupby,
}

load_dotenv("./../dbconnection.env")
Expand Down Expand Up @@ -104,9 +106,10 @@

logger.info(f"running the {op_type} operation")
logger.info(f"using config {config}")
OPERATIONS_DICT[op_type](
output = OPERATIONS_DICT[op_type](
config=config, warehouse=warehouse, project_dir=project_dir
)
logger.info(f"finished running the {op_type} operation")
logger.info(output)

warehouse.close()
Loading