diff --git a/lib/columnutils.py b/lib/columnutils.py index 5af867d..8a6949b 100644 --- a/lib/columnutils.py +++ b/lib/columnutils.py @@ -47,3 +47,13 @@ def fmt_colname(colname: str, warehouse: str): return colname.lower() else: raise ValueError(f"unsupported warehouse: {warehouse}") + + +def quote_columnname(colname: str, warehouse: str): + """encloses the column name within the appropriate quotes""" + if warehouse == "postgres": + return '"' + colname + '"' + elif warehouse == "bigquery": + return "`" + colname + "`" + else: + raise ValueError(f"unsupported warehouse: {warehouse}") diff --git a/main.py b/main.py index bb99c56..6fe92f4 100644 --- a/main.py +++ b/main.py @@ -11,12 +11,14 @@ from operations.mergetables import union_tables from operations.syncsources import sync_sources from operations.castdatatypes import cast_datatypes +from operations.coalescecolumns import coalesce_columns OPERATIONS_DICT = { "flatten": flatten_operation, "unionall": union_tables, "syncsources": sync_sources, "castdatatypes": cast_datatypes, + "coalescecolumns": coalesce_columns, } load_dotenv("dbconnection.env") @@ -52,6 +54,7 @@ warehouse = config_data["warehouse"] # run operations to generate dbt model(s) +# pylint:disable=logging-fstring-interpolation for op_data in config_data["operations"]: op_type = op_data["type"] config = op_data["config"] diff --git a/operations.yaml.template b/operations.yaml.template index 4a93407..c6eb9fa 100644 --- a/operations.yaml.template +++ b/operations.yaml.template @@ -18,4 +18,22 @@ operations: config: source_name: source_schema: + - type: castdatatypes + config: + dest_schema: + input_name: + output_name: + columns: + - columnname: + columntype: + - type: coalescecolumns + config: + dest_schema: + input_name: + output_name: + columns: + - columnname: + - columnname: + - ... + output_column_name: diff --git a/operations/castdatatypes.py b/operations/castdatatypes.py index 11f2150..9287718 100644 --- a/operations/castdatatypes.py +++ b/operations/castdatatypes.py @@ -3,6 +3,7 @@ from logging import basicConfig, getLogger, INFO from lib.dbtproject import dbtProject +from lib.columnutils import quote_columnname basicConfig(level=INFO) logger = getLogger() @@ -37,11 +38,11 @@ def cast_datatypes(config: dict, warehouse: str, project_dir: str): ) union_code += ( ", CAST(" - + column["columnname"] + + quote_columnname(column["columnname"], warehouse) + " AS " + warehouse_column_type + ") AS " - + column["columnname"] + + quote_columnname(column["columnname"], warehouse) ) union_code += " FROM " + "{{ref('" + input_name + "')}}" + "\n" diff --git a/operations/coalescecolumns.py b/operations/coalescecolumns.py new file mode 100644 index 0000000..e1dfa63 --- /dev/null +++ b/operations/coalescecolumns.py @@ -0,0 +1,44 @@ +"""generates a model which coalesces columns""" + +from logging import basicConfig, getLogger, INFO + +from lib.dbtproject import dbtProject +from lib.columnutils import quote_columnname + +basicConfig(level=INFO) +logger = getLogger() + + +# pylint:disable=unused-argument,logging-fstring-interpolation +def coalesce_columns(config: dict, warehouse: str, project_dir: str): + """coalesces columns""" + dest_schema = config["dest_schema"] + output_name = config["output_name"] + input_name = config["input_name"] + + dbtproject = dbtProject(project_dir) + dbtproject.ensure_models_dir(dest_schema) + + union_code = "{{ config(materialized='table',) }}\n" + + columns = config["columns"] + columnnames = [c["columnname"] for c in columns] + union_code += "SELECT {{dbt_utils.star(from=ref('" + input_name + "'), except=[" + union_code += ",".join([f'"{columnname}"' for columnname in columnnames]) + union_code += "])}}" + + union_code += ", COALESCE(" + + for column in config["columns"]: + union_code += quote_columnname(column["columnname"], warehouse) + ", " + union_code = union_code[:-2] + ") AS " + config["output_column_name"] + + union_code += " FROM " + "{{ref('" + input_name + "')}}" + "\n" + + logger.info(f"writing dbt model {union_code}") + dbtproject.write_model( + dest_schema, + output_name, + union_code, + ) + logger.info(f"dbt model {output_name} successfully created")