Skip to content

Commit

Permalink
Safe_cast macro addition (#198)
Browse files Browse the repository at this point in the history
* added safe_cast, removed format_row macro, removed terminal logging og query_band

* added test case for safe_cast testing
  • Loading branch information
VarunSharma15 authored Nov 4, 2024
1 parent 89f7cc7 commit 90e0494
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 42 deletions.
1 change: 0 additions & 1 deletion dbt/adapters/teradata/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,6 @@ def apply_query_band(cls, handle, query_band_text):
cur.execute("sel GetQueryBand();")
rows = cur.fetchone()
logger.debug("Query Band set to {}".format(rows)) # To log in dbt.log
logger.info("Query Band set to {}".format(rows)) # To log in terminal
except teradatasql.Error as ex:
logger.debug(ex)
logger.info("Please verify query_band parameter in profiles.yml file")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,41 +82,6 @@ from SYS_CALENDAR.CALENDAR where day_of_calendar = 1
{% endmacro %}


-- We need to override "format_row" macro to remove the safe_cast() used in the default implementation.
-- We had to remove safe_cast because N/A was being picked as column_type in safe_casting, which was later running into issues

{%- macro format_row(row, column_name_to_data_types) -%}
{#-- generate case-insensitive formatted row --#}
{% set formatted_row = {} %}
{%- for column_name, column_value in row.items() -%}
{% set column_name = column_name|lower %}

{%- if column_name not in column_name_to_data_types %}
{#-- if user-provided row contains column name that relation does not contain, raise an error --#}
{% set fixture_name = "expected output" if model.resource_type == 'unit_test' else ("'" ~ model.name ~ "'") %}
{{ exceptions.raise_compiler_error(
"Invalid column name: '" ~ column_name ~ "' in unit test fixture for " ~ fixture_name ~ "."
"\nAccepted columns for " ~ fixture_name ~ " are: " ~ (column_name_to_data_types.keys()|list)
) }}
{%- endif -%}

{%- set column_type = column_name_to_data_types[column_name] %}

{#-- sanitize column_value: wrap yaml strings in quotes, apply cast --#}
{%- set column_value_clean = column_value -%}
{%- if column_value is string -%}
{%- set column_value_clean = dbt.string_literal(dbt.escape_single_quotes(column_value)) -%}
{%- elif column_value is none -%}
{%- set column_value_clean = 'null' -%}
{%- endif -%}

{%- set row_update = {column_name: column_value_clean} -%}
{%- do formatted_row.update(row_update) -%}
{%- endfor -%}
{{ return(formatted_row) }}
{%- endmacro -%}



-- Overridden "get_unit_test_sql" macro to avoid right truncation of data
-- We are selecting "dbt_internal_unit_test_expected" first then doing union all with "dbt_internal_unit_test_actual"
Expand Down
3 changes: 3 additions & 0 deletions dbt/include/teradata/macros/materializations/unit/unit.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
{%- materialization unit, adapter='teradata' -%}

-- calling the macro set_query_band() which will set the query_band for this materialization as per the user_configuration
{% do set_query_band() %}

{% set relations = [] %}

{% set expected_rows = config.get('expected_rows') %}
Expand Down
4 changes: 4 additions & 0 deletions dbt/include/teradata/macros/utils/safe_cast.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{% macro teradata__safe_cast(field, type) %}
{%- set field_as_string = dbt.string_literal(field) if field is number else field -%}
trycast({{field_as_string}} as {{type}})
{% endmacro %}
105 changes: 99 additions & 6 deletions tests/functional/adapter/test_unit_tests.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,47 @@
import pytest
from dbt.tests.util import write_file, run_dbt
from dbt.tests.adapter.unit_testing.test_types import BaseUnitTestingTypes
from dbt.tests.adapter.unit_testing.test_case_insensitivity import BaseUnitTestCaseInsensivity
from dbt.tests.adapter.unit_testing.test_invalid_input import BaseUnitTestInvalidInput


safe_cast_sql = """
select
cast(substr(opened_at,1,10) AS date format 'yyyy-mm-dd') as opened_date from {{ ref('seed')}}
"""

seed_csv = """
id,name,tax_rate,opened_at
1,Philadelphia,0.2,2016-09-01T00:00:00
2,New York,0.22,2017-03-15T00:00:00
3,Los Angeles,0.18,2018-06-10T00:00:00
""".lstrip()

test_safe_cast_yml = """
unit_tests:
- name: test_safe_cast
model: safe_cast
given:
- input: ref('seed')
rows:
- {opened_at: "2023-05-14T00:00:00"}
expect:
rows:
- {opened_date: 2023-05-14}
"""

class TestTestingTypesTeradata(BaseUnitTestingTypes):

@pytest.fixture(scope="class")
def project_config_update(self):
return {
"name": "test_testing_types",
"models": {
"test_testing_types": {
"materialized": "table"
}
}
}
@pytest.fixture
def data_types(self):
# sql_value, yaml_value
Expand All @@ -14,16 +51,72 @@ def data_types(self):
["'true'", "'true'"],
["DATE '2020-01-02'", "2020-01-02"],
["TIMESTAMP '2013-11-03 00:00:00'", "2013-11-03 00:00:00"],
[
"""cast('{"bar": "baz", "balance": 7.77, "active": false}'as json)""",
"""'{"bar": "baz", "balance": 7.77, "active": false}'""",
],
# [
# """cast('{"bar": "baz", "balance": 7.77, "active": false}'as json)""",
# """'{"bar": "baz", "balance": 7.77, "active": false}'""",
# ],
]
# had to comment the last testcase related to the json data type because it was failing with below error
#[Teradata Database] [Error 5771] Index not supported by UDT 'TD_JSONLATIN_LOB'. Indexes are not supported for LOB UDTs.


class TestUnitTestCaseInsensitivityTeradata(BaseUnitTestCaseInsensivity):
pass
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"name": "test_case_insensitivity",
"models":{
"test_case_insensitivity":{
"materialized": "table"
}
}
}



class TestUnitTestInvalidInput(BaseUnitTestInvalidInput):
pass
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"name": "test_unit_test_invalid_input",
"models": {
"test_unit_test_invalid_input": {
"materialized": "table"
}
}
}

class TestSafeCast():

@pytest.fixture(scope="class")
def project_config_update(self):
return {
"name": "test_safe_cast",
"seeds":{
"test_safe_cast":{
"seed":{
"+column_types":{
"opened_at": "varchar(20)"
}
}
}
}
}

@pytest.fixture(scope="class")
def seeds(self):
return{
"seed.csv": seed_csv
}
@pytest.fixture(scope="class")
def models(self):
return {
"safe_cast.sql": safe_cast_sql,
"test_safe_cast.yml": test_safe_cast_yml
}

def test_safe_cast(self, project):
result1 = run_dbt(["seed"])
results = run_dbt(["run"])

results = run_dbt(["test"])

0 comments on commit 90e0494

Please sign in to comment.