Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add sql-datatype to the SDK discovery and catalog #1872

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions singer_sdk/_singerlib/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@

from singer_sdk._singerlib.schema import Schema

if t.TYPE_CHECKING:
from typing_extensions import TypeAlias


Breadcrumb = tuple[str, ...]

logger = logging.getLogger(__name__)
Expand All @@ -35,7 +31,7 @@ def __missing__(self, breadcrumb: Breadcrumb) -> bool:


@dataclass
class Metadata:
class _BaseMetadata:
"""Base stream or property metadata."""

class InclusionType(str, enum.Enum):
Expand All @@ -50,7 +46,7 @@ class InclusionType(str, enum.Enum):
selected_by_default: bool | None = None

@classmethod
def from_dict(cls: type[Metadata], value: dict[str, t.Any]) -> Metadata:
def from_dict(cls: type[_BaseMetadata], value: dict[str, t.Any]) -> _BaseMetadata:
"""Parse metadata dictionary.

Args:
Expand Down Expand Up @@ -82,6 +78,11 @@ def to_dict(self) -> dict[str, t.Any]:
return result


@dataclass
class Metadata(_BaseMetadata):
sql_datatype: str | None = None


@dataclass
class StreamMetadata(Metadata):
"""Stream metadata."""
Expand All @@ -93,10 +94,7 @@ class StreamMetadata(Metadata):
schema_name: str | None = None


AnyMetadata: TypeAlias = t.Union[Metadata, StreamMetadata]


class MetadataMapping(dict[Breadcrumb, AnyMetadata]):
class MetadataMapping(dict[Breadcrumb, _BaseMetadata]):
"""Stream metadata mapping."""

@classmethod
Expand Down Expand Up @@ -133,7 +131,7 @@ def to_list(self) -> list[dict[str, t.Any]]:
{"breadcrumb": list(k), "metadata": v.to_dict()} for k, v in self.items()
]

def __missing__(self, breadcrumb: Breadcrumb) -> AnyMetadata:
def __missing__(self, breadcrumb: Breadcrumb) -> _BaseMetadata:
"""Handle missing metadata entries.

Args:
Expand Down Expand Up @@ -164,6 +162,7 @@ def get_standard_metadata(
valid_replication_keys: list[str] | None = None,
replication_method: str | None = None,
selected_by_default: bool | None = None,
sql_datatypes: dict[str, str] | None = None,
) -> MetadataMapping:
"""Get default metadata for a stream.

Expand All @@ -174,6 +173,7 @@ def get_standard_metadata(
valid_replication_keys: Stream valid replication keys.
replication_method: Stream replication method.
selected_by_default: Whether the stream is selected by default.
sql_datatypes: SQL datatypes present in the stream.

Returns:
Metadata mapping.
Expand All @@ -193,12 +193,23 @@ def get_standard_metadata(
root.schema_name = schema_name

for field_name in schema.get("properties", {}):
if sql_datatypes and field_name in sql_datatypes:
sql_datatype = sql_datatypes[field_name]
else:
sql_datatype = None

if (key_properties and field_name in key_properties) or (
valid_replication_keys and field_name in valid_replication_keys
):
entry = Metadata(inclusion=Metadata.InclusionType.AUTOMATIC)
entry = Metadata(
inclusion=Metadata.InclusionType.AUTOMATIC,
sql_datatype=sql_datatype,
)
else:
entry = Metadata(inclusion=Metadata.InclusionType.AVAILABLE)
entry = Metadata(
inclusion=Metadata.InclusionType.AVAILABLE,
sql_datatype=sql_datatype,
)

mapping["properties", field_name] = entry

Expand Down
28 changes: 28 additions & 0 deletions singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,30 @@ def get_object_names(
view_names = []
return [(t, False) for t in table_names] + [(v, True) for v in view_names]

def _sa_type_to_str(self, data_type: sa.types.TypeEngine) -> str: # noqa: PLR6301
"""Retrun SQL Datatype as a string to utilize in the catalog.

Args:
data_type: given data type as sqlalchemy.types.TypeEngine

Returns:
A string description the given data type example "VARCHAR(length=15)".
"""
datatype_attributes = ("length", "scale", "precision")

catalog_format = f"{type(data_type).__name__}("

for attribute in datatype_attributes:
if hasattr(data_type, attribute) and getattr(data_type, attribute):
catalog_format += f"{attribute}={(getattr(data_type, attribute))}, "

if catalog_format.endswith(", "):
catalog_format = catalog_format[:-2]

catalog_format += ")"

return catalog_format

# TODO maybe should be splitted into smaller parts?
def discover_catalog_entry(
self,
Expand Down Expand Up @@ -881,6 +905,7 @@ def discover_catalog_entry(

# Initialize columns list
table_schema = th.PropertiesList()
datatypes: dict[str, str] = {}
for column_def in inspected.get_columns(table_name, schema=schema_name):
column_name = column_def["name"]
is_nullable = column_def.get("nullable", False)
Expand All @@ -893,6 +918,8 @@ def discover_catalog_entry(
required=column_name in key_properties if key_properties else False,
),
)
datatypes[column_def["name"]] = self._sa_type_to_str(column_def["type"])

schema = table_schema.to_dict()

# Initialize available replication methods
Expand All @@ -919,6 +946,7 @@ def discover_catalog_entry(
replication_method=replication_method,
key_properties=key_properties,
valid_replication_keys=None, # Must be defined by user
sql_datatypes=datatypes,
),
database=None, # Expects single-database context
row_count=None,
Expand Down
7 changes: 6 additions & 1 deletion tests/samples/test_tap_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from samples.sample_tap_sqlite import SQLiteTap
from samples.sample_target_csv.csv_target import SampleTargetCSV
from singer_sdk import SQLStream
from singer_sdk._singerlib import MetadataMapping, StreamMetadata
from singer_sdk._singerlib import Metadata, MetadataMapping, StreamMetadata
from singer_sdk.testing import (
get_standard_tap_tests,
tap_sync_test,
Expand Down Expand Up @@ -83,6 +83,11 @@ def test_sqlite_discovery(sqlite_sample_tap: SQLTap):

assert stream.metadata.root.table_key_properties == ["c1"]
assert stream.primary_keys == ["c1"]

field_metadata = stream.metadata["properties", "c1"]
assert isinstance(field_metadata, Metadata)
assert field_metadata.sql_datatype == "INTEGER()"

assert stream.schema["properties"]["c1"] == {"type": ["integer"]}
assert stream.schema["required"] == ["c1"]

Expand Down