From 975de51e324f068b6cf7ddf43c94ae3e33b956ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 8 Mar 2024 14:38:40 +0100 Subject: [PATCH] perf(sql): prevent sqlglot from extensive deepcopying every time we create a sqlglot object --- ibis/backends/bigquery/compiler.py | 6 +- ibis/backends/clickhouse/compiler.py | 11 +- ibis/backends/datafusion/compiler.py | 14 +- ibis/backends/duckdb/compiler.py | 5 +- ibis/backends/flink/compiler.py | 6 +- ibis/backends/mssql/compiler.py | 3 +- ibis/backends/postgres/compiler.py | 10 +- ibis/backends/sql/__init__.py | 4 +- ibis/backends/sql/compiler.py | 184 +++++++++++++++------------ ibis/backends/trino/compiler.py | 12 +- 10 files changed, 132 insertions(+), 123 deletions(-) diff --git a/ibis/backends/bigquery/compiler.py b/ibis/backends/bigquery/compiler.py index 69dec3405704..442719e85e6e 100644 --- a/ibis/backends/bigquery/compiler.py +++ b/ibis/backends/bigquery/compiler.py @@ -12,7 +12,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis import util -from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler, paren +from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler from ibis.backends.sql.datatypes import BigQueryType, BigQueryUDFType from ibis.backends.sql.rewrites import ( exclude_unsupported_window_frame_from_ops, @@ -707,10 +707,10 @@ def visit_CountStar(self, op, *, arg, where): return self.f.count(STAR) def visit_Degrees(self, op, *, arg): - return paren(180 * arg / self.f.acos(-1)) + return sge.paren(180 * arg / self.f.acos(-1), copy=False) def visit_Radians(self, op, *, arg): - return paren(self.f.acos(-1) * arg / 180) + return sge.paren(self.f.acos(-1) * arg / 180, copy=False) def visit_CountDistinct(self, op, *, arg, where): if where is not None: diff --git a/ibis/backends/clickhouse/compiler.py b/ibis/backends/clickhouse/compiler.py index 3201e11ca811..ba24173c6ae1 100644 --- a/ibis/backends/clickhouse/compiler.py +++ b/ibis/backends/clickhouse/compiler.py @@ -11,12 +11,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis import util -from ibis.backends.sql.compiler import ( - NULL, - STAR, - SQLGlotCompiler, - parenthesize, -) +from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler from ibis.backends.sql.datatypes import ClickHouseType from ibis.backends.sql.dialects import ClickHouse from ibis.backends.sql.rewrites import rewrite_sample_as_filter @@ -163,11 +158,11 @@ def visit_ArrayRepeat(self, op, *, arg, times): return self.f.arrayFlatten(self.f.arrayMap(func, self.f.range(times))) def visit_ArraySlice(self, op, *, arg, start, stop): - start = parenthesize(op.start, start) + start = self._add_parens(op.start, start) start_correct = self.if_(start < 0, start, start + 1) if stop is not None: - stop = parenthesize(op.stop, stop) + stop = self._add_parens(op.stop, stop) length = self.if_( stop < 0, diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index 64c4de25e85d..17a724fdf240 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -11,13 +11,7 @@ import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.backends.sql.compiler import ( - FALSE, - NULL, - STAR, - SQLGlotCompiler, - paren, -) +from ibis.backends.sql.compiler import FALSE, NULL, STAR, SQLGlotCompiler from ibis.backends.sql.datatypes import DataFusionType from ibis.backends.sql.dialects import DataFusion from ibis.backends.sql.rewrites import rewrite_sample_as_filter @@ -261,7 +255,7 @@ def visit_DayOfWeekIndex(self, op, *, arg): def visit_DayOfWeekName(self, op, *, arg): return sg.exp.Case( - this=paren(self.f.date_part("dow", arg) + 6) % 7, + this=sge.paren(self.f.date_part("dow", arg) + 6, copy=False) % 7, ifs=list(starmap(self.if_, enumerate(calendar.day_name))), ) @@ -438,7 +432,7 @@ def visit_StringConcat(self, op, *, arg): def visit_Aggregate(self, op, *, parent, groups, metrics): """Support `GROUP BY` expressions in `SELECT` since DataFusion does not.""" quoted = self.quoted - metrics = tuple(starmap(self._dedup_name, metrics.items())) + metrics = tuple(self._cleanup_names(metrics)) if groups: # datafusion doesn't support count distinct aggregations alongside @@ -459,7 +453,7 @@ def visit_Aggregate(self, op, *, parent, groups, metrics): ) ) table = ( - sg.select(*cols, *starmap(self._dedup_name, groups.items())) + sg.select(*cols, *self._cleanup_names(groups)) .from_(parent) .subquery(parent.alias) ) diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index 3e6f85981004..a597ad75e9b6 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -11,7 +11,7 @@ import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler, paren +from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler from ibis.backends.sql.datatypes import DuckDBType _INTERVAL_SUFFIXES = { @@ -402,6 +402,7 @@ def visit_StructField(self, op, *, arg, field): if not isinstance(op.arg, (ops.Field, sge.Struct)): # parenthesize anything that isn't a simple field access return sge.Dot( - this=paren(arg), expression=sg.to_identifier(field, quoted=self.quoted) + this=sge.paren(arg), + expression=sg.to_identifier(field, quoted=self.quoted), ) return super().visit_StructField(op, arg=arg, field=field) diff --git a/ibis/backends/flink/compiler.py b/ibis/backends/flink/compiler.py index 38392812a934..c4eeab523dc6 100644 --- a/ibis/backends/flink/compiler.py +++ b/ibis/backends/flink/compiler.py @@ -8,7 +8,7 @@ import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler, paren +from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler from ibis.backends.sql.datatypes import FlinkType from ibis.backends.sql.dialects import Flink from ibis.backends.sql.rewrites import ( @@ -467,12 +467,12 @@ def visit_TimestampBucket(self, op, *, arg, interval, offset): bucket_width = op.interval.value unit_func = self.f["dayofmonth" if unit.upper() == "DAY" else unit] - arg = self.f.anon.timestampadd(unit_var, -paren(offset), arg) + arg = self.f.anon.timestampadd(unit_var, -sge.paren(offset, copy=False), arg) mod = unit_func(arg) % bucket_width return self.f.anon.timestampadd( unit_var, - -paren(mod) + offset, + -sge.paren(mod, copy=False) + offset, self.v[f"FLOOR({arg.sql(self.dialect)} TO {unit_var.sql(self.dialect)})"], ) diff --git a/ibis/backends/mssql/compiler.py b/ibis/backends/mssql/compiler.py index bd97c198d00e..eef723836804 100644 --- a/ibis/backends/mssql/compiler.py +++ b/ibis/backends/mssql/compiler.py @@ -15,7 +15,6 @@ STAR, TRUE, SQLGlotCompiler, - paren, ) from ibis.backends.sql.datatypes import MSSQLType from ibis.backends.sql.dialects import MSSQL @@ -189,7 +188,7 @@ def visit_StringLength(self, op, *, arg): Thanks to @arkanovicz for this glorious hack. """ - return paren(self.f.len(self.f.concat("A", arg, "Z")) - 2) + return sge.paren(self.f.len(self.f.concat("A", arg, "Z")) - 2, copy=False) def visit_GroupConcat(self, op, *, arg, sep, where): if where is not None: diff --git a/ibis/backends/postgres/compiler.py b/ibis/backends/postgres/compiler.py index e904952624f9..817709d9751e 100644 --- a/ibis/backends/postgres/compiler.py +++ b/ibis/backends/postgres/compiler.py @@ -11,7 +11,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.rules as rlz -from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler, paren +from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler from ibis.backends.sql.datatypes import PostgresType from ibis.backends.sql.dialects import Postgres from ibis.backends.sql.rewrites import rewrite_sample_as_filter @@ -125,7 +125,7 @@ def visit_ArgMinMax(self, op, *, arg, key, where, desc: bool): sge.Ordered(this=sge.Order(this=arg, expressions=[key]), desc=desc), where=sg.and_(*conditions), ) - return paren(agg)[0] + return sge.paren(agg, copy=False)[0] def visit_ArgMin(self, op, *, arg, key, where): return self.visit_ArgMinMax(op, arg=arg, key=key, where=where, desc=False) @@ -381,7 +381,7 @@ def visit_Modulus(self, op, *, left, right): def visit_RegexExtract(self, op, *, arg, pattern, index): pattern = self.f.concat("(", pattern, ")") matches = self.f.regexp_match(arg, pattern) - return self.if_(arg.rlike(pattern), paren(matches)[index], NULL) + return self.if_(arg.rlike(pattern), sge.paren(matches, copy=False)[index], NULL) def visit_FindInSet(self, op, *, needle, values): return self.f.coalesce( @@ -466,7 +466,7 @@ def visit_ExtractEpochSeconds(self, op, *, arg): def visit_ArrayIndex(self, op, *, arg, index): index = self.if_(index < 0, self.f.cardinality(arg) + index, index) - return paren(arg)[index + 1] + return sge.paren(arg, copy=False)[index + 1] def visit_ArraySlice(self, op, *, arg, start, stop): neg_to_pos_index = lambda n, index: self.if_(index < 0, n + index, index) @@ -484,7 +484,7 @@ def visit_ArraySlice(self, op, *, arg, start, stop): stop = neg_to_pos_index(arg_length, stop) slice_expr = sge.Slice(this=start + 1, expression=stop) - return paren(arg)[slice_expr] + return sge.paren(arg, copy=False)[slice_expr] def visit_IntervalFromInteger(self, op, *, arg, unit): plural = unit.plural diff --git a/ibis/backends/sql/__init__.py b/ibis/backends/sql/__init__.py index e201d4a8930a..7e2f65ef3411 100644 --- a/ibis/backends/sql/__init__.py +++ b/ibis/backends/sql/__init__.py @@ -107,7 +107,7 @@ def _to_sqlglot( assert not isinstance(sql, sge.Subquery) if isinstance(sql, sge.Table): - sql = sg.select(STAR).from_(sql) + sql = sg.select(STAR, copy=False).from_(sql, copy=False) assert not isinstance(sql, sge.Subquery) return sql @@ -117,7 +117,7 @@ def compile( ): """Compile an Ibis expression to a SQL string.""" query = self._to_sqlglot(expr, limit=limit, params=params, **kwargs) - sql = query.sql(dialect=self.dialect, pretty=True) + sql = query.sql(dialect=self.dialect, pretty=True, copy=False) self._log(sql) return sql diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index 82e34159f821..216b43850576 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -5,14 +5,12 @@ import itertools import math import string -from collections.abc import Iterator, Mapping +from collections.abc import Mapping from functools import partial, reduce -from itertools import starmap from typing import TYPE_CHECKING, Any, Callable, ClassVar import sqlglot as sg import sqlglot.expressions as sge -import toolz from public import public import ibis.common.exceptions as com @@ -132,22 +130,10 @@ def __init__(self, table: str | None = None) -> None: self.table = table def __getattr__(self, name: str) -> sge.Column: - return sg.column(name, table=self.table) + return sg.column(name, table=self.table, copy=False) def __getitem__(self, key: str) -> sge.Column: - return sg.column(key, table=self.table) - - -def paren(expr): - """Wrap a sqlglot expression in parentheses.""" - return sge.Paren(this=sge.convert(expr)) - - -def parenthesize(op, arg): - if isinstance(op, (ops.Binary, ops.Unary)): - return paren(arg) - # function calls don't need parens - return arg + return sg.column(key, table=self.table, copy=False) C = ColGen() @@ -404,11 +390,11 @@ def if_(self, condition, true, false: sge.Expression | None = None) -> sge.If: return sge.If( this=sge.convert(condition), true=sge.convert(true), - false=false if false is None else sge.convert(false), + false=None if false is None else sge.convert(false), ) def cast(self, arg, to: dt.DataType) -> sge.Cast: - return sg.cast(sge.convert(arg), to=self.type_mapper.from_ibis(to)) + return sg.cast(sge.convert(arg), to=self.type_mapper.from_ibis(to), copy=False) def _prepare_params(self, params): result = {} @@ -465,7 +451,7 @@ def fn(node, _, **kwargs): return result.as_(alias, quoted=self.quoted) else: try: - return result.subquery(alias) + return result.subquery(alias, copy=False) except AttributeError: return result.as_(alias, quoted=self.quoted) @@ -475,14 +461,16 @@ def fn(node, _, **kwargs): # get the root node as a sqlglot select statement out = results[op] if isinstance(out, sge.Table): - out = sg.select(STAR).from_(out) + out = sg.select(STAR, copy=False).from_(out, copy=False) elif isinstance(out, sge.Subquery): out = out.this # add cte definitions to the select statement for cte in ctes: alias = sg.to_identifier(aliases[cte], quoted=self.quoted) - out = out.with_(alias, as_=results[cte].this, dialect=self.dialect) + out = out.with_( + alias, as_=results[cte].this, dialect=self.dialect, copy=False + ) return out @@ -509,7 +497,7 @@ def visit_Cast(self, op, *, arg, to): return self.cast(arg, to) def visit_ScalarSubquery(self, op, *, rel): - return rel.this.subquery() + return rel.this.subquery(copy=False) def visit_Alias(self, op, *, arg, name): return arg @@ -687,12 +675,14 @@ def visit_Between(self, op, *, arg, lower_bound, upper_bound): return sge.Between(this=arg, low=lower_bound, high=upper_bound) def visit_Negate(self, op, *, arg): - return -paren(arg) + return -sge.paren(arg, copy=False) def visit_Not(self, op, *, arg): if isinstance(arg, sge.Filter): - return sge.Filter(this=sg.not_(arg.this), expression=arg.expression) - return sg.not_(paren(arg)) + return sge.Filter( + this=sg.not_(arg.this, copy=False), expression=arg.expression + ) + return sg.not_(sge.paren(arg, copy=False)) ### Timey McTimeFace @@ -834,7 +824,7 @@ def visit_IsNull(self, op, *, arg): return arg.is_(NULL) def visit_NotNull(self, op, *, arg): - return arg.is_(sg.not_(NULL)) + return arg.is_(sg.not_(NULL, copy=False)) def visit_InValues(self, op, *, value, options): return value.isin(*options) @@ -1021,7 +1011,9 @@ def visit_Argument(self, op, *, name: str, shape, dtype): return sg.to_identifier(op.param) def visit_RowID(self, op, *, table): - return sg.column(op.name, table=table.alias_or_name, quoted=self.quoted) + return sg.column( + op.name, table=table.alias_or_name, quoted=self.quoted, copy=False + ) # TODO(kszucs): this should be renamed to something UDF related def __sql_name__(self, op: ops.ScalarUDF | ops.AggUDF) -> str: @@ -1066,16 +1058,6 @@ def visit_ArrayConcat(self, op, *, arg): ## relations - def _dedup_name( - self, key: str, value: sge.Expression - ) -> Iterator[sge.Alias | sge.Column]: - """Don't alias columns that are already named the same as their alias.""" - return ( - value - if isinstance(value, sge.Column) and key == value.name - else value.as_(key, quoted=self.quoted) - ) - @staticmethod def _gen_valid_name(name: str) -> str: """Generate a valid name for a value expression. @@ -1089,9 +1071,14 @@ def _gen_valid_name(name: str) -> str: def _cleanup_names(self, exprs: Mapping[str, sge.Expression]): """Compose `_gen_valid_name` and `_dedup_name` to clean up names in projections.""" - return starmap( - self._dedup_name, toolz.keymap(self._gen_valid_name, exprs).items() - ) + + for name, value in exprs.items(): + name = self._gen_valid_name(name) + if isinstance(value, sge.Column) and name == value.name: + # don't alias columns that are already named the same as their alias + yield value + else: + yield value.as_(name, quoted=self.quoted, copy=False) def visit_Select(self, op, *, parent, selections, predicates, sort_keys): # if we've constructed a useless projection return the parent relation @@ -1101,18 +1088,20 @@ def visit_Select(self, op, *, parent, selections, predicates, sort_keys): result = parent if selections: - result = sg.select(*self._cleanup_names(selections)).from_(result) + result = sg.select(*self._cleanup_names(selections), copy=False).from_( + result, copy=False + ) if predicates: - result = result.where(*predicates) + result = result.where(*predicates, copy=False) if sort_keys: - result = result.order_by(*sort_keys) + result = result.order_by(*sort_keys, copy=False) return result def visit_DummyTable(self, op, *, values): - return sg.select(*self._cleanup_names(values)) + return sg.select(*self._cleanup_names(values), copy=False) def visit_UnboundTable( self, op, *, name: str, schema: sch.Schema, namespace: ops.Namespace @@ -1143,12 +1132,14 @@ def visit_SelfReference(self, op, *, parent, identifier): return parent def visit_JoinChain(self, op, *, first, rest, values): - result = sg.select(*self._cleanup_names(values)).from_(first) + result = sg.select(*self._cleanup_names(values), copy=False).from_( + first, copy=False + ) for link in rest: if isinstance(link, sge.Alias): link = link.this - result = result.join(link) + result = result.join(link, copy=False) return result def visit_JoinLink(self, op, *, how, table, predicates): @@ -1189,18 +1180,18 @@ def _generate_groups(groups): def visit_Aggregate(self, op, *, parent, groups, metrics): sel = sg.select( - *self._cleanup_names(groups), *self._cleanup_names(metrics) - ).from_(parent) + *self._cleanup_names(groups), *self._cleanup_names(metrics), copy=False + ).from_(parent, copy=False) if groups: - sel = sel.group_by(*self._generate_groups(groups.values())) + sel = sel.group_by(*self._generate_groups(groups.values()), copy=False) return sel @classmethod def _add_parens(cls, op, sg_expr): if isinstance(op, cls.NEEDS_PARENS): - return paren(sg_expr) + return sge.paren(sg_expr, copy=False) return sg_expr def visit_Filter(self, op, *, parent, predicates): @@ -1209,104 +1200,126 @@ def visit_Filter(self, op, *, parent, predicates): for raw_predicate, predicate in zip(op.predicates, predicates) ) try: - return parent.where(*predicates) + return parent.where(*predicates, copy=False) except AttributeError: - return sg.select(STAR).from_(parent).where(*predicates) + return ( + sg.select(STAR, copy=False) + .from_(parent, copy=False) + .where(*predicates, copy=False) + ) def visit_Sort(self, op, *, parent, keys): try: - return parent.order_by(*keys) + return parent.order_by(*keys, copy=False) except AttributeError: - return sg.select(STAR).from_(parent).order_by(*keys) + return ( + sg.select(STAR, copy=False) + .from_(parent, copy=False) + .order_by(*keys, copy=False) + ) def visit_Union(self, op, *, left, right, distinct): if isinstance(left, (sge.Table, sge.Subquery)): - left = sg.select(STAR).from_(left) + left = sg.select(STAR, copy=False).from_(left, copy=False) if isinstance(right, (sge.Table, sge.Subquery)): - right = sg.select(STAR).from_(right) + right = sg.select(STAR, copy=False).from_(right, copy=False) return sg.union( left.args.get("this", left), right.args.get("this", right), distinct=distinct, + copy=False, ) def visit_Intersection(self, op, *, left, right, distinct): if isinstance(left, (sge.Table, sge.Subquery)): - left = sg.select(STAR).from_(left) + left = sg.select(STAR, copy=False).from_(left, copy=False) if isinstance(right, (sge.Table, sge.Subquery)): - right = sg.select(STAR).from_(right) + right = sg.select(STAR, copy=False).from_(right, copy=False) return sg.intersect( left.args.get("this", left), right.args.get("this", right), distinct=distinct, + copy=False, ) def visit_Difference(self, op, *, left, right, distinct): if isinstance(left, (sge.Table, sge.Subquery)): - left = sg.select(STAR).from_(left) + left = sg.select(STAR, copy=False).from_(left, copy=False) if isinstance(right, (sge.Table, sge.Subquery)): - right = sg.select(STAR).from_(right) + right = sg.select(STAR, copy=False).from_(right, copy=False) return sg.except_( left.args.get("this", left), right.args.get("this", right), distinct=distinct, + copy=False, ) def visit_Limit(self, op, *, parent, n, offset): # push limit/offset into subqueries if isinstance(parent, sge.Subquery) and parent.this.args.get("limit") is None: - result = parent.this + result = parent.this.copy() alias = parent.alias else: - result = sg.select(STAR).from_(parent) + result = sg.select(STAR, copy=False).from_(parent, copy=False) alias = None if isinstance(n, int): - result = result.limit(n) + result = result.limit(n, copy=False) elif n is not None: - result = result.limit(sg.select(n).from_(parent).subquery()) + result = result.limit( + sg.select(n, copy=False).from_(parent, copy=False).subquery(copy=False), + copy=False, + ) else: assert n is None, n if self.no_limit_value is not None: - result = result.limit(self.no_limit_value) + result = result.limit(self.no_limit_value, copy=False) assert offset is not None, "offset is None" if not isinstance(offset, int): skip = offset - skip = sg.select(skip).from_(parent).subquery() + skip = ( + sg.select(skip, copy=False) + .from_(parent, copy=False) + .subquery(copy=False) + ) elif not offset: if alias is not None: - return result.subquery(alias) + return result.subquery(alias, copy=False) return result else: skip = offset - result = result.offset(skip) + result = result.offset(skip, copy=False) if alias is not None: - return result.subquery(alias) + return result.subquery(alias, copy=False) return result def visit_Distinct(self, op, *, parent): - return sg.select(STAR).distinct().from_(parent) + return ( + sg.select(STAR, copy=False).distinct(copy=False).from_(parent, copy=False) + ) def visit_DropNa(self, op, *, parent, how, subset): if subset is None: subset = [ - sg.column(name, table=parent.alias_or_name, quoted=self.quoted) + sg.column( + name, table=parent.alias_or_name, quoted=self.quoted, copy=False + ) for name in op.schema.names ] if subset: predicate = reduce( sg.and_ if how == "any" else sg.or_, - (sg.not_(col.is_(NULL)) for col in subset), + (sg.not_(col.is_(NULL), copy=False) for col in subset), ) elif how == "all": predicate = FALSE @@ -1317,9 +1330,13 @@ def visit_DropNa(self, op, *, parent, how, subset): return parent try: - return parent.where(predicate) + return parent.where(predicate, copy=False) except AttributeError: - return sg.select(STAR).from_(parent).where(predicate) + return ( + sg.select(STAR, copy=False) + .from_(parent, copy=False) + .where(predicate, copy=False) + ) def visit_FillNa(self, op, *, parent, replacements): if isinstance(replacements, Mapping): @@ -1332,26 +1349,33 @@ def visit_FillNa(self, op, *, parent, replacements): } exprs = { col: ( - self.f.coalesce(sg.column(col, quoted=self.quoted), sge.convert(alt)) + self.f.coalesce( + sg.column(col, quoted=self.quoted, copy=False), + sge.convert(alt), + ) if (alt := mapping.get(col)) is not None else sg.column(col, quoted=self.quoted) ) for col in op.schema.keys() } - return sg.select(*self._cleanup_names(exprs)).from_(parent) + return sg.select(*self._cleanup_names(exprs), copy=False).from_( + parent, copy=False + ) def visit_CTE(self, op, *, parent): return sg.table(parent.alias_or_name, quoted=self.quoted) def visit_View(self, op, *, child, name: str): if isinstance(child, sge.Table): - child = sg.select(STAR).from_(child) + child = sg.select(STAR, copy=False).from_(child, copy=False) + else: + child = child.copy() if isinstance(child, sge.Subquery): return child.as_(name, quoted=self.quoted) else: try: - return child.subquery(name) + return child.subquery(name, copy=False) except AttributeError: return child.as_(name, quoted=self.quoted) @@ -1359,7 +1383,7 @@ def visit_SQLStringView(self, op, *, query: str, child, schema): return sg.parse_one(query, read=self.dialect) def visit_SQLQueryResult(self, op, *, query, schema, source): - return sg.parse_one(query, dialect=self.dialect).subquery() + return sg.parse_one(query, dialect=self.dialect).subquery(copy=False) def visit_JoinTable(self, op, *, parent, index): return parent diff --git a/ibis/backends/trino/compiler.py b/ibis/backends/trino/compiler.py index 67ada7c4ec49..a32cc72f1527 100644 --- a/ibis/backends/trino/compiler.py +++ b/ibis/backends/trino/compiler.py @@ -10,13 +10,7 @@ import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.backends.sql.compiler import ( - FALSE, - NULL, - STAR, - SQLGlotCompiler, - paren, -) +from ibis.backends.sql.compiler import FALSE, NULL, STAR, SQLGlotCompiler from ibis.backends.sql.datatypes import TrinoType from ibis.backends.sql.dialects import Trino from ibis.backends.sql.rewrites import exclude_unsupported_window_frame_from_ops @@ -181,7 +175,9 @@ def visit_JSONGetItem(self, op, *, arg, index): return self.f.json_extract(arg, self.f.format(f"$[{fmt}]", index)) def visit_DayOfWeekIndex(self, op, *, arg): - return self.cast(paren(self.f.day_of_week(arg) + 6) % 7, op.dtype) + return self.cast( + sge.paren(self.f.day_of_week(arg) + 6, copy=False) % 7, op.dtype + ) def visit_DayOfWeekName(self, op, *, arg): return self.f.date_format(arg, "%W")