From 2f9de3ea51a3bfff986ef7eca78d8941820307e9 Mon Sep 17 00:00:00 2001 From: Yohann Jardin Date: Mon, 7 Oct 2024 20:10:38 +0200 Subject: [PATCH] Feat: Support Snowflake's travel time --- airbyte/_processors/sql/snowflake.py | 11 ++++ airbyte/shared/sql_processor.py | 7 +++ tests/unit_tests/test_processors.py | 75 ++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+) create mode 100644 tests/unit_tests/test_processors.py diff --git a/airbyte/_processors/sql/snowflake.py b/airbyte/_processors/sql/snowflake.py index fdbc47b5..bf9b9844 100644 --- a/airbyte/_processors/sql/snowflake.py +++ b/airbyte/_processors/sql/snowflake.py @@ -42,6 +42,17 @@ class SnowflakeConfig(SqlConfig): database: str role: str schema_name: str = Field(default=DEFAULT_CACHE_SCHEMA_NAME) + data_retention_time_in_days: int | None = None + + @overrides + def get_create_table_extra_clauses(self) -> list[str]: + """Return a list of clauses to append on CREATE TABLE statements.""" + clauses = [] + + if self.data_retention_time_in_days is not None: + clauses.append(f"DATA_RETENTION_TIME_IN_DAYS = {self.data_retention_time_in_days}") + + return clauses @overrides def get_database_name(self) -> str: diff --git a/airbyte/shared/sql_processor.py b/airbyte/shared/sql_processor.py index d4899b7e..b52c2593 100644 --- a/airbyte/shared/sql_processor.py +++ b/airbyte/shared/sql_processor.py @@ -124,6 +124,10 @@ def config_hash(self) -> str | None: ) ) + def get_create_table_extra_clauses(self) -> list[str]: + """Return a list of clauses to append on CREATE TABLE statements.""" + return [] + def get_sql_engine(self) -> Engine: """Return a new SQL engine to use.""" return create_engine( @@ -653,10 +657,13 @@ def _create_table( pk_str = ", ".join(primary_keys) column_definition_str += f",\n PRIMARY KEY ({pk_str})" + extra_clauses = "\n".join(self.sql_config.get_create_table_extra_clauses()) + cmd = f""" CREATE TABLE {self._fully_qualified(table_name)} ( {column_definition_str} ) + {extra_clauses} """ _ = self._execute_sql(cmd) diff --git a/tests/unit_tests/test_processors.py b/tests/unit_tests/test_processors.py new file mode 100644 index 00000000..6781be01 --- /dev/null +++ b/tests/unit_tests/test_processors.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations + +from pathlib import Path +from typing import Optional +import pytest_mock +from airbyte.caches.snowflake import SnowflakeSqlProcessor, SnowflakeConfig +from airbyte_protocol.models import ConfiguredAirbyteCatalog +from airbyte.secrets.base import SecretString +from airbyte.shared.catalog_providers import CatalogProvider + + +def test_snowflake_cache_config_data_retention_time_in_days( + mocker: pytest_mock.MockFixture, +): + expected_cmd = """ + CREATE TABLE airbyte_raw."table_name" ( + col_name type + ) + DATA_RETENTION_TIME_IN_DAYS = 1 + """ + + def _execute_sql(cmd): + global actual_cmd + actual_cmd = cmd + + mocker.patch.object(SnowflakeSqlProcessor, "_execute_sql", side_effect=_execute_sql) + config = _build_mocked_snowflake_processor(mocker, data_retention_time_in_days=1) + config._create_table(table_name="table_name", column_definition_str="col_name type") + + assert actual_cmd == expected_cmd + + +def test_snowflake_cache_config_no_data_retention_time_in_days( + mocker: pytest_mock.MockFixture, +): + expected_cmd = """ + CREATE TABLE airbyte_raw."table_name" ( + col_name type + ) + \n """ + + def _execute_sql(cmd): + global actual_cmd + actual_cmd = cmd + + mocker.patch.object(SnowflakeSqlProcessor, "_execute_sql", side_effect=_execute_sql) + config = _build_mocked_snowflake_processor(mocker) + config._create_table(table_name="table_name", column_definition_str="col_name type") + + assert actual_cmd == expected_cmd + + +def _build_mocked_snowflake_processor( + mocker: pytest_mock.MockFixture, data_retention_time_in_days: Optional[int] = None +): + sql_config = SnowflakeConfig( + account="foo", + username="foo", + password=SecretString("foo"), + warehouse="foo", + database="foo", + role="foo", + data_retention_time_in_days=data_retention_time_in_days, + ) + + mocker.patch.object( + SnowflakeSqlProcessor, "_ensure_schema_exists", return_value=None + ) + return SnowflakeSqlProcessor( + catalog_provider=CatalogProvider(ConfiguredAirbyteCatalog(streams=[])), + temp_dir=Path(), + temp_file_cleanup=True, + sql_config=sql_config, + )