Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(targets): Correctly serialize decimal.Decimal in JSON fields of SQL targets #1898

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
name = "singer-sdk"
version = "0.30.0"
description = "A framework for building Singer taps"
authors = ["Meltano Team and Contributors"]
maintainers = ["Meltano Team and Contributors"]
authors = ["Meltano Team and Contributors <[email protected]>"]
maintainers = ["Meltano Team and Contributors <[email protected]>"]
readme = "README.md"
homepage = "https://sdk.meltano.com/en/latest/"
repository = "https://github.com/meltano/sdk"
Expand Down Expand Up @@ -144,7 +144,7 @@ name = "cz_version_bump"
version = "0.30.0"
tag_format = "v$major.$minor.$patch$prerelease"
version_files = [
"docs/conf.py",
"docs/conf.py:^release =",
"pyproject.toml:^version =",
"cookiecutter/tap-template/{{cookiecutter.tap_id}}/pyproject.toml:singer-sdk",
"cookiecutter/target-template/{{cookiecutter.target_id}}/pyproject.toml:singer-sdk",
Expand Down
42 changes: 41 additions & 1 deletion singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

from __future__ import annotations

import decimal
import json
import logging
import typing as t
import warnings
from contextlib import contextmanager
from datetime import datetime
from functools import lru_cache

import simplejson
import sqlalchemy
from sqlalchemy.engine import Engine

Expand Down Expand Up @@ -316,7 +319,12 @@ def create_engine(self) -> Engine:
Returns:
A new SQLAlchemy Engine.
"""
return sqlalchemy.create_engine(self.sqlalchemy_url, echo=False)
return sqlalchemy.create_engine(
self.sqlalchemy_url,
echo=False,
json_serializer=self.serialize_json,
json_deserializer=self.deserialize_json,
)

def quote(self, name: str) -> str:
"""Quote a name if it needs quoting, using '.' as a name-part delimiter.
Expand Down Expand Up @@ -1124,3 +1132,35 @@ def _adapt_column_type(
)
with self._connect() as conn:
conn.execute(alter_column_ddl)

def serialize_json(self, obj: object) -> str:
"""Serialize an object to a JSON string.

Target connectors may override this method to provide custom serialization logic
for JSON types.

Args:
obj: The object to serialize.

Returns:
The JSON string.

.. versionadded:: 0.31.0
"""
return simplejson.dumps(obj, use_decimal=True)

def deserialize_json(self, json_str: str) -> object:
"""Deserialize a JSON string to an object.

Tap connectors may override this method to provide custom deserialization
logic for JSON types.

Args:
json_str: The JSON string to deserialize.

Returns:
The deserialized object.

.. versionadded:: 0.31.0
"""
return json.loads(json_str, parse_float=decimal.Decimal)
25 changes: 25 additions & 0 deletions tests/core/test_connector_sql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from decimal import Decimal
from unittest import mock

import pytest
Expand Down Expand Up @@ -258,3 +259,27 @@ def test_merge_generic_sql_types(
):
merged_type = connector.merge_sql_types(types)
assert isinstance(merged_type, expected_type)

def test_engine_json_serialization(self, connector: SQLConnector):
engine = connector._engine
meta = sqlalchemy.MetaData()
table = sqlalchemy.Table(
"test_table",
meta,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column("attrs", sqlalchemy.JSON),
)
meta.create_all(engine)
with engine.connect() as conn:
conn.execute(
table.insert(),
[
{"attrs": {"x": Decimal("1.0")}},
{"attrs": {"x": Decimal("2.0"), "y": [1, 2, 3]}},
],
)
result = conn.execute(table.select())
assert result.fetchall() == [
(1, {"x": Decimal("1.0")}),
(2, {"x": Decimal("2.0"), "y": [1, 2, 3]}),
]