Skip to content

Commit

Permalink
feat: Bump SQLAlchemy dependency to 2.0.34 (#220)
Browse files Browse the repository at this point in the history
Closes #32.
  • Loading branch information
edgarrmondragon authored Sep 16, 2024
1 parent c6c6305 commit ec5c2bb
Show file tree
Hide file tree
Showing 6 changed files with 461 additions and 415 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ jobs:
TARGET_SNOWFLAKE_ROLE: ${{secrets.TARGET_SNOWFLAKE_ROLE}}
strategy:
fail-fast: false
max-parallel: 2
matrix:
python-version:
- "3.12"
Expand Down
800 changes: 423 additions & 377 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ python = ">=3.8"
cryptography = ">=40,<44"
snowflake-sqlalchemy = "~=1.6.1"
snowflake-connector-python = { version = "<4.0.0", extras = ["secure-local-storage"] }
sqlalchemy = "<2"
sqlalchemy = "~=2.0.31"

[tool.poetry.dependencies.singer-sdk]
version = "~=0.40.0a1"
Expand Down
14 changes: 7 additions & 7 deletions target_snowflake/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,23 +469,23 @@ def put_batches_to_stage(self, sync_id: str, files: Sequence[str]) -> None:
sync_id: The sync ID for the batch.
files: The files containing records to upload.
"""
with self._connect() as conn:
with self._connect() as conn, conn.begin():
for file_uri in files:
put_statement, kwargs = self._get_put_statement(
sync_id=sync_id,
file_uri=file_uri,
)
# sqlalchemy.text stripped a slash, which caused windows to fail so we used bound parameters instead
# See https://github.com/MeltanoLabs/target-snowflake/issues/87 for more information about this error
conn.execute(put_statement, file_uri=file_uri, **kwargs)
conn.execute(put_statement, {"file_uri": file_uri, **kwargs})

def create_file_format(self, file_format: str) -> None:
"""Create a file format in the schema.
Args:
file_format: The name of the file format.
"""
with self._connect() as conn:
with self._connect() as conn, conn.begin():
file_format_statement, kwargs = self._get_file_format_statement(
file_format=file_format,
)
Expand All @@ -510,7 +510,7 @@ def merge_from_stage(
schema: The schema of the data.
key_properties: The primary key properties of the data.
"""
with self._connect() as conn:
with self._connect() as conn, conn.begin():
merge_statement, kwargs = self._get_merge_from_stage_statement(
full_table_name=full_table_name,
schema=schema,
Expand All @@ -536,7 +536,7 @@ def copy_from_stage(
sync_id: The sync ID for the batch.
file_format: The name of the file format.
"""
with self._connect() as conn:
with self._connect() as conn, conn.begin():
copy_statement, kwargs = self._get_copy_statement(
full_table_name=full_table_name,
schema=schema,
Expand All @@ -552,7 +552,7 @@ def drop_file_format(self, file_format: str) -> None:
Args:
file_format: The name of the file format.
"""
with self._connect() as conn:
with self._connect() as conn, conn.begin():
drop_statement, kwargs = self._get_drop_file_format_statement(
file_format=file_format,
)
Expand All @@ -565,7 +565,7 @@ def remove_staged_files(self, sync_id: str) -> None:
Args:
sync_id: The sync ID for the batch.
"""
with self._connect() as conn:
with self._connect() as conn, conn.begin():
remove_statement, kwargs = self._get_stage_files_remove_statement(
sync_id=sync_id,
)
Expand Down
54 changes: 26 additions & 28 deletions tests/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
import snowflake.sqlalchemy.custom_types as sct
import sqlalchemy
import sqlalchemy as sa
from singer_sdk.testing.suites import TestSuite
from singer_sdk.testing.target_tests import (
TargetArrayData,
Expand Down Expand Up @@ -33,7 +33,7 @@ def validate(self) -> None:
f"{self.target.config['database']}.{self.target.config['default_target_schema']}.test_{self.name}".upper()
)
result = connector.connection.execute(
f"select * from {table} order by 1",
sa.text(f"select * from {table} order by 1"),
)
assert result.rowcount == 4
row = result.first()
Expand All @@ -45,7 +45,7 @@ def validate(self) -> None:
assert row[1] == '[\n "apple",\n "orange",\n "pear"\n]'
table_schema = connector.get_table(table)
expected_types = {
"id": sqlalchemy.DECIMAL,
"id": sa.DECIMAL,
"fruits": sct.VARIANT,
"_sdc_extracted_at": sct.TIMESTAMP_NTZ,
"_sdc_batched_at": sct.TIMESTAMP_NTZ,
Expand All @@ -69,8 +69,8 @@ def validate(self) -> None:
table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.ForecastingTypeToCategory".upper() # noqa: E501
table_schema = connector.get_table(table)
expected_types = {
"id": sqlalchemy.VARCHAR,
"isdeleted": sqlalchemy.types.BOOLEAN,
"id": sa.VARCHAR,
"isdeleted": sa.types.BOOLEAN,
"createddate": sct.TIMESTAMP_NTZ,
"createdbyid": sct.STRING,
"lastmodifieddate": sct.TIMESTAMP_NTZ,
Expand All @@ -79,8 +79,8 @@ def validate(self) -> None:
"forecastingtypeid": sct.STRING,
"forecastingitemcategory": sct.STRING,
"displayposition": sct.NUMBER,
"isadjustable": sqlalchemy.types.BOOLEAN,
"isowneradjustable": sqlalchemy.types.BOOLEAN,
"isadjustable": sa.types.BOOLEAN,
"isowneradjustable": sa.types.BOOLEAN,
"age": sct.NUMBER,
"newcamelcasedattribute": sct.STRING,
"_attribute_startswith_underscore": sct.STRING,
Expand All @@ -107,7 +107,7 @@ def validate(self) -> None:
f"{self.target.config['database']}.{self.target.config['default_target_schema']}.test_{self.name}".upper()
)
result = connector.connection.execute(
f"select * from {table} order by 1",
sa.text(f"select * from {table} order by 1"),
)
expected_value = {
1: 100,
Expand Down Expand Up @@ -150,7 +150,7 @@ def validate(self) -> None:
f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{self.stream_name}".upper()
)
connector.connection.execute(
f"select * from {table} order by 1",
sa.text(f"select * from {table} order by 1"),
)

table_schema = connector.get_table(table)
Expand Down Expand Up @@ -185,7 +185,7 @@ def validate(self) -> None:
f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{table_name}".upper()
)
connector.connection.execute(
f"select * from {table} order by 1",
sa.text(f"select * from {table} order by 1"),
)
# TODO: more assertions

Expand Down Expand Up @@ -241,7 +241,7 @@ def validate(self) -> None:
f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{table_name}".upper()
)
result = connector.connection.execute(
f"select * from {table} order by 1",
sa.text(f"select * from {table} order by 1"),
)
assert result.rowcount == 2
row = result.first()
Expand Down Expand Up @@ -276,7 +276,7 @@ def validate(self) -> None:
f"{self.target.config['database']}.{self.target.config['default_target_schema']}.test_{self.name}".upper()
)
result = connector.connection.execute(
f"select * from {table} order by 1",
sa.text(f"select * from {table} order by 1"),
)
assert result.rowcount == 6
row = result.first()
Expand All @@ -291,7 +291,7 @@ def validate(self) -> None:
"id": sct.NUMBER,
"a1": sct.DOUBLE,
"a2": sct.STRING,
"a3": sqlalchemy.types.BOOLEAN,
"a3": sa.types.BOOLEAN,
"a4": sct.VARIANT,
"a5": sct.VARIANT,
"a6": sct.NUMBER,
Expand Down Expand Up @@ -325,7 +325,7 @@ def validate(self) -> None:
connector = self.target.default_sink_class.connector_class(self.target.config)
table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{self.name}".upper()
result = connector.connection.execute(
f"select * from {table}",
sa.text(f"select * from {table}"),
)
assert result.rowcount == 2
row = result.first()
Expand All @@ -347,7 +347,7 @@ def validate(self) -> None:
connector = self.target.default_sink_class.connector_class(self.target.config)
table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{self.name}".upper()
result = connector.connection.execute(
f"select * from {table}",
sa.text(f"select * from {table}"),
)
assert result.rowcount == 1
row = result.first()
Expand All @@ -366,7 +366,7 @@ def validate(self) -> None:
connector = self.target.default_sink_class.connector_class(self.target.config)
table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{self.name}".upper()
result = connector.connection.execute(
f"select * from {table}",
sa.text(f"select * from {table}"),
)
assert result.rowcount == 1
row = result.first()
Expand Down Expand Up @@ -400,7 +400,7 @@ def setup(self) -> None:
connector = self.target.default_sink_class.connector_class(self.target.config)
table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{self.name}".upper()
connector.connection.execute(
f"""
sa.text(f"""
CREATE OR REPLACE TABLE {table} (
ID VARCHAR(16777216),
COL_STR VARCHAR(16777216),
Expand All @@ -416,14 +416,14 @@ def setup(self) -> None:
_SDC_TABLE_VERSION NUMBER(38,0),
PRIMARY KEY (ID)
)
""",
"""),
)

def validate(self) -> None:
connector = self.target.default_sink_class.connector_class(self.target.config)
table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{self.name}".upper()
result = connector.connection.execute(
f"select * from {table}",
sa.text(f"select * from {table}"),
)
assert result.rowcount == 1
row = result.first()
Expand All @@ -438,7 +438,7 @@ def setup(self) -> None:
connector = self.target.default_sink_class.connector_class(self.target.config)
table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{self.name}".upper()
connector.connection.execute(
f"""
sa.text(f"""
CREATE OR REPLACE TABLE {table} (
ID VARCHAR(16777216),
COL_STR VARCHAR(16777216),
Expand All @@ -454,7 +454,7 @@ def setup(self) -> None:
_SDC_TABLE_VERSION NUMBER(38,0),
PRIMARY KEY (ID)
)
""",
"""),
)


Expand All @@ -471,7 +471,7 @@ def setup(self) -> None:
connector = self.target.default_sink_class.connector_class(self.target.config)
table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.\"order\"".upper()
connector.connection.execute(
f"""
sa.text(f"""
CREATE OR REPLACE TABLE {table} (
ID VARCHAR(16777216),
COL_STR VARCHAR(16777216),
Expand All @@ -487,7 +487,7 @@ def setup(self) -> None:
_SDC_TABLE_VERSION NUMBER(38,0),
PRIMARY KEY (ID)
)
""",
"""),
)


Expand All @@ -505,9 +505,7 @@ def singer_filepath(self) -> Path:
def validate(self) -> None:
connector = self.target.default_sink_class.connector_class(self.target.config)
table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.\"order\"".upper()
result = connector.connection.execute(
f"select * from {table}",
)
result = connector.connection.execute(sa.text(f"select * from {table}"))
assert result.rowcount == 1
row = result.first()
assert len(row) == 13, f"Row has unexpected length {len(row)}"
Expand Down Expand Up @@ -554,13 +552,13 @@ def setup(self) -> None:
table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{self.name}".upper()
# Seed the 2 columns from tap schema and an unused third column to assert explicit inserts are working
connector.connection.execute(
f"""
sa.text(f"""
CREATE OR REPLACE TABLE {table} (
COL1 VARCHAR(16777216),
COL3 TIMESTAMP_NTZ(9),
COL2 BOOLEAN
)
""",
"""),
)

@property
Expand Down
5 changes: 3 additions & 2 deletions tests/test_target_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any

import pytest
import sqlalchemy as sa
from singer_sdk.testing import TargetTestRunner, get_target_test_class

from target_snowflake.target import TargetSnowflake
Expand Down Expand Up @@ -46,11 +47,11 @@ def resource(self, runner, connection):
https://github.com/meltano/sdk/tree/main/tests/samples
"""
connection.execute(
f"create schema {runner.config['database']}.{runner.config['default_target_schema']}",
sa.text(f"create schema {runner.config['database']}.{runner.config['default_target_schema']}"),
)
yield
connection.execute(
f"drop schema if exists {runner.config['database']}.{runner.config['default_target_schema']}",
sa.text(f"drop schema if exists {runner.config['database']}.{runner.config['default_target_schema']}"),
)


Expand Down

0 comments on commit ec5c2bb

Please sign in to comment.