Skip to content

Commit

Permalink
refactor(snowflake): use sqlglot for the snowflake backend
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Jan 3, 2024
1 parent 1638630 commit 7755e6a
Show file tree
Hide file tree
Showing 62 changed files with 2,637 additions and 2,621 deletions.
5 changes: 1 addition & 4 deletions .github/renovate.json
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,7 @@
"addLabels": ["pyspark"]
},
{
"matchPackagePatterns": [
"snowflake-connector-python",
"snowflake-sqlalchemy"
],
"matchPackagePatterns": ["snowflake-connector-python"],
"addLabels": ["snowflake"]
},
{
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ jobs:
# - run: python -m pip install --upgrade pip 'poetry==1.7.1'
#
# - name: remove deps that are not compatible with sqlalchemy 2
# run: poetry remove snowflake-sqlalchemy sqlalchemy-exasol
# run: poetry remove sqlalchemy-exasol
#
# - name: add sqlalchemy 2
# run: poetry add --lock --optional 'sqlalchemy>=2,<3'
Expand Down
15 changes: 7 additions & 8 deletions ibis/backends/base/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,13 @@ def __getitem__(self, key: str) -> partial:


class FuncGen:
__slots__ = ()
__slots__ = ("namespace",)

def __init__(self, namespace: str | None = None) -> None:
self.namespace = namespace

def __getattr__(self, name: str) -> partial:
name = ".".join(filter(None, (self.namespace, name)))
return lambda *args, **kwargs: sg.func(name, *map(sge.convert, args), **kwargs)

def __getitem__(self, key: str) -> partial:
Expand Down Expand Up @@ -413,15 +417,10 @@ def visit_Time(self, op, *, arg):

@visit_node.register(ops.TimestampNow)
def visit_TimestampNow(self, op):
"""DuckDB current timestamp defaults to timestamp + tz."""
return self.cast(sge.CurrentTimestamp(), dt.timestamp)
return sge.CurrentTimestamp()

@visit_node.register(ops.Strftime)
def visit_Strftime(self, op, *, arg, format_str):
if not isinstance(op.format_str, ops.Literal):
raise com.UnsupportedOperationError(
f"{self.dialect} `format_str` must be a literal `str`; got {type(op.format_str)}"
)
return sge.TimeToStr(this=arg, format=format_str)

@visit_node.register(ops.ExtractEpochSeconds)
Expand Down Expand Up @@ -541,7 +540,7 @@ def visit_StringFind(self, op, *, arg, substr, start, end):

@visit_node.register(ops.RegexSearch)
def visit_RegexSearch(self, op, *, arg, pattern):
return self.f.regexp_matches(arg, pattern, "s")
return sge.RegexpLike(this=arg, expression=pattern, flag=sge.convert("s"))

@visit_node.register(ops.RegexReplace)
def visit_RegexReplace(self, op, *, arg, pattern, replacement):
Expand Down
20 changes: 20 additions & 0 deletions ibis/backends/base/sqlglot/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,10 @@ class OracleType(SqlglotType):

class SnowflakeType(SqlglotType):
dialect = "snowflake"

default_decimal_precision = 38
default_decimal_scale = 9

default_temporal_scale = 9

@classmethod
Expand All @@ -478,6 +482,22 @@ def _from_sqlglot_ARRAY(cls, value_type=None) -> dt.Array:
assert value_type is None
return dt.Array(dt.json, nullable=cls.default_nullable)

@classmethod
def _from_ibis_JSON(cls, dtype: dt.JSON) -> sge.DataType:
return sge.DataType(this=sge.DataType.Type.VARIANT)

@classmethod
def _from_ibis_Array(cls, dtype: dt.Array) -> sge.DataType:
return sge.DataType(this=sge.DataType.Type.ARRAY, nested=True)

@classmethod
def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
return sge.DataType(this=sge.DataType.Type.OBJECT, nested=True)

@classmethod
def _from_ibis_Struct(cls, dtype: dt.Struct) -> sge.DataType:
return sge.DataType(this=sge.DataType.Type.OBJECT, nested=True)


class SQLiteType(SqlglotType):
dialect = "sqlite"
10 changes: 10 additions & 0 deletions ibis/backends/base/sqlglot/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ def window_function_to_window(_):
)


@replace(p.Log2)
def replace_log2(_):
return ops.Log(_.arg, base=2)


@replace(p.Log10)
def replace_log10(_):
return ops.Log(_.arg, base=10)


@replace(Object(Select, Object(Select)))
def merge_select_select(_):
"""Merge subsequent Select relations into one.
Expand Down
11 changes: 1 addition & 10 deletions ibis/backends/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,16 +537,7 @@ def ddl_con(ddl_backend):

@pytest.fixture(
params=_get_backends_to_test(
keep=(
"exasol",
"mssql",
"mysql",
"oracle",
"postgres",
"snowflake",
"sqlite",
"trino",
)
keep=("exasol", "mssql", "mysql", "oracle", "postgres", "sqlite", "trino")
),
scope="session",
)
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,6 @@ def visit_RegexExtract(self, op, *, arg, pattern, index):
)
return self.f.regexp_match(arg, self.f.concat("(", pattern, ")"))[index]

# @visit_node.register(ops.RegexReplace)
# def regex_replace(self, op, *, arg, pattern, replacement):
# return self.f.regexp_replace(arg, pattern, replacement, sg.exp.convert("g"))

@visit_node.register(ops.StringFind)
def visit_StringFind(self, op, *, arg, substr, start, end):
if end is not None:
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,11 @@ def visit_GeoConvert(self, op, *, arg, source, target):
# matches the behavior of the equivalent geopandas functionality
return self.f.st_transform(arg, source, target, True)

@visit_node.register(ops.TimestampNow)
def visit_TimestampNow(self, op):
"""DuckDB current timestamp defaults to timestamp + tz."""
return self.cast(super().visit_TimestampNow(op), dt.timestamp)


_SIMPLE_OPS = {
ops.ArrayPosition: "list_indexof",
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/duckdb/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ def test_insert(con):
assert t.count().execute() == 2


@pytest.mark.xfail(reason="snowflake backend not yet rewritten")
def test_to_other_sql(con, snapshot):
pytest.importorskip("snowflake.connector")

Expand Down
7 changes: 3 additions & 4 deletions ibis/backends/oracle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
# Wow, this is truly horrible
# Get out your clippers, it's time to shave a yak.
#
# 1. snowflake-sqlalchemy doesn't support sqlalchemy 2.0
# 2. oracledb is only supported in sqlalchemy 2.0
# 3. Ergo, module hacking is required to avoid doing a silly amount of work
# 1. oracledb is only supported in sqlalchemy 2.0
# 2. Ergo, module hacking is required to avoid doing a silly amount of work
# to create multiple lockfiles or port snowflake away from sqlalchemy
# 4. Also the version needs to be spoofed to be >= 7 or else the cx_Oracle
# 3. Also the version needs to be spoofed to be >= 7 or else the cx_Oracle
# dialect barfs
oracledb.__version__ = oracledb.version = "7"

Expand Down
Loading

0 comments on commit 7755e6a

Please sign in to comment.