Skip to content

Commit

Permalink
Merge pull request #110 from DalgoT4D/106-unpivot-operation
Browse files Browse the repository at this point in the history
unpivot op
  • Loading branch information
fatchat authored Apr 11, 2024
2 parents 959334a + 22b1fc1 commit ef626df
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 10 deletions.
43 changes: 42 additions & 1 deletion dbt_automation/assets/operations.template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,30 @@ operations:
- <pivot col value3>
dest_schema: <destination schema>
output_name: <name of the output model>

- type: unpivot
config:
input:
input_type: <"source" or "model" of table1>
input_name: <name of source table or ref model table1>
source_name: <name of the source defined in source.yml; will be null for type "model" table1>
source_columns:
- <column name>
- <column name>
- <column name>
exclude_columns:
- <column name>
- <column name>
- <column name>
unpivot_columns:
- <column name>
- <column name>
- <column name>
unpivot_field_name: <by default - "field_name">
unpivot_value_name: <by default - "value">
cast_to: <data type to cast values to - "varchar" for postgres & "STRING" for bigquery>
dest_schema: <destination schema>
output_name: <name of the output model>


- type: mergeoperations
Expand Down Expand Up @@ -656,4 +680,21 @@ operations:
pivot_column_values:
- <pivot col value1>
- <pivot col value2>
- <pivot col value3>
- <pivot col value3>
- type: unpivot
config:
source_columns:
- <column name>
- <column name>
- <column name>
exclude_columns:
- <column name>
- <column name>
- <column name>
unpivot_columns:
- <column name>
- <column name>
- <column name>
cast_to: <data type to cast values to - "varchar" for postgres & "STRING" for bigquery>
unpivot_field_name: <by default - "field_name">
unpivot_value_name: <by default - "value">
67 changes: 67 additions & 0 deletions dbt_automation/assets/unpivot.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{#
Pivot values from columns to rows. Similar to pandas DataFrame melt() function.

Example Usage: {{ unpivot(relation=ref('users'), cast_to='integer', exclude=['id','created_at']) }}

Arguments:
relation: Relation object, required.
cast_to: The datatype to cast all unpivoted columns to. Default is varchar.
exclude: A list of columns to keep but exclude from the unpivot operation. Default is none.
remove: A list of columns to remove from the resulting table. Default is none.
field_name: Destination table column name for the source table column names.
value_name: Destination table column name for the pivoted values
#}

{% macro unpivot(relation=none, cast_to='varchar', exclude=none, remove=none, field_name='field_name', value_name='value', quote_identifiers=True) -%}
{{ return(adapter.dispatch('unpivot', 'dbt_utils')(relation, cast_to, exclude, remove, field_name, value_name, quote_identifiers)) }}
{% endmacro %}

{% macro default__unpivot(relation=none, cast_to='varchar', exclude=none, remove=none, field_name='field_name', value_name='value', quote_identifiers=True) -%}

{% if not relation %}
{{ exceptions.raise_compiler_error("Error: argument `relation` is required for `unpivot` macro.") }}
{% endif %}

{%- set exclude = exclude if exclude is not none else [] %}
{%- set remove = remove if remove is not none else [] %}

{%- set include_cols = [] %}

{%- set table_columns = {} %}

{%- do table_columns.update({relation: []}) %}

{%- do dbt_utils._is_relation(relation, 'unpivot') -%}
{%- do dbt_utils._is_ephemeral(relation, 'unpivot') -%}
{%- set cols = adapter.get_columns_in_relation(relation) %}

{%- for col in cols -%}
{%- if col.column.lower() not in remove|map('lower') and col.column.lower() not in exclude|map('lower') -%}
{% do include_cols.append(col) %}
{%- endif %}
{%- endfor %}


{%- for col in include_cols -%}
{%- set current_col_name = adapter.quote(col.column) if quote_identifiers else col.column -%}
select
{%- for exclude_col in exclude %}
{{ adapter.quote(exclude_col) if quote_identifiers else exclude_col }},
{%- endfor %}

cast('{{ col.column }}' as {{ dbt.type_string() }}) as {{ adapter.quote(field_name) if quote_identifiers else field_name }},
cast( {% if col.data_type == 'boolean' %}
{{ dbt.cast_bool_to_text(current_col_name) }}
{% else %}
{{ current_col_name }}
{% endif %}
as {{ cast_to }}) as {{ adapter.quote(value_name) if quote_identifiers else value_name }}

from {{ relation }}

{% if not loop.last -%}
union all
{% endif -%}
{%- endfor -%}

{%- endmacro %}
5 changes: 5 additions & 0 deletions dbt_automation/operations/mergeoperations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from dbt_automation.operations.flattenjson import flattenjson_dbt_sql
from dbt_automation.operations.mergetables import union_tables_sql
from dbt_automation.operations.pivot import pivot_dbt_sql
from dbt_automation.operations.unpivot import unpivot_dbt_sql


def merge_operations_sql(
Expand Down Expand Up @@ -122,6 +123,10 @@ def merge_operations_sql(
op_select_statement, out_cols = pivot_dbt_sql(
operation["config"], warehouse
)
elif operation["type"] == "unpivot":
op_select_statement, out_cols = unpivot_dbt_sql(
operation["config"], warehouse
)

output_cols = out_cols

Expand Down
23 changes: 15 additions & 8 deletions dbt_automation/operations/scaffold.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""setup the dbt project"""

import glob
import os, shutil, yaml
from pathlib import Path
from string import Template
Expand Down Expand Up @@ -44,14 +45,20 @@ def scaffold(config: dict, warehouse: WarehouseInterface, project_dir: str):
(Path(project_dir) / "models" / "staging").mkdir()
(Path(project_dir) / "models" / "intermediate").mkdir()

flatten_json_target = Path(project_dir) / "macros" / "flatten_json.sql"
custom_schema_target = Path(project_dir) / "macros" / "generate_schema_name.sql"
logger.info("created %s", flatten_json_target)
source_schema_name_macro_path = os.path.abspath(
os.path.join(os.path.abspath(assets.__file__), "..", "generate_schema_name.sql")
)
shutil.copy(source_schema_name_macro_path, custom_schema_target)
logger.info("created %s", custom_schema_target)
# copy all .sql files from assets/ to project_dir/macros
# create if the file is not present in project_dir/macros
assets_dir = assets.__path__[0]

# loop over all sql macros with .sql extension
for sql_file_path in glob.glob(os.path.join(assets_dir, "*.sql")):
# Get the target path in the project_dir/macros directory
target_path = Path(project_dir) / "macros" / Path(sql_file_path).name

# Copy the .sql file to the target path
shutil.copy(sql_file_path, target_path)

# Log the creation of the file
logger.info("created %s", target_path)

dbtproject_filename = Path(project_dir) / "dbt_project.yml"
PROJECT_TEMPLATE = Template(
Expand Down
89 changes: 89 additions & 0 deletions dbt_automation/operations/unpivot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
Generates a dbt model for unpivot
This operation will only work in the chain of mergeoperations if its at the first step
"""

from logging import basicConfig, getLogger, INFO

from dbt_automation.utils.dbtproject import dbtProject
from dbt_automation.utils.columnutils import quote_columnname
from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface
from dbt_automation.utils.columnutils import quote_columnname, quote_constvalue
from dbt_automation.utils.tableutils import source_or_ref

basicConfig(level=INFO)
logger = getLogger()


# pylint:disable=unused-argument,logging-fstring-interpolation
def unpivot_dbt_sql(
config: dict,
warehouse: WarehouseInterface,
):
"""
Generate SQL code for the coalesce_columns operation.
"""
source_columns = config.get("source_columns", []) # all columns
exclude_columns = config.get(
"exclude_columns", []
) # exclude from unpivot but keep in the resulting table
unpivot_on_columns = config.get("unpivot_columns", []) # columns to unpivot
input_table = config["input"]
field_name = config.get("unpivot_field_name", "field_name")
value_name = config.get("unpivot_value_name", "value")
cast_datatype_to = config.get("cast_to", "varchar")
if not cast_datatype_to and warehouse.name == "bigquery":
cast_datatype_to = "STRING"

if len(unpivot_on_columns) == 0:
raise ValueError("No columns specified for unpivot")

output_columns = list(set(exclude_columns) | set(unpivot_on_columns)) # union
remove_columns = list(set(source_columns) - set(output_columns))

dbt_code = "{{ unpivot("
dbt_code += source_or_ref(**input_table)
dbt_code += ", exclude="
dbt_code += (
"["
+ ",".join(
[quote_constvalue(col_name, warehouse.name) for col_name in exclude_columns]
)
+ "] ,"
)
dbt_code += f"cast_to={quote_constvalue(cast_datatype_to, warehouse.name)}, "
dbt_code += "remove="
dbt_code += (
"["
+ ",".join(
[quote_constvalue(col_name, warehouse.name) for col_name in remove_columns]
)
+ "] ,"
)
dbt_code += f"field_name={quote_constvalue(field_name, warehouse.name)}, value_name={quote_constvalue(value_name, warehouse.name)}"
dbt_code += ")}}\n"

return dbt_code, output_columns


def unpivot(config: dict, warehouse: WarehouseInterface, project_dir: str):
"""
Perform coalescing of columns and generate a DBT model.
"""
dbt_sql = ""
if config["input"]["input_type"] != "cte":
dbt_sql = (
"{{ config(materialized='table', schema='" + config["dest_schema"] + "') }}"
)

select_statement, output_cols = unpivot_dbt_sql(config, warehouse)
dbt_sql += "\n" + select_statement

dbt_project = dbtProject(project_dir)
dbt_project.ensure_models_dir(config["dest_schema"])

output_name = config["output_name"]
dest_schema = config["dest_schema"]
model_sql_path = dbt_project.write_model(dest_schema, output_name, dbt_sql)

return model_sql_path, output_cols
2 changes: 2 additions & 0 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from dbt_automation.operations.aggregate import aggregate
from dbt_automation.operations.casewhen import casewhen
from dbt_automation.operations.pivot import pivot
from dbt_automation.operations.unpivot import unpivot

OPERATIONS_DICT = {
"flatten": flatten_operation,
Expand All @@ -50,6 +51,7 @@
"aggregate": aggregate,
"casewhen": casewhen,
"pivot": pivot,
"unpivot": unpivot,
}

load_dotenv("./../dbconnection.env")
Expand Down
48 changes: 48 additions & 0 deletions tests/warehouse/test_bigquery_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from dbt_automation.operations.aggregate import aggregate
from dbt_automation.operations.casewhen import casewhen
from dbt_automation.operations.pivot import pivot
from dbt_automation.operations.unpivot import unpivot


basicConfig(level=INFO)
Expand Down Expand Up @@ -763,6 +764,53 @@ def test_mergetables(self):

assert len(table_data1) + len(table_data2) == len(table_data_union)

def test_unpivot(self):
"""test unpivot operation"""
wc_client = TestBigqueryOperations.wc_client
output_name = "unpivot_op"

config = {
"input": {
"input_type": "model",
"input_name": "_airbyte_raw_Sheet2",
"source_name": None,
},
"dest_schema": "pytest_intermediate",
"output_name": output_name,
"source_columns": [
"NGO",
"SPOC",
"Month",
"measure1",
"_airbyte_ab_id",
"measure2",
"Indicator",
],
"exclude_columns": [],
"unpivot_columns": ["NGO", "SPOC"],
"unpivot_field_name": "col_field",
"unpivot_value_name": "col_val",
}

unpivot(
config,
wc_client,
TestBigqueryOperations.test_project_dir,
)

TestBigqueryOperations.execute_dbt("run", output_name)

cols = [
col_dict["name"]
for col_dict in wc_client.get_table_columns(
"pytest_intermediate", output_name
)
]
assert len(cols) == 2
assert sorted(cols) == sorted(
[config["unpivot_field_name"], config["unpivot_value_name"]]
)

def test_merge_operation(self):
"""test merge_operation"""
wc_client = TestBigqueryOperations.wc_client
Expand Down
Loading

0 comments on commit ef626df

Please sign in to comment.