Skip to content

Commit

Permalink
refactor(snowflake): get query schema using dscribe last_query_id();
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Aug 25, 2023
1 parent deea624 commit 4874821
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 10 deletions.
25 changes: 25 additions & 0 deletions ibis/backends/base/sql/glot/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ def _from_sqlglot_TIMESTAMPTZ(cls, scale=None) -> dt.Timestamp:
nullable=cls.default_nullable,
)

@classmethod
def _from_sqlglot_TIMESTAMPLTZ(cls, scale=None) -> dt.Timestamp:
return dt.Timestamp(
timezone="UTC",
scale=cls.default_temporal_scale if scale is None else int(scale.this.this),
nullable=cls.default_nullable,
)

@classmethod
def _from_sqlglot_INTERVAL(
cls, precision: sge.DataTypeParam | None = None
Expand Down Expand Up @@ -371,6 +379,23 @@ class OracleType(SqlglotType):

class SnowflakeType(SqlglotType):
dialect = "snowflake"
default_temporal_scale = 9

@classmethod
def _from_sqlglot_FLOAT(cls) -> dt.Float64:
return dt.Float64(nullable=cls.default_nullable)

@classmethod
def _from_sqlglot_DECIMAL(cls, precision=None, scale=None) -> dt.Decimal:
if scale is None or int(scale.this.this) == 0:
return dt.Int64(nullable=cls.default_nullable)
else:
return super()._from_sqlglot_DECIMAL(precision, scale)

@classmethod
def _from_sqlglot_ARRAY(cls, value_type=None) -> dt.Array:
assert value_type is None
return dt.Array(dt.json, nullable=cls.default_nullable)


class SQLiteType(SqlglotType):
Expand Down
17 changes: 9 additions & 8 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import pyarrow as pa
import sqlalchemy as sa
from snowflake.connector.constants import FIELD_ID_TO_NAME
from snowflake.sqlalchemy import ARRAY, DOUBLE, OBJECT, URL
from sqlalchemy.ext.compiler import compiles

Expand Down Expand Up @@ -409,13 +408,15 @@ def _make_batch_iter(
)

def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
with self.begin() as con, con.connection.cursor() as cur:
result = cur.describe(query)

for name, type_code, _, _, _, _, is_nullable in result:
typ_name = FIELD_ID_TO_NAME[type_code]
typ = SnowflakeType.from_string(typ_name)
yield name, typ.copy(nullable=is_nullable)
with self.begin() as con:
con.exec_driver_sql(query)
result = con.exec_driver_sql("DESC RESULT last_query_id()").mappings().all()

for field in result:
name = field["name"]
type_string = field["type"]
is_nullable = field["null?"] == "Y"
yield name, SnowflakeType.from_string(type_string, nullable=is_nullable)

def list_databases(self, like: str | None = None) -> list[str]:
with self.begin() as con:
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/snowflake/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import ibis
import ibis.expr.datatypes as dt
from ibis.backends.snowflake.datatypes import parse
from ibis.backends.snowflake.datatypes import SnowflakeType
from ibis.backends.snowflake.tests.conftest import _get_url
from ibis.util import gen_name

Expand Down Expand Up @@ -35,7 +35,7 @@
],
)
def test_parse(snowflake_type, ibis_type):
assert parse(snowflake_type.upper()) == ibis_type
assert SnowflakeType.from_string(snowflake_type.upper()) == ibis_type


@pytest.fixture(scope="module")
Expand Down

0 comments on commit 4874821

Please sign in to comment.