Skip to content

Commit

Permalink
chore: checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Nov 29, 2023
1 parent 7ed5149 commit f88d058
Show file tree
Hide file tree
Showing 14 changed files with 1,054 additions and 1,037 deletions.
84 changes: 84 additions & 0 deletions ibis/backends/base/sqlglot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING, Any, ClassVar

import sqlglot as sg
Expand Down Expand Up @@ -113,3 +114,86 @@ def sql(
def _get_schema_using_query(self, query: str) -> sch.Schema:
"""Return an ibis Schema from a backend-specific SQL string."""
return sch.Schema.from_tuples(self._metadata(query))

def create_view(
self,
name: str,
obj: ir.Table,
*,
database: str | None = None,
overwrite: bool = False,
) -> ir.Table:
src = sg.exp.Create(
this=sg.table(name, db=database),
kind="VIEW",
replace=overwrite,
expression=self._to_sqlglot(obj),
)
self._register_in_memory_tables(obj)
external_tables = self._collect_in_memory_tables(obj)
with self._safe_raw_sql(src, external_tables=external_tables):
pass
return self.table(name, database=database)

def _register_in_memory_tables(self, expr: ir.Expr) -> None:
for memtable in expr.op().find(ops.InMemoryTable):
self._register_in_memory_table(memtable)

def drop_view(
self, name: str, *, database: str | None = None, force: bool = False
) -> None:
src = sg.exp.Drop(this=sg.table(name, db=database), kind="VIEW", exists=force)
with contextlib.closing(self.raw_sql(src)):
pass

def _get_temp_view_definition(self, name: str, definition: str) -> str:
yield sg.exp.Create(
this=sg.to_identifier(name, quoted=self.compiler.quoted),
kind="VIEW",
expression=definition,
replace=True,
properties=sg.exp.Properties(expressions=[sg.exp.TemporaryProperty()]),
).sql(self.name)

def _create_temp_view(self, table_name, source):
if table_name not in self._temp_views and table_name in self.list_tables():
raise ValueError(
f"{table_name} already exists as a non-temporary table or view"
)
with self._safe_raw_sql(self._get_temp_view_definition(table_name, source)):
pass
self._temp_views.add(table_name)
self._register_temp_view_cleanup(table_name)

def _register_temp_view_cleanup(self, name: str) -> None:
"""Register a clean up function for a temporary view.
No-op by default.
Parameters
----------
name
The temporary view to register for clean up.
"""

def _load_into_cache(self, name, expr):
self.create_table(name, expr, schema=expr.schema(), temp=True)

def _clean_up_cached_table(self, op):
self.drop_table(op.name)

def execute(
self, expr: ir.Expr, limit: str | None = "default", **kwargs: Any
) -> Any:
"""Execute an expression."""

self._run_pre_execute_hooks(expr)
table = expr.as_table()
sql = self.compile(table, limit=limit, **kwargs)

schema = table.schema()
self._log(sql)

with self._safe_raw_sql(sql) as cur:
result = self.fetch_from_cursor(cur, schema)
return expr.__pandas_result__(result)
24 changes: 22 additions & 2 deletions ibis/backends/base/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import calendar
import functools
import itertools
import math
import operator
import string
from collections.abc import Mapping
Expand Down Expand Up @@ -144,6 +145,15 @@ class SQLGlotCompiler(abc.ABC):
quoted: bool | None = None
"""Whether to always quote identifiers."""

NAN = sg.exp.Literal.number("'NaN'::double")
"""Backend's NaN literal."""

POS_INF = sg.exp.Literal.number("'Inf'::double")
"""Backend's positive infinity literal."""

NEG_INF = sg.exp.Literal.number("'-Inf'::double")
"""Backend's negative infinity literal."""

def __init__(self) -> None:
self.agg = AggGen(aggfunc=self._aggregate)
self.f = FuncGen()
Expand Down Expand Up @@ -217,10 +227,10 @@ def fn(node, _, **kwargs):
return result

alias_index = next(gen_alias_index)
alias = f"t{alias_index:d}"
alias = sg.to_identifier(f"t{alias_index:d}", quoted=quoted)

try:
return result.subquery(sg.exp.TableAlias(this=alias, quoted=quoted))
return result.subquery(alias)
except AttributeError:
return result.as_(alias, quoted=quoted)

Expand Down Expand Up @@ -269,6 +279,16 @@ def visit_Literal(self, op, *, value, dtype, **kw):
raise com.UnsupportedOperationError(
f"Unsupported NULL for non-nullable type: {dtype!r}"
)
elif dtype.is_integer():
return sg.exp.convert(value)
elif dtype.is_floating():
if math.isnan(value):
return self.NAN
elif math.isinf(value):
return self.POS_INF if value < 0 else self.NEG_INF
return sg.exp.convert(value)
elif dtype.is_decimal():
return self.cast(sg.exp.convert(str(value)), dtype)
elif dtype.is_interval():
return sg.exp.Interval(
this=sg.exp.convert(str(value)), unit=dtype.resolution.upper()
Expand Down
21 changes: 0 additions & 21 deletions ibis/backends/base/sqlglot/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,26 +405,5 @@ class OracleType(SqlglotType):
dialect = "oracle"


class SnowflakeType(SqlglotType):
dialect = "snowflake"
default_temporal_scale = 9

@classmethod
def _from_sqlglot_FLOAT(cls) -> dt.Float64:
return dt.Float64(nullable=cls.default_nullable)

@classmethod
def _from_sqlglot_DECIMAL(cls, precision=None, scale=None) -> dt.Decimal:
if scale is None or int(scale.this.this) == 0:
return dt.Int64(nullable=cls.default_nullable)
else:
return super()._from_sqlglot_DECIMAL(precision, scale)

@classmethod
def _from_sqlglot_ARRAY(cls, value_type=None) -> dt.Array:
assert value_type is None
return dt.Array(dt.json, nullable=cls.default_nullable)


class SQLiteType(SqlglotType):
dialect = "sqlite"
27 changes: 1 addition & 26 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,35 +709,10 @@ def create_view(
expression=self._to_sqlglot(obj),
)
external_tables = self._collect_in_memory_tables(obj)
with closing(self.raw_sql(src, external_tables=external_tables)):
with self._safe_raw_sql(src, external_tables=external_tables):
pass
return self.table(name, database=database)

def drop_view(
self, name: str, *, database: str | None = None, force: bool = False
) -> None:
src = sg.exp.Drop(this=sg.table(name, db=database), kind="VIEW", exists=force)
with closing(self.raw_sql(src)):
pass

def _load_into_cache(self, name, expr):
self.create_table(name, expr, schema=expr.schema(), temp=True)

def _clean_up_cached_table(self, op):
self.drop_table(op.name)

def _create_temp_view(self, table_name, source):
if table_name not in self._temp_views and table_name in self.list_tables():
raise ValueError(
f"{table_name} already exists as a non-temporary table or view"
)
src = sg.exp.Create(
this=sg.table(table_name), kind="VIEW", replace=True, expression=source
)
self.raw_sql(src)
self._temp_views.add(table_name)
self._register_temp_view_cleanup(table_name)

def _register_temp_view_cleanup(self, name: str) -> None:
def drop(self, name: str, query: str):
self.raw_sql(query)
Expand Down
46 changes: 4 additions & 42 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ def get_schema(
}
)

@contextlib.contextmanager
def _safe_raw_sql(self, *args, **kwargs):
yield self.raw_sql(*args, **kwargs)

def list_databases(self, like: str | None = None) -> list[str]:
col = "catalog_name"
query = sg.select(sg.exp.Distinct(expressions=[sg.column(col)])).from_(
Expand Down Expand Up @@ -412,26 +416,6 @@ def _from_url(self, url: str, **kwargs) -> BaseBackend:
self._convert_kwargs(kwargs)
return self.connect(**kwargs)

def execute(
self, expr: ir.Expr, limit: str | None = "default", **kwargs: Any
) -> Any:
"""Execute an expression."""

self._run_pre_execute_hooks(expr)
table = expr.as_table()
sql = self.compile(table, limit=limit, **kwargs)

schema = table.schema()
self._log(sql)

try:
cur = self.con.execute(sql)
except duckdb.CatalogException as e:
raise exc.IbisError(e)

result = self.fetch_from_cursor(cur, schema)
return expr.__pandas_result__(result)

def load_extension(self, extension: str, force_install: bool = False) -> None:
"""Install and load a duckdb extension by name or path.
Expand Down Expand Up @@ -532,25 +516,6 @@ def _register_failure(self):
f"please call one of {msg} directly"
)

def _create_temp_view(self, table_name, source):
if table_name not in self._temp_views and table_name in self.list_tables():
raise ValueError(
f"{table_name} already exists as a non-temporary table or view"
)
src = sg.exp.Create(
this=sg.exp.Identifier(
this=table_name, quoted=True
), # CREATE ... 'table_name'
kind="VIEW", # VIEW
replace=True, # OR REPLACE
properties=sg.exp.Properties(
expressions=[sg.exp.TemporaryProperty()] # TEMPORARY
),
expression=source, # AS ...
)
self.raw_sql(src.sql("duckdb"))
self._temp_views.add(table_name)

@util.experimental
def read_json(
self,
Expand Down Expand Up @@ -1352,9 +1317,6 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
except duckdb.NotImplementedException:
self.con.register(name, data.to_pyarrow(schema))

def _get_temp_view_definition(self, name: str, definition) -> str:
yield f"CREATE OR REPLACE TEMPORARY VIEW {name} AS {definition}"

def _register_udfs(self, expr: ir.Expr) -> None:
import ibis.expr.operations as ops

Expand Down
Loading

0 comments on commit f88d058

Please sign in to comment.