From 9b72fc002cd773d0fb9c62cdff87bfd431e0d50c Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Thu, 8 Feb 2024 11:15:21 -0600 Subject: [PATCH] fix(trino): re-enable native TABLESAMPLE support (#8284) --- ibis/backends/trino/compiler.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/ibis/backends/trino/compiler.py b/ibis/backends/trino/compiler.py index bb4a5e402ab0..80b93e348b04 100644 --- a/ibis/backends/trino/compiler.py +++ b/ibis/backends/trino/compiler.py @@ -10,14 +10,19 @@ import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.backends.base.sqlglot.compiler import FALSE, NULL, SQLGlotCompiler, paren +from ibis.backends.base.sqlglot.compiler import ( + FALSE, + NULL, + STAR, + SQLGlotCompiler, + paren, +) from ibis.backends.base.sqlglot.datatypes import TrinoType from ibis.backends.base.sqlglot.dialects import Trino from ibis.backends.base.sqlglot.rewrites import ( exclude_unsupported_window_frame_from_ops, rewrite_first_to_first_value, rewrite_last_to_last_value, - rewrite_sample_as_filter, ) @@ -27,7 +32,6 @@ class TrinoCompiler(SQLGlotCompiler): dialect = Trino type_mapper = TrinoType rewrites = ( - rewrite_sample_as_filter, rewrite_first_to_first_value, rewrite_last_to_last_value, exclude_unsupported_window_frame_from_ops, @@ -60,6 +64,22 @@ def _minimize_spec(start, end, spec): def visit_node(self, op, **kw): return super().visit_node(op, **kw) + @visit_node.register(ops.Sample) + def visit_Sample( + self, op, *, parent, fraction: float, method: str, seed: int | None, **_ + ): + if op.seed is not None: + raise com.UnsupportedOperationError( + "`Table.sample` with a random seed is unsupported" + ) + sample = sge.TableSample( + this=parent, + method="bernoulli" if method == "row" else "system", + percent=sge.convert(fraction * 100.0), + seed=None if seed is None else sge.convert(seed), + ) + return sg.select(STAR).from_(sample) + @visit_node.register(ops.Correlation) def visit_Correlation(self, op, *, left, right, how, where): if how == "sample":