diff --git a/ibis/backends/sql/__init__.py b/ibis/backends/sql/__init__.py index e201d4a8930a6..7e2f65ef34114 100644 --- a/ibis/backends/sql/__init__.py +++ b/ibis/backends/sql/__init__.py @@ -107,7 +107,7 @@ def _to_sqlglot( assert not isinstance(sql, sge.Subquery) if isinstance(sql, sge.Table): - sql = sg.select(STAR).from_(sql) + sql = sg.select(STAR, copy=False).from_(sql, copy=False) assert not isinstance(sql, sge.Subquery) return sql @@ -117,7 +117,7 @@ def compile( ): """Compile an Ibis expression to a SQL string.""" query = self._to_sqlglot(expr, limit=limit, params=params, **kwargs) - sql = query.sql(dialect=self.dialect, pretty=True) + sql = query.sql(dialect=self.dialect, pretty=True, copy=False) self._log(sql) return sql diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index 654d926c4ac51..f140b4f57ec94 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -76,10 +76,12 @@ def __getitem__(self, key: str) -> sge.Var: class AnonymousFuncGen: __slots__ = () + def _anonymous(self, name: str, *args: Any) -> sge.Anonymous: + exprs = [sge.convert(arg, copy=False) for arg in args] + return sge.Anonymous(this=name, expressions=exprs) + def __getattr__(self, name: str) -> Callable[..., sge.Anonymous]: - return lambda *args: sge.Anonymous( - this=name, expressions=list(map(sge.convert, args)) - ) + return partial(self._anonymous, name) def __getitem__(self, key: str) -> Callable[..., sge.Anonymous]: return getattr(self, key) @@ -92,9 +94,13 @@ def __init__(self, namespace: str | None = None) -> None: self.namespace = namespace self.anon = AnonymousFuncGen() - def __getattr__(self, name: str) -> Callable[..., sge.Func]: + def _func(self, name: str, *args: Any, **kwargs: Any) -> sge.Func: name = ".".join(filter(None, (self.namespace, name))) - return lambda *args, **kwargs: sg.func(name, *map(sge.convert, args), **kwargs) + args = (sge.convert(arg, copy=False) for arg in args) + return sg.func(name, *args, **kwargs, copy=False) + + def __getattr__(self, name: str) -> Callable[..., sge.Func]: + return partial(self._func, name) def __getitem__(self, key: str) -> Callable[..., sge.Func]: return getattr(self, key) @@ -110,7 +116,8 @@ def array(self, *args: Any) -> sge.Array: not rest ), "only one argument allowed when `first` is a select statement" - return sge.Array(expressions=list(map(sge.convert, (first, *rest)))) + exprs = [sge.convert(elem, copy=False) for elem in (first, *rest)] + return sge.Array(expressions=exprs) def tuple(self, *args: Any) -> sge.Anonymous: return self.anon.tuple(*args) @@ -119,7 +126,8 @@ def exists(self, query: sge.Expression) -> sge.Exists: return sge.Exists(this=query) def concat(self, *args: Any) -> sge.Concat: - return sge.Concat(expressions=list(map(sge.convert, args))) + exprs = [sge.convert(arg, copy=False) for arg in args] + return sge.Concat(expressions=exprs) def map(self, keys: Iterable, values: Iterable) -> sge.Map: return sge.Map(keys=keys, values=values) @@ -132,15 +140,15 @@ def __init__(self, table: str | None = None) -> None: self.table = table def __getattr__(self, name: str) -> sge.Column: - return sg.column(name, table=self.table) + return sg.column(name, table=self.table, copy=False) def __getitem__(self, key: str) -> sge.Column: - return sg.column(key, table=self.table) + return sg.column(key, table=self.table, copy=False) def paren(expr): """Wrap a sqlglot expression in parentheses.""" - return sge.Paren(this=sge.convert(expr)) + return sge.Paren(this=sge.convert(expr, copy=False)) def parenthesize(op, arg): @@ -402,13 +410,15 @@ def _aggregate(self, funcname, *args, where): def if_(self, condition, true, false: sge.Expression | None = None) -> sge.If: return sge.If( - this=sge.convert(condition), - true=sge.convert(true), - false=false if false is None else sge.convert(false), + this=sge.convert(condition, copy=False), + true=sge.convert(true, copy=False), + false=false if false is None else sge.convert(false, copy=False), ) def cast(self, arg, to: dt.DataType) -> sge.Cast: - return sg.cast(sge.convert(arg), to=self.type_mapper.from_ibis(to)) + return sg.cast( + sge.convert(arg, copy=False), to=self.type_mapper.from_ibis(to), copy=False + ) def _prepare_params(self, params): result = {} @@ -460,11 +470,14 @@ def fn(node, _, **kwargs): alias = node.name if isinstance(node, ops.View) else f"t{next(counter)}" aliases[node] = alias - alias = sg.to_identifier(alias, quoted=self.quoted) + alias = sg.to_identifier(alias, quoted=self.quoted, copy=False) try: - return result.subquery(alias) + # any nested select must be a subquery unless it is the root + return result.subquery(alias, copy=False) except AttributeError: - return result.as_(alias, quoted=self.quoted) + # although tables cannot be turned into subqueries, but we still + # need to alias them in which case copying is necessary + return result.as_(alias, quoted=self.quoted, copy=True) # apply translate rules in topological order results = op.map(fn) @@ -472,14 +485,16 @@ def fn(node, _, **kwargs): # get the root node as a sqlglot select statement out = results[op] if isinstance(out, sge.Table): - out = sg.select(STAR).from_(out) + out = sg.select(STAR, copy=False).from_(out, copy=False) elif isinstance(out, sge.Subquery): out = out.this # add cte definitions to the select statement for cte in ctes: - alias = sg.to_identifier(aliases[cte], quoted=self.quoted) - out = out.with_(alias, as_=results[cte].this, dialect=self.dialect) + alias = sg.to_identifier(aliases[cte], quoted=self.quoted, copy=False) + out = out.with_( + alias, as_=results[cte].this, dialect=self.dialect, copy=False + ) return out @@ -620,7 +635,7 @@ def visit_DefaultLiteral(self, op, *, value, dtype): items = [ self.visit_Literal( ops.Literal(v, field_dtype), value=v, dtype=field_dtype - ).as_(k, quoted=self.quoted) + ).as_(k, quoted=self.quoted, copy=False) for field_dtype, (k, v) in zip(dtype.types, value.items()) ] return sge.Struct.from_arg_list(items) @@ -831,7 +846,7 @@ def visit_IsNull(self, op, *, arg): return arg.is_(NULL) def visit_NotNull(self, op, *, arg): - return arg.is_(sg.not_(NULL)) + return arg.is_(sg.not_(NULL, copy=False)) def visit_InValues(self, op, *, value, options): return value.isin(*options) @@ -934,7 +949,9 @@ def visit_StructColumn(self, op, *, names, values): ) def visit_StructField(self, op, *, arg, field): - return sge.Dot(this=arg, expression=sg.to_identifier(field, quoted=self.quoted)) + return sge.Dot( + this=arg, expression=sg.to_identifier(field, quoted=self.quoted, copy=False) + ) def visit_IdenticalTo(self, op, *, left, right): return sge.NullSafeEQ(this=left, expression=right) @@ -1015,10 +1032,12 @@ def visit_LagLead(self, op, *, arg, offset, default): visit_Lag = visit_Lead = visit_LagLead def visit_Argument(self, op, *, name: str, shape, dtype): - return sg.to_identifier(op.param) + return sg.to_identifier(op.param, copy=False) def visit_RowID(self, op, *, table): - return sg.column(op.name, table=table.alias_or_name, quoted=self.quoted) + return sg.column( + op.name, table=table.alias_or_name, quoted=self.quoted, copy=False + ) # TODO(kszucs): this should be renamed to something UDF related def __sql_name__(self, op: ops.ScalarUDF | ops.AggUDF) -> str: @@ -1067,11 +1086,10 @@ def _dedup_name( self, key: str, value: sge.Expression ) -> Iterator[sge.Alias | sge.Column]: """Don't alias columns that are already named the same as their alias.""" - return ( - value - if isinstance(value, sge.Column) and key == value.name - else value.as_(key, quoted=self.quoted) - ) + if isinstance(value, sge.Column) and key == value.name: + return value + else: + return value.as_(key, quoted=self.quoted, copy=False) @staticmethod def _gen_valid_name(name: str) -> str: @@ -1098,18 +1116,20 @@ def visit_Select(self, op, *, parent, selections, predicates, sort_keys): result = parent if selections: - result = sg.select(*self._cleanup_names(selections)).from_(result) + result = sg.select(*self._cleanup_names(selections), copy=False).from_( + result, copy=False + ) if predicates: - result = result.where(*predicates) + result = result.where(*predicates, copy=False) if sort_keys: - result = result.order_by(*sort_keys) + result = result.order_by(*sort_keys, copy=False) return result def visit_DummyTable(self, op, *, values): - return sg.select(*self._cleanup_names(values)) + return sg.select(*self._cleanup_names(values), copy=False) def visit_UnboundTable( self, op, *, name: str, schema: sch.Schema, namespace: ops.Namespace @@ -1140,12 +1160,14 @@ def visit_SelfReference(self, op, *, parent, identifier): return parent def visit_JoinChain(self, op, *, first, rest, values): - result = sg.select(*self._cleanup_names(values)).from_(first) + result = sg.select(*self._cleanup_names(values), copy=False).from_( + first, copy=False + ) for link in rest: if isinstance(link, sge.Alias): link = link.this - result = result.join(link) + result = result.join(link, copy=False) return result def visit_JoinLink(self, op, *, how, table, predicates): @@ -1186,11 +1208,11 @@ def _generate_groups(groups): def visit_Aggregate(self, op, *, parent, groups, metrics): sel = sg.select( - *self._cleanup_names(groups), *self._cleanup_names(metrics) - ).from_(parent) + *self._cleanup_names(groups), *self._cleanup_names(metrics), copy=False + ).from_(parent, copy=False) if groups: - sel = sel.group_by(*self._generate_groups(groups.values())) + sel = sel.group_by(*self._generate_groups(groups.values()), copy=False) return sel @@ -1206,104 +1228,126 @@ def visit_Filter(self, op, *, parent, predicates): for raw_predicate, predicate in zip(op.predicates, predicates) ) try: - return parent.where(*predicates) + return parent.where(*predicates, copy=False) except AttributeError: - return sg.select(STAR).from_(parent).where(*predicates) + return ( + sg.select(STAR, copy=False) + .from_(parent, copy=False) + .where(*predicates, copy=False) + ) def visit_Sort(self, op, *, parent, keys): try: - return parent.order_by(*keys) + return parent.order_by(*keys, copy=False) except AttributeError: - return sg.select(STAR).from_(parent).order_by(*keys) + return ( + sg.select(STAR, copy=False) + .from_(parent, copy=False) + .order_by(*keys, copy=False) + ) def visit_Union(self, op, *, left, right, distinct): if isinstance(left, (sge.Table, sge.Subquery)): - left = sg.select(STAR).from_(left) + left = sg.select(STAR, copy=False).from_(left, copy=False) if isinstance(right, (sge.Table, sge.Subquery)): - right = sg.select(STAR).from_(right) + right = sg.select(STAR, copy=False).from_(right, copy=False) return sg.union( left.args.get("this", left), right.args.get("this", right), distinct=distinct, + copy=False, ) def visit_Intersection(self, op, *, left, right, distinct): if isinstance(left, (sge.Table, sge.Subquery)): - left = sg.select(STAR).from_(left) + left = sg.select(STAR, copy=False).from_(left, copy=False) if isinstance(right, (sge.Table, sge.Subquery)): - right = sg.select(STAR).from_(right) + right = sg.select(STAR, copy=False).from_(right, copy=False) return sg.intersect( left.args.get("this", left), right.args.get("this", right), distinct=distinct, + copy=False, ) def visit_Difference(self, op, *, left, right, distinct): if isinstance(left, (sge.Table, sge.Subquery)): - left = sg.select(STAR).from_(left) + left = sg.select(STAR, copy=False).from_(left, copy=False) if isinstance(right, (sge.Table, sge.Subquery)): - right = sg.select(STAR).from_(right) + right = sg.select(STAR, copy=False).from_(right, copy=False) return sg.except_( left.args.get("this", left), right.args.get("this", right), distinct=distinct, + copy=False, ) def visit_Limit(self, op, *, parent, n, offset): # push limit/offset into subqueries if isinstance(parent, sge.Subquery) and parent.this.args.get("limit") is None: - result = parent.this + result = parent.this.copy() alias = parent.alias else: - result = sg.select(STAR).from_(parent) + result = sg.select(STAR, copy=False).from_(parent, copy=False) alias = None if isinstance(n, int): - result = result.limit(n) + result = result.limit(n, copy=False) elif n is not None: - result = result.limit(sg.select(n).from_(parent).subquery()) + result = result.limit( + sg.select(n, copy=False).from_(parent, copy=False).subquery(copy=False), + copy=False, + ) else: assert n is None, n if self.no_limit_value is not None: - result = result.limit(self.no_limit_value) + result = result.limit(self.no_limit_value, copy=False) assert offset is not None, "offset is None" if not isinstance(offset, int): skip = offset - skip = sg.select(skip).from_(parent).subquery() + skip = ( + sg.select(skip, copy=False) + .from_(parent, copy=False) + .subquery(copy=False) + ) elif not offset: if alias is not None: - return result.subquery(alias) + return result.subquery(alias, copy=False) return result else: skip = offset - result = result.offset(skip) + result = result.offset(skip, copy=False) if alias is not None: - return result.subquery(alias) + return result.subquery(alias, copy=False) return result def visit_Distinct(self, op, *, parent): - return sg.select(STAR).distinct().from_(parent) + return ( + sg.select(STAR, copy=False).distinct(copy=False).from_(parent, copy=False) + ) def visit_DropNa(self, op, *, parent, how, subset): if subset is None: subset = [ - sg.column(name, table=parent.alias_or_name, quoted=self.quoted) + sg.column( + name, table=parent.alias_or_name, quoted=self.quoted, copy=False + ) for name in op.schema.names ] if subset: predicate = reduce( sg.and_ if how == "any" else sg.or_, - (sg.not_(col.is_(NULL)) for col in subset), + (sg.not_(col.is_(NULL), copy=False) for col in subset), ) elif how == "all": predicate = FALSE @@ -1314,9 +1358,13 @@ def visit_DropNa(self, op, *, parent, how, subset): return parent try: - return parent.where(predicate) + return parent.where(predicate, copy=False) except AttributeError: - return sg.select(STAR).from_(parent).where(predicate) + return ( + sg.select(STAR, copy=False) + .from_(parent, copy=False) + .where(predicate, copy=False) + ) def visit_FillNa(self, op, *, parent, replacements): if isinstance(replacements, Mapping): @@ -1329,31 +1377,38 @@ def visit_FillNa(self, op, *, parent, replacements): } exprs = { col: ( - self.f.coalesce(sg.column(col, quoted=self.quoted), sge.convert(alt)) + self.f.coalesce( + sg.column(col, quoted=self.quoted, copy=False), + sge.convert(alt, copy=False), + ) if (alt := mapping.get(col)) is not None else sg.column(col, quoted=self.quoted) ) for col in op.schema.keys() } - return sg.select(*self._cleanup_names(exprs)).from_(parent) + return sg.select(*self._cleanup_names(exprs), copy=False).from_( + parent, copy=False + ) def visit_CTE(self, op, *, parent): return sg.table(parent.alias_or_name, quoted=self.quoted) def visit_View(self, op, *, child, name: str): if isinstance(child, sge.Table): - child = sg.select(STAR).from_(child) + child = sg.select(STAR, copy=False).from_(child, copy=False) + else: + child = child.copy() try: - return child.subquery(name) + return child.subquery(name, copy=False) except AttributeError: - return child.as_(name) + return child.as_(name, copy=False) def visit_SQLStringView(self, op, *, query: str, child, schema): return sg.parse_one(query, read=self.dialect) def visit_SQLQueryResult(self, op, *, query, schema, source): - return sg.parse_one(query, dialect=self.dialect).subquery() + return sg.parse_one(query, dialect=self.dialect).subquery(copy=False) def visit_JoinTable(self, op, *, parent, index): return parent