diff --git a/docs/_code/setup_penguins.qmd b/docs/_code/setup_penguins.qmd index 229f56801c93..28ad81c02445 100644 --- a/docs/_code/setup_penguins.qmd +++ b/docs/_code/setup_penguins.qmd @@ -1,6 +1,7 @@ ```{python} import ibis # <1> import ibis.selectors as s # <1> +from ibis import _ ibis.options.interactive = True # <2> diff --git a/docs/how-to/analytics/basics.qmd b/docs/how-to/analytics/basics.qmd index 0e93ee6e61f0..9a219b2342ad 100644 --- a/docs/how-to/analytics/basics.qmd +++ b/docs/how-to/analytics/basics.qmd @@ -58,7 +58,7 @@ t.mutate(bill_length_cm=t["bill_length_mm"] / 10).relocate( Use the `.join()` method to join data: ```{python} -t.join(t, t["species"] == t["species"], how="left_semi") +t.join(t, ["species"], how="left_semi") ``` ## Combining it all together @@ -66,12 +66,12 @@ t.join(t, t["species"] == t["species"], how="left_semi") We can use [the underscore to chain expressions together](./chain_expressions.qmd). ```{python} -t.join(t, t["species"] == t["species"], how="left_semi").filter( - ibis._["species"] != "Adelie" +t.join(t, ["species"], how="left_semi").filter( + _.species != "Adelie" ).group_by(["species", "island"]).aggregate( - avg_bill_length=ibis._["bill_length_mm"].mean() + avg_bill_length=_.bill_length_mm.mean() ).order_by( - ibis._["avg_bill_length"].desc() + _.avg_bill_length.desc() ) ``` diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_projection_fusion_only_peeks_at_immediate_parent/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_projection_fusion_only_peeks_at_immediate_parent/out.sql index dca603f5351e..9aeb0e7cc4d8 100644 --- a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_projection_fusion_only_peeks_at_immediate_parent/out.sql +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_projection_fusion_only_peeks_at_immediate_parent/out.sql @@ -15,5 +15,5 @@ SELECT `t3`.`val`, `t3`.`XYZ` FROM `t1` AS `t3` -INNER JOIN `t1` AS `t5` +INNER JOIN `t1` AS `t4` ON TRUE \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/test_compiler.py b/ibis/backends/bigquery/tests/unit/test_compiler.py index 0d95438df2c6..42df4c2c708d 100644 --- a/ibis/backends/bigquery/tests/unit/test_compiler.py +++ b/ibis/backends/bigquery/tests/unit/test_compiler.py @@ -274,7 +274,8 @@ class MockBackend(ibis.backends.bigquery.Backend): table = ops.SQLQueryResult("select * from t", schema, ibis_client).to_expr() for _ in range(num_joins): # noqa: F402 table = table.mutate(dummy=ibis.literal("")) - table = table.left_join(table, ["dummy"])[[table]] + table_ = table.view() + table = table.left_join(table_, ["dummy"])[[table_]] start = time.time() table.compile() diff --git a/ibis/backends/clickhouse/tests/snapshots/test_select/test_join_self_reference/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_select/test_join_self_reference/out.sql index 66518891122e..059d27e85c4a 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_select/test_join_self_reference/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_select/test_join_self_reference/out.sql @@ -13,5 +13,5 @@ SELECT "t1"."year", "t1"."month" FROM "functional_alltypes" AS "t1" -INNER JOIN "functional_alltypes" AS "t3" - ON "t1"."id" = "t3"."id" \ No newline at end of file +INNER JOIN "functional_alltypes" AS "t2" + ON "t1"."id" = "t2"."id" \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_sql/test_join_key_name2/out.sql b/ibis/backends/impala/tests/snapshots/test_sql/test_join_key_name2/out.sql index bff7789e00ed..a9ee7c0eb57f 100644 --- a/ibis/backends/impala/tests/snapshots/test_sql/test_join_key_name2/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_sql/test_join_key_name2/out.sql @@ -1 +1 @@ -WITH `t9` AS (SELECT EXTRACT(year FROM `t8`.`odate`) AS `year`, COUNT(*) AS `CountStar()` FROM (SELECT `t6`.`c_custkey`, `t6`.`c_name`, `t6`.`c_address`, `t6`.`c_nationkey`, `t6`.`c_phone`, `t6`.`c_acctbal`, `t6`.`c_mktsegment`, `t6`.`c_comment`, `t4`.`r_name` AS `region`, `t7`.`o_totalprice`, CAST(`t7`.`o_orderdate` AS TIMESTAMP) AS `odate` FROM `tpch_region` AS `t4` INNER JOIN `tpch_nation` AS `t5` ON `t4`.`r_regionkey` = `t5`.`n_regionkey` INNER JOIN `tpch_customer` AS `t6` ON `t6`.`c_nationkey` = `t5`.`n_nationkey` INNER JOIN `tpch_orders` AS `t7` ON `t7`.`o_custkey` = `t6`.`c_custkey`) AS `t8` GROUP BY 1) SELECT `t11`.`year`, `t11`.`CountStar()` AS `pre_count`, `t13`.`CountStar()` AS `post_count` FROM `t9` AS `t11` INNER JOIN `t9` AS `t13` ON `t11`.`year` = `t13`.`year` \ No newline at end of file +WITH `t9` AS (SELECT EXTRACT(year FROM `t8`.`odate`) AS `year`, COUNT(*) AS `CountStar()` FROM (SELECT `t6`.`c_custkey`, `t6`.`c_name`, `t6`.`c_address`, `t6`.`c_nationkey`, `t6`.`c_phone`, `t6`.`c_acctbal`, `t6`.`c_mktsegment`, `t6`.`c_comment`, `t4`.`r_name` AS `region`, `t7`.`o_totalprice`, CAST(`t7`.`o_orderdate` AS TIMESTAMP) AS `odate` FROM `tpch_region` AS `t4` INNER JOIN `tpch_nation` AS `t5` ON `t4`.`r_regionkey` = `t5`.`n_regionkey` INNER JOIN `tpch_customer` AS `t6` ON `t6`.`c_nationkey` = `t5`.`n_nationkey` INNER JOIN `tpch_orders` AS `t7` ON `t7`.`o_custkey` = `t6`.`c_custkey`) AS `t8` GROUP BY 1) SELECT `t11`.`year`, `t11`.`CountStar()` AS `pre_count`, `t12`.`CountStar()` AS `post_count` FROM `t9` AS `t11` INNER JOIN `t9` AS `t12` ON `t11`.`year` = `t12`.`year` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_sql/test_join_with_nested_or_condition/out.sql b/ibis/backends/impala/tests/snapshots/test_sql/test_join_with_nested_or_condition/out.sql index dda0c97b4843..7a17eb283958 100644 --- a/ibis/backends/impala/tests/snapshots/test_sql/test_join_with_nested_or_condition/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_sql/test_join_with_nested_or_condition/out.sql @@ -2,12 +2,12 @@ SELECT `t1`.`a`, `t1`.`b` FROM `t` AS `t1` -INNER JOIN `t` AS `t3` - ON `t1`.`a` = `t3`.`a` +INNER JOIN `t` AS `t2` + ON `t1`.`a` = `t2`.`a` AND ( ( - `t1`.`a` <> `t3`.`b` + `t1`.`a` <> `t2`.`b` ) OR ( - `t1`.`b` <> `t3`.`a` + `t1`.`b` <> `t2`.`a` ) ) \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_sql/test_join_with_nested_xor_condition/out.sql b/ibis/backends/impala/tests/snapshots/test_sql/test_join_with_nested_xor_condition/out.sql index 22c41b392f86..ad5bfb402644 100644 --- a/ibis/backends/impala/tests/snapshots/test_sql/test_join_with_nested_xor_condition/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_sql/test_join_with_nested_xor_condition/out.sql @@ -2,13 +2,13 @@ SELECT `t1`.`a`, `t1`.`b` FROM `t` AS `t1` -INNER JOIN `t` AS `t3` - ON `t1`.`a` = `t3`.`a` +INNER JOIN `t` AS `t2` + ON `t1`.`a` = `t2`.`a` AND ( ( - `t1`.`a` <> `t3`.`b` OR `t1`.`b` <> `t3`.`a` + `t1`.`a` <> `t2`.`b` OR `t1`.`b` <> `t2`.`a` ) AND NOT ( - `t1`.`a` <> `t3`.`b` AND `t1`.`b` <> `t3`.`a` + `t1`.`a` <> `t2`.`b` AND `t1`.`b` <> `t2`.`a` ) ) \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_sql/test_limit_cte_extract/out.sql b/ibis/backends/impala/tests/snapshots/test_sql/test_limit_cte_extract/out.sql index cfb42ddfab2c..0812de4a5bbb 100644 --- a/ibis/backends/impala/tests/snapshots/test_sql/test_limit_cte_extract/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_sql/test_limit_cte_extract/out.sql @@ -19,5 +19,5 @@ SELECT `t3`.`year`, `t3`.`month` FROM `t1` AS `t3` -INNER JOIN `t1` AS `t5` +INNER JOIN `t1` AS `t4` ON TRUE \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_sql/test_nested_join_base/out.sql b/ibis/backends/impala/tests/snapshots/test_sql/test_nested_join_base/out.sql index b5629eddffa6..bb3f1f7abd5e 100644 --- a/ibis/backends/impala/tests/snapshots/test_sql/test_nested_join_base/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_sql/test_nested_join_base/out.sql @@ -7,7 +7,7 @@ WITH `t1` AS ( 1 ) SELECT - `t5`.`uuid`, + `t3`.`uuid`, `t3`.`CountStar(t)` FROM ( SELECT diff --git a/ibis/backends/impala/tests/snapshots/test_sql/test_nested_joins_single_cte/out.sql b/ibis/backends/impala/tests/snapshots/test_sql/test_nested_joins_single_cte/out.sql index 794146f5495e..2c428cb3d9b7 100644 --- a/ibis/backends/impala/tests/snapshots/test_sql/test_nested_joins_single_cte/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_sql/test_nested_joins_single_cte/out.sql @@ -7,7 +7,7 @@ WITH `t1` AS ( 1 ) SELECT - `t7`.`uuid`, + `t4`.`uuid`, `t4`.`CountStar(t)`, `t5`.`last_visit` FROM ( @@ -28,4 +28,4 @@ LEFT OUTER JOIN ( GROUP BY 1 ) AS `t5` - ON `t7`.`uuid` = `t5`.`uuid` \ No newline at end of file + ON `t4`.`uuid` = `t5`.`uuid` \ No newline at end of file diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py index 5c97bd270415..78d784a7d20d 100644 --- a/ibis/backends/pandas/executor.py +++ b/ibis/backends/pandas/executor.py @@ -540,7 +540,7 @@ def visit(cls, op: ops.DummyTable, values): return df @classmethod - def visit(cls, op: ops.SelfReference | ops.JoinTable, parent, **kwargs): + def visit(cls, op: ops.Reference, parent, **kwargs): return parent @classmethod diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index c28e224c19c2..840738b25c87 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -1248,13 +1248,8 @@ def execute_view(op, *, ctx: pl.SQLContext, **kw): return child -@translate.register(ops.SelfReference) -def execute_self_reference(op, **kw): - return translate(op.parent, **kw) - - -@translate.register(ops.JoinTable) -def execute_join_table(op, **kw): +@translate.register(ops.Reference) +def execute_reference(op, **kw): return translate(op.parent, **kw) diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index 89bc0db53f5e..dd2317cc9fe4 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -1158,6 +1158,8 @@ def visit_DatabaseTable( def visit_SelfReference(self, op, *, parent, identifier): return parent + visit_JoinReference = visit_SelfReference + def visit_JoinChain(self, op, *, first, rest, values): result = sg.select(*self._cleanup_names(values), copy=False).from_( first, copy=False @@ -1388,9 +1390,6 @@ def visit_SQLStringView(self, op, *, query: str, child, schema): def visit_SQLQueryResult(self, op, *, query, schema, source): return sg.parse_one(query, dialect=self.dialect).subquery(copy=False) - def visit_JoinTable(self, op, *, parent, index): - return parent - def visit_RegexExtract(self, op, *, arg, pattern, index): return self.f.regexp_extract(arg, pattern, index, dialect=self.dialect) diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_cte_extract/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_cte_extract/out.sql index ca9b89c74630..e94ea21a070d 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_cte_extract/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_cte_extract/out.sql @@ -19,5 +19,5 @@ SELECT "t3"."year", "t3"."month" FROM "t1" AS "t3" -INNER JOIN "t1" AS "t5" +INNER JOIN "t1" AS "t4" ON TRUE \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_with_self_join/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_with_self_join/out.sql index 2cda81f9a69c..bd5a9e2ba9db 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_with_self_join/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_limit_with_self_join/out.sql @@ -15,20 +15,20 @@ FROM ( "t1"."timestamp_col", "t1"."year", "t1"."month", - "t3"."id" AS "id_right", - "t3"."bool_col" AS "bool_col_right", - "t3"."tinyint_col" AS "tinyint_col_right", - "t3"."smallint_col" AS "smallint_col_right", - "t3"."int_col" AS "int_col_right", - "t3"."bigint_col" AS "bigint_col_right", - "t3"."float_col" AS "float_col_right", - "t3"."double_col" AS "double_col_right", - "t3"."date_string_col" AS "date_string_col_right", - "t3"."string_col" AS "string_col_right", - "t3"."timestamp_col" AS "timestamp_col_right", - "t3"."year" AS "year_right", - "t3"."month" AS "month_right" + "t2"."id" AS "id_right", + "t2"."bool_col" AS "bool_col_right", + "t2"."tinyint_col" AS "tinyint_col_right", + "t2"."smallint_col" AS "smallint_col_right", + "t2"."int_col" AS "int_col_right", + "t2"."bigint_col" AS "bigint_col_right", + "t2"."float_col" AS "float_col_right", + "t2"."double_col" AS "double_col_right", + "t2"."date_string_col" AS "date_string_col_right", + "t2"."string_col" AS "string_col_right", + "t2"."timestamp_col" AS "timestamp_col_right", + "t2"."year" AS "year_right", + "t2"."month" AS "month_right" FROM "functional_alltypes" AS "t1" - INNER JOIN "functional_alltypes" AS "t3" - ON "t1"."tinyint_col" < EXTRACT(minute FROM "t3"."timestamp_col") -) AS "t4" \ No newline at end of file + INNER JOIN "functional_alltypes" AS "t2" + ON "t1"."tinyint_col" < EXTRACT(minute FROM "t2"."timestamp_col") +) AS "t3" \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_self_join_subquery_distinct_equal/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_self_join_subquery_distinct_equal/out.sql index 52fe23dbf9e4..321a92d13028 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_self_join_subquery_distinct_equal/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_self_join_subquery_distinct_equal/out.sql @@ -1,6 +1,6 @@ SELECT "t2"."r_name", - "t6"."n_name" + "t5"."n_name" FROM "tpch_region" AS "t2" INNER JOIN "tpch_nation" AS "t3" ON "t2"."r_regionkey" = "t3"."n_regionkey" @@ -16,5 +16,5 @@ INNER JOIN ( FROM "tpch_region" AS "t2" INNER JOIN "tpch_nation" AS "t3" ON "t2"."r_regionkey" = "t3"."n_regionkey" -) AS "t6" - ON "t2"."r_regionkey" = "t6"."r_regionkey" \ No newline at end of file +) AS "t5" + ON "t2"."r_regionkey" = "t5"."r_regionkey" \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_subquery_in_union/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_subquery_in_union/out.sql index 0d28860d3c25..6c4beaaccf15 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_subquery_in_union/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_subquery_in_union/out.sql @@ -7,25 +7,25 @@ WITH "t1" AS ( GROUP BY 1, 2 -), "t6" AS ( +), "t5" AS ( SELECT "t3"."a", "t3"."g", "t3"."metric" FROM "t1" AS "t3" - INNER JOIN "t1" AS "t5" - ON "t3"."g" = "t5"."g" + INNER JOIN "t1" AS "t4" + ON "t3"."g" = "t4"."g" ) SELECT - "t9"."a", - "t9"."g", - "t9"."metric" + "t8"."a", + "t8"."g", + "t8"."metric" FROM ( SELECT * - FROM "t6" AS "t7" + FROM "t5" AS "t6" UNION ALL SELECT * - FROM "t6" AS "t8" -) AS "t9" \ No newline at end of file + FROM "t5" AS "t7" +) AS "t8" \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_subquery_used_for_self_join/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_subquery_used_for_self_join/out.sql index 4039868d3eac..a9b27450205b 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_subquery_used_for_self_join/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_subquery_used_for_self_join/out.sql @@ -11,21 +11,21 @@ WITH "t1" AS ( 3 ) SELECT - "t6"."g", - MAX("t6"."total" - "t6"."total_right") AS "metric" + "t5"."g", + MAX("t5"."total" - "t5"."total_right") AS "metric" FROM ( SELECT "t3"."g", "t3"."a", "t3"."b", "t3"."total", - "t5"."g" AS "g_right", - "t5"."a" AS "a_right", - "t5"."b" AS "b_right", - "t5"."total" AS "total_right" + "t4"."g" AS "g_right", + "t4"."a" AS "a_right", + "t4"."b" AS "b_right", + "t4"."total" AS "total_right" FROM "t1" AS "t3" - INNER JOIN "t1" AS "t5" - ON "t3"."a" = "t5"."b" -) AS "t6" + INNER JOIN "t1" AS "t4" + ON "t3"."a" = "t4"."b" +) AS "t5" GROUP BY 1 \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_tpch_self_join_failure/out.sql b/ibis/backends/tests/sql/snapshots/test_select_sql/test_tpch_self_join_failure/out.sql index bf5316341bdf..18f3ee847814 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_tpch_self_join_failure/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_tpch_self_join_failure/out.sql @@ -24,9 +24,9 @@ WITH "t9" AS ( SELECT "t11"."region", "t11"."year", - "t11"."total" - "t13"."total" AS "yoy_change" + "t11"."total" - "t12"."total" AS "yoy_change" FROM "t9" AS "t11" -INNER JOIN "t9" AS "t13" +INNER JOIN "t9" AS "t12" ON "t11"."year" = ( - "t13"."year" - CAST(1 AS TINYINT) + "t12"."year" - CAST(1 AS TINYINT) ) \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_sql/test_cte_factor_distinct_but_equal/out.sql b/ibis/backends/tests/sql/snapshots/test_sql/test_cte_factor_distinct_but_equal/out.sql index 47a9c20dfbed..d4c24d4f2192 100644 --- a/ibis/backends/tests/sql/snapshots/test_sql/test_cte_factor_distinct_but_equal/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_sql/test_cte_factor_distinct_but_equal/out.sql @@ -16,5 +16,5 @@ INNER JOIN ( FROM "alltypes" AS "t1" GROUP BY 1 -) AS "t6" - ON "t3"."g" = "t6"."g" \ No newline at end of file +) AS "t5" + ON "t3"."g" = "t5"."g" \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_sql/test_self_reference_join/out.sql b/ibis/backends/tests/sql/snapshots/test_sql/test_self_reference_join/out.sql index ac92348404fd..927c6c569ef6 100644 --- a/ibis/backends/tests/sql/snapshots/test_sql/test_self_reference_join/out.sql +++ b/ibis/backends/tests/sql/snapshots/test_sql/test_self_reference_join/out.sql @@ -4,5 +4,5 @@ SELECT "t1"."foo_id", "t1"."bar_id" FROM "star1" AS "t1" -INNER JOIN "star1" AS "t3" - ON "t1"."foo_id" = "t3"."bar_id" \ No newline at end of file +INNER JOIN "star1" AS "t2" + ON "t1"."foo_id" = "t2"."bar_id" \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/duckdb/h07.sql b/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/duckdb/h07.sql index fca7ba9abc97..8a5689ab25a2 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/duckdb/h07.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/duckdb/h07.sql @@ -1,27 +1,27 @@ SELECT - "t14"."supp_nation", - "t14"."cust_nation", - "t14"."l_year", - "t14"."revenue" + "t13"."supp_nation", + "t13"."cust_nation", + "t13"."l_year", + "t13"."revenue" FROM ( SELECT - "t13"."supp_nation", - "t13"."cust_nation", - "t13"."l_year", - SUM("t13"."volume") AS "revenue" + "t12"."supp_nation", + "t12"."cust_nation", + "t12"."l_year", + SUM("t12"."volume") AS "revenue" FROM ( SELECT - "t12"."supp_nation", - "t12"."cust_nation", - "t12"."l_shipdate", - "t12"."l_extendedprice", - "t12"."l_discount", - "t12"."l_year", - "t12"."volume" + "t11"."supp_nation", + "t11"."cust_nation", + "t11"."l_shipdate", + "t11"."l_extendedprice", + "t11"."l_discount", + "t11"."l_year", + "t11"."volume" FROM ( SELECT "t9"."n_name" AS "supp_nation", - "t11"."n_name" AS "cust_nation", + "t10"."n_name" AS "cust_nation", "t6"."l_shipdate", "t6"."l_extendedprice", "t6"."l_discount", @@ -38,34 +38,34 @@ FROM ( ON "t8"."c_custkey" = "t7"."o_custkey" INNER JOIN "nation" AS "t9" ON "t5"."s_nationkey" = "t9"."n_nationkey" - INNER JOIN "nation" AS "t11" - ON "t8"."c_nationkey" = "t11"."n_nationkey" - ) AS "t12" + INNER JOIN "nation" AS "t10" + ON "t8"."c_nationkey" = "t10"."n_nationkey" + ) AS "t11" WHERE ( ( ( - "t12"."cust_nation" = 'FRANCE' + "t11"."cust_nation" = 'FRANCE' ) AND ( - "t12"."supp_nation" = 'GERMANY' + "t11"."supp_nation" = 'GERMANY' ) ) OR ( ( - "t12"."cust_nation" = 'GERMANY' + "t11"."cust_nation" = 'GERMANY' ) AND ( - "t12"."supp_nation" = 'FRANCE' + "t11"."supp_nation" = 'FRANCE' ) ) ) - AND "t12"."l_shipdate" BETWEEN MAKE_DATE(1995, 1, 1) AND MAKE_DATE(1996, 12, 31) - ) AS "t13" + AND "t11"."l_shipdate" BETWEEN MAKE_DATE(1995, 1, 1) AND MAKE_DATE(1996, 12, 31) + ) AS "t12" GROUP BY 1, 2, 3 -) AS "t14" +) AS "t13" ORDER BY - "t14"."supp_nation" ASC, - "t14"."cust_nation" ASC, - "t14"."l_year" ASC \ No newline at end of file + "t13"."supp_nation" ASC, + "t13"."cust_nation" ASC, + "t13"."l_year" ASC \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/trino/h07.sql b/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/trino/h07.sql index 986668485bbf..8fee2e970185 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/trino/h07.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/trino/h07.sql @@ -7,29 +7,29 @@ WITH "t5" AS ( FROM "hive"."ibis_sf1"."nation" AS "t4" ) SELECT - "t20"."supp_nation", - "t20"."cust_nation", - "t20"."l_year", - "t20"."revenue" + "t19"."supp_nation", + "t19"."cust_nation", + "t19"."l_year", + "t19"."revenue" FROM ( SELECT - "t19"."supp_nation", - "t19"."cust_nation", - "t19"."l_year", - SUM("t19"."volume") AS "revenue" + "t18"."supp_nation", + "t18"."cust_nation", + "t18"."l_year", + SUM("t18"."volume") AS "revenue" FROM ( SELECT - "t18"."supp_nation", - "t18"."cust_nation", - "t18"."l_shipdate", - "t18"."l_extendedprice", - "t18"."l_discount", - "t18"."l_year", - "t18"."volume" + "t17"."supp_nation", + "t17"."cust_nation", + "t17"."l_shipdate", + "t17"."l_extendedprice", + "t17"."l_discount", + "t17"."l_year", + "t17"."volume" FROM ( SELECT "t15"."n_name" AS "supp_nation", - "t17"."n_name" AS "cust_nation", + "t16"."n_name" AS "cust_nation", "t12"."l_shipdate", "t12"."l_extendedprice", "t12"."l_discount", @@ -98,34 +98,34 @@ FROM ( ON "t14"."c_custkey" = "t13"."o_custkey" INNER JOIN "t5" AS "t15" ON "t11"."s_nationkey" = "t15"."n_nationkey" - INNER JOIN "t5" AS "t17" - ON "t14"."c_nationkey" = "t17"."n_nationkey" - ) AS "t18" + INNER JOIN "t5" AS "t16" + ON "t14"."c_nationkey" = "t16"."n_nationkey" + ) AS "t17" WHERE ( ( ( - "t18"."cust_nation" = 'FRANCE' + "t17"."cust_nation" = 'FRANCE' ) AND ( - "t18"."supp_nation" = 'GERMANY' + "t17"."supp_nation" = 'GERMANY' ) ) OR ( ( - "t18"."cust_nation" = 'GERMANY' + "t17"."cust_nation" = 'GERMANY' ) AND ( - "t18"."supp_nation" = 'FRANCE' + "t17"."supp_nation" = 'FRANCE' ) ) ) - AND "t18"."l_shipdate" BETWEEN FROM_ISO8601_DATE('1995-01-01') AND FROM_ISO8601_DATE('1996-12-31') - ) AS "t19" + AND "t17"."l_shipdate" BETWEEN FROM_ISO8601_DATE('1995-01-01') AND FROM_ISO8601_DATE('1996-12-31') + ) AS "t18" GROUP BY 1, 2, 3 -) AS "t20" +) AS "t19" ORDER BY - "t20"."supp_nation" ASC, - "t20"."cust_nation" ASC, - "t20"."l_year" ASC \ No newline at end of file + "t19"."supp_nation" ASC, + "t19"."cust_nation" ASC, + "t19"."l_year" ASC \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/duckdb/h08.sql b/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/duckdb/h08.sql index e2a61a509543..3fe20e69b03d 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/duckdb/h08.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/duckdb/h08.sql @@ -1,26 +1,26 @@ SELECT - "t18"."o_year", - "t18"."mkt_share" + "t17"."o_year", + "t17"."mkt_share" FROM ( SELECT - "t17"."o_year", - SUM("t17"."nation_volume") / SUM("t17"."volume") AS "mkt_share" + "t16"."o_year", + SUM("t16"."nation_volume") / SUM("t16"."volume") AS "mkt_share" FROM ( SELECT - "t16"."o_year", - "t16"."volume", - "t16"."nation", - "t16"."r_name", - "t16"."o_orderdate", - "t16"."p_type", - CASE WHEN "t16"."nation" = 'BRAZIL' THEN "t16"."volume" ELSE CAST(0 AS TINYINT) END AS "nation_volume" + "t15"."o_year", + "t15"."volume", + "t15"."nation", + "t15"."r_name", + "t15"."o_orderdate", + "t15"."p_type", + CASE WHEN "t15"."nation" = 'BRAZIL' THEN "t15"."volume" ELSE CAST(0 AS TINYINT) END AS "nation_volume" FROM ( SELECT EXTRACT(year FROM "t10"."o_orderdate") AS "o_year", "t8"."l_extendedprice" * ( CAST(1 AS TINYINT) - "t8"."l_discount" ) AS "volume", - "t15"."n_name" AS "nation", + "t13"."n_name" AS "nation", "t14"."r_name", "t10"."o_orderdate", "t7"."p_type" @@ -37,16 +37,16 @@ FROM ( ON "t11"."c_nationkey" = "t12"."n_nationkey" INNER JOIN "region" AS "t14" ON "t12"."n_regionkey" = "t14"."r_regionkey" - INNER JOIN "nation" AS "t15" - ON "t9"."s_nationkey" = "t15"."n_nationkey" - ) AS "t16" + INNER JOIN "nation" AS "t13" + ON "t9"."s_nationkey" = "t13"."n_nationkey" + ) AS "t15" WHERE - "t16"."r_name" = 'AMERICA' - AND "t16"."o_orderdate" BETWEEN MAKE_DATE(1995, 1, 1) AND MAKE_DATE(1996, 12, 31) - AND "t16"."p_type" = 'ECONOMY ANODIZED STEEL' - ) AS "t17" + "t15"."r_name" = 'AMERICA' + AND "t15"."o_orderdate" BETWEEN MAKE_DATE(1995, 1, 1) AND MAKE_DATE(1996, 12, 31) + AND "t15"."p_type" = 'ECONOMY ANODIZED STEEL' + ) AS "t16" GROUP BY 1 -) AS "t18" +) AS "t17" ORDER BY - "t18"."o_year" ASC \ No newline at end of file + "t17"."o_year" ASC \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/trino/h08.sql b/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/trino/h08.sql index 80ae67d1ebd1..5ce6e33feef2 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/trino/h08.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/trino/h08.sql @@ -7,28 +7,28 @@ WITH "t8" AS ( FROM "hive"."ibis_sf1"."nation" AS "t6" ) SELECT - "t26"."o_year", - "t26"."mkt_share" + "t25"."o_year", + "t25"."mkt_share" FROM ( SELECT - "t25"."o_year", - CAST(SUM("t25"."nation_volume") AS DOUBLE) / SUM("t25"."volume") AS "mkt_share" + "t24"."o_year", + CAST(SUM("t24"."nation_volume") AS DOUBLE) / SUM("t24"."volume") AS "mkt_share" FROM ( SELECT - "t24"."o_year", - "t24"."volume", - "t24"."nation", - "t24"."r_name", - "t24"."o_orderdate", - "t24"."p_type", - CASE WHEN "t24"."nation" = 'BRAZIL' THEN "t24"."volume" ELSE 0 END AS "nation_volume" + "t23"."o_year", + "t23"."volume", + "t23"."nation", + "t23"."r_name", + "t23"."o_orderdate", + "t23"."p_type", + CASE WHEN "t23"."nation" = 'BRAZIL' THEN "t23"."volume" ELSE 0 END AS "nation_volume" FROM ( SELECT EXTRACT(year FROM "t19"."o_orderdate") AS "o_year", "t17"."l_extendedprice" * ( 1 - "t17"."l_discount" ) AS "volume", - "t23"."n_name" AS "nation", + "t22"."n_name" AS "nation", "t14"."r_name", "t19"."o_orderdate", "t16"."p_type" @@ -115,16 +115,16 @@ FROM ( FROM "hive"."ibis_sf1"."region" AS "t5" ) AS "t14" ON "t21"."n_regionkey" = "t14"."r_regionkey" - INNER JOIN "t8" AS "t23" - ON "t18"."s_nationkey" = "t23"."n_nationkey" - ) AS "t24" + INNER JOIN "t8" AS "t22" + ON "t18"."s_nationkey" = "t22"."n_nationkey" + ) AS "t23" WHERE - "t24"."r_name" = 'AMERICA' - AND "t24"."o_orderdate" BETWEEN FROM_ISO8601_DATE('1995-01-01') AND FROM_ISO8601_DATE('1996-12-31') - AND "t24"."p_type" = 'ECONOMY ANODIZED STEEL' - ) AS "t25" + "t23"."r_name" = 'AMERICA' + AND "t23"."o_orderdate" BETWEEN FROM_ISO8601_DATE('1995-01-01') AND FROM_ISO8601_DATE('1996-12-31') + AND "t23"."p_type" = 'ECONOMY ANODIZED STEEL' + ) AS "t24" GROUP BY 1 -) AS "t26" +) AS "t25" ORDER BY - "t26"."o_year" ASC \ No newline at end of file + "t25"."o_year" ASC \ No newline at end of file diff --git a/ibis/common/graph.py b/ibis/common/graph.py index c4975e37d360..ec7ea763ce78 100644 --- a/ibis/common/graph.py +++ b/ibis/common/graph.py @@ -627,13 +627,12 @@ def traverse( The Node expression or a list of expressions. """ - - args = reversed(node) if isinstance(node, Sequence) else [node] - todo: deque[Node] = deque(args) + nodes = list(_flatten_collections(promote_list(node))) + queue: deque[Node] = deque(reversed(nodes)) seen: set[Node] = set() - while todo: - node = todo.pop() + while queue: + node = queue.pop() if node in seen: continue @@ -654,7 +653,7 @@ def traverse( "an instance of boolean or iterable" ) - todo.extend(reversed(children)) + queue.extend(reversed(children)) def bfs(root: Node) -> Graph: diff --git a/ibis/common/tests/test_graph.py b/ibis/common/tests/test_graph.py index 7b05287b4ce6..ea0ea5c9136f 100644 --- a/ibis/common/tests/test_graph.py +++ b/ibis/common/tests/test_graph.py @@ -16,6 +16,7 @@ bfs_while, dfs, dfs_while, + traverse, ) from ibis.common.grounds import Annotable, Concrete from ibis.common.patterns import Eq, If, InstanceOf, Object, TupleOf, _ @@ -460,3 +461,11 @@ def record_result_keys(node, results, **kwargs): result = X.map_clear(record_result_keys) assert result == X assert result_sequence == expected_result_sequence + + +def test_traverse(): + def walker(node): + return True, node + + result = list(traverse(walker, A)) + assert result == [A, B, D, E, C] diff --git a/ibis/expr/decompile.py b/ibis/expr/decompile.py index 43452fa0c09a..c7ea1015975a 100644 --- a/ibis/expr/decompile.py +++ b/ibis/expr/decompile.py @@ -192,8 +192,8 @@ def self_reference(op, parent, identifier): return f"{parent}.view()" -@translate.register(ops.JoinTable) -def join_table(op, parent, index): +@translate.register(ops.JoinReference) +def join_reference(op, parent, identifier): return parent @@ -353,7 +353,7 @@ class CodeContext: ) always_ignore = ( - ops.JoinTable, + ops.JoinReference, ops.Field, dt.Primitive, dt.Variadic, diff --git a/ibis/expr/format.py b/ibis/expr/format.py index 51d827884e0b..00b93ae5f2f6 100644 --- a/ibis/expr/format.py +++ b/ibis/expr/format.py @@ -181,7 +181,7 @@ def mapper(op, _, **kwargs): if var := variables.get(op): refs[op] = result result = var - elif isinstance(op, ops.Relation) and not isinstance(op, ops.JoinTable): + elif isinstance(op, ops.Relation) and not isinstance(op, ops.JoinReference): refs[op] = result result = f"r{next(refcnt)}" return Rendered(result) @@ -364,8 +364,8 @@ def _self_reference(op, parent, **kwargs): return f"{op.__class__.__name__}[{parent}]" -@fmt.register(ops.JoinTable) -def _join_table(op, parent, index): +@fmt.register(ops.JoinReference) +def _join_reference(op, parent, **kwargs): return parent diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index 4e38359520e8..4618de7a4d52 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -133,11 +133,10 @@ def schema(self): return self.parent.schema -# TODO(kszucs): remove in favor of View @public -class SelfReference(Simple): +class Reference(Relation): _uid_counter = itertools.count() - + parent: Relation identifier: Optional[int] = None def __init__(self, parent, identifier): @@ -145,9 +144,22 @@ def __init__(self, parent, identifier): identifier = next(self._uid_counter) super().__init__(parent=parent, identifier=identifier) + @attribute + def schema(self): + return self.parent.schema + + +# TODO(kszucs): remove in favor of View +@public +class SelfReference(Reference): + values = FrozenDict() + + +@public +class JoinReference(Reference): @attribute def values(self): - return FrozenDict() + return self.parent.fields JoinKind = Literal[ @@ -164,21 +176,16 @@ def values(self): ] -@public -class JoinTable(Simple): - index: int - - @public class JoinLink(Node): how: JoinKind - table: JoinTable + table: Reference predicates: VarTuple[Value[dt.Boolean]] @public class JoinChain(Relation): - first: JoinTable + first: Reference rest: VarTuple[JoinLink] values: FrozenDict[str, Unaliased[Value]] @@ -194,6 +201,10 @@ def __init__(self, first, rest, values): _check_integrity(values.values(), allowed_parents) super().__init__(first=first, rest=rest, values=values) + @property + def tables(self): + return [self.first] + [link.table for link in self.rest] + @property def length(self): return len(self.rest) + 1 diff --git a/ibis/expr/rewrites.py b/ibis/expr/rewrites.py index 5b03b024d89e..a30603adaa87 100644 --- a/ibis/expr/rewrites.py +++ b/ibis/expr/rewrites.py @@ -3,16 +3,21 @@ from __future__ import annotations import functools +from collections import defaultdict from collections.abc import Mapping import toolz import ibis.expr.datatypes as dt import ibis.expr.operations as ops +from ibis.common.collections import FrozenDict # noqa: TCH001 from ibis.common.deferred import Item, _, deferred, var -from ibis.common.exceptions import ExpressionError +from ibis.common.exceptions import ExpressionError, IbisInputError +from ibis.common.graph import Node as Traversable +from ibis.common.grounds import Concrete from ibis.common.patterns import Check, pattern, replace -from ibis.util import Namespace +from ibis.common.typing import VarTuple # noqa: TCH001 +from ibis.util import Namespace, promote_list p = Namespace(pattern, module=ops) d = Namespace(deferred, module=ops) @@ -23,6 +28,132 @@ name = var("name") +class DerefMap(Concrete, Traversable): + """Trace and replace fields from earlier relations in the hierarchy. + + In order to provide a nice user experience, we need to allow expressions + from earlier relations in the hierarchy. Consider the following example: + + t = ibis.table([('a', 'int64'), ('b', 'string')], name='t') + t1 = t.select([t.a, t.b]) + t2 = t1.filter(t.a > 0) # note that not t1.a is referenced here + t3 = t2.select(t.a) # note that not t2.a is referenced here + + However the relational operations in the IR are strictly enforcing that + the expressions are referencing the immediate parent only. So we need to + track fields upwards the hierarchy to replace `t.a` with `t1.a` and `t2.a` + in the example above. This is called dereferencing. + + Whether we can treat or not a field of a relation semantically equivalent + with a field of an earlier relation in the hierarchy depends on the + `.values` mapping of the relation. Leaf relations, like `t` in the example + above, have an empty `.values` mapping, so we cannot dereference fields + from them. On the other hand a projection, like `t1` in the example above, + has a `.values` mapping like `{'a': t.a, 'b': t.b}`, so we can deduce that + `t1.a` is semantically equivalent with `t.a` and so on. + """ + + """The relations we want the values to point to.""" + rels: VarTuple[ops.Relation] + + """Substitution mapping from values of earlier relations to the fields of `rels`.""" + subs: FrozenDict[ops.Value, ops.Field] + + """Ambiguous field references.""" + ambigs: FrozenDict[ops.Value, VarTuple[ops.Value]] + + @classmethod + def from_targets(cls, rels, extra=None): + """Create a dereference map from a list of target relations. + + Usually a single relation is passed except for joins where multiple + relations are involved. + + Parameters + ---------- + rels : list of ops.Relation + The target relations to dereference to. + extra : dict, optional + Extra substitutions to be added to the dereference map. + + Returns + ------- + DerefMap + """ + rels = promote_list(rels) + mapping = defaultdict(dict) + for rel in rels: + for field in rel.fields.values(): + for value, distance in cls.backtrack(field): + mapping[value][field] = distance + + subs, ambigs = {}, {} + for from_, to in mapping.items(): + mindist = min(to.values()) + minkeys = [k for k, v in to.items() if v == mindist] + # if all the closest fields are from the same relation, then we + # can safely substitute them and we pick the first one arbitrarily + if all(minkeys[0].relations == k.relations for k in minkeys): + subs[from_] = minkeys[0] + else: + ambigs[from_] = minkeys + + if extra is not None: + subs.update(extra) + + return cls(rels, subs, ambigs) + + @classmethod + def backtrack(cls, value): + """Backtrack the field in the relation hierarchy. + + The field is traced back until no modification is made, so only follow + ops.Field nodes not arbitrary values. + + Parameters + ---------- + value : ops.Value + The value to backtrack. + + Yields + ------ + tuple[ops.Field, int] + The value node and the distance from the original value. + """ + distance = 0 + # track down the field in the hierarchy until no modification + # is made so only follow ops.Field nodes not arbitrary values; + while isinstance(value, ops.Field): + yield value, distance + value = value.rel.values.get(value.name) + distance += 1 + if value is not None and not value.find(ops.Impure, filter=ops.Value): + yield value, distance + + def dereference(self, value): + """Dereference a value to the target relations. + + Also check for ambiguous field references. If a field reference is found + which is marked as ambiguous, then raise an error. + + Parameters + ---------- + value : ops.Value + The value to dereference. + + Returns + ------- + ops.Value + The dereferenced value. + """ + ambigs = value.find(lambda x: x in self.ambigs, filter=ops.Value) + if ambigs: + raise IbisInputError( + f"Ambiguous field reference {ambigs!r} in expression {value!r}" + ) + return value.replace(self.subs, filter=ops.Value) + + @replace(p.Field(p.JoinChain)) def peel_join_field(_): return _.rel.values[_.name] diff --git a/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt b/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt index a29fc083cf3a..3e39b232a162 100644 --- a/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt @@ -6,15 +6,17 @@ r1 := UnboundTable: right time2 int32 value2 float64 +r2 := SelfReference[r1] + JoinChain[r0] JoinLink[asof, r1] r0.time1 >= r1.time2 - JoinLink[inner, r1] - r0.value == r1.value2 + JoinLink[inner, r2] + r0.value == r2.value2 values: time1: r0.time1 value: r0.value time2: r1.time2 value2: r1.value2 - time2_right: r1.time2 - value2_right: r1.value2 \ No newline at end of file + time2_right: r2.time2 + value2_right: r2.value2 \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt b/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt index 672faadf9ba2..0facaef4d746 100644 --- a/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt @@ -8,11 +8,13 @@ r1 := UnboundTable: right value2 float64 b string +r2 := SelfReference[r1] + JoinChain[r0] JoinLink[inner, r1] r0.a == r1.b - JoinLink[inner, r1] - r0.value == r1.value2 + JoinLink[inner, r2] + r0.value == r2.value2 values: time1: r0.time1 value: r0.value @@ -20,6 +22,6 @@ JoinChain[r0] time2: r1.time2 value2: r1.value2 b: r1.b - time2_right: r1.time2 - value2_right: r1.value2 - b_right: r1.b \ No newline at end of file + time2_right: r2.time2 + value2_right: r2.value2 + b_right: r2.b \ No newline at end of file diff --git a/ibis/expr/tests/test_dereference.py b/ibis/expr/tests/test_dereference.py index e19234f92084..4396c0ac5fb0 100644 --- a/ibis/expr/tests/test_dereference.py +++ b/ibis/expr/tests/test_dereference.py @@ -1,7 +1,7 @@ from __future__ import annotations import ibis -from ibis.expr.types.relations import dereference_mapping +from ibis.expr.types.relations import DerefMap t = ibis.table( [ @@ -20,7 +20,7 @@ def dereference_expect(expected): def test_dereference_project(): p = t.select([t.int_col, t.double_col]) - mapping = dereference_mapping([p.op()]) + mapping = DerefMap.from_targets([p.op()]) expected = dereference_expect( { p.int_col: p.int_col, @@ -29,13 +29,13 @@ def test_dereference_project(): t.double_col: p.double_col, } ) - assert mapping == expected + assert mapping.subs == expected def test_dereference_mapping_self_reference(): v = t.view() - mapping = dereference_mapping([v.op()]) + mapping = DerefMap.from_targets([v.op()]) expected = dereference_expect( { v.int_col: v.int_col, @@ -43,4 +43,4 @@ def test_dereference_mapping_self_reference(): v.string_col: v.string_col, } ) - assert mapping == expected + assert mapping.subs == expected diff --git a/ibis/expr/tests/test_format.py b/ibis/expr/tests/test_format.py index 6df58a0a9702..4c24f09c7bc6 100644 --- a/ibis/expr/tests/test_format.py +++ b/ibis/expr/tests/test_format.py @@ -309,8 +309,9 @@ def test_fillna(snapshot): def test_asof_join(snapshot): left = ibis.table([("time1", "int32"), ("value", "double")], name="left") right = ibis.table([("time2", "int32"), ("value2", "double")], name="right") + right_ = right.view() joined = left.asof_join(right, ("time1", "time2")).inner_join( - right, left.value == right.value2 + right_, left.value == right_.value2 ) result = repr(joined) @@ -324,8 +325,9 @@ def test_two_inner_joins(snapshot): right = ibis.table( [("time2", "int32"), ("value2", "double"), ("b", "string")], name="right" ) + right_ = right.view() joined = left.inner_join(right, left.a == right.b).inner_join( - right, left.value == right.value2 + right_, left.value == right_.value2 ) result = repr(joined) diff --git a/ibis/expr/tests/test_newrels.py b/ibis/expr/tests/test_newrels.py index 276cd6607bb5..22588deb65d8 100644 --- a/ibis/expr/tests/test_newrels.py +++ b/ibis/expr/tests/test_newrels.py @@ -35,8 +35,8 @@ @contextlib.contextmanager -def join_tables(*tables): - yield tuple(ops.JoinTable(t, i).to_expr() for i, t in enumerate(tables)) +def join_tables(table): + yield [t.to_expr() for t in table.op().tables] def test_field(): @@ -565,7 +565,7 @@ def test_join(): assert isinstance(joined.op(), JoinChain) assert isinstance(joined.op().to_expr(), ir.Join) - with join_tables(t1, t2) as (t1, t2): + with join_tables(joined) as (t1, t2): assert result.op() == JoinChain( first=t1, rest=[ @@ -584,16 +584,16 @@ def test_join_integrity_checks(): t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"}) # correct example - r1 = ops.JoinTable(t1, 10) - r2 = ops.JoinTable(t1, 20) + r1 = ops.JoinReference(t1, 10) + r2 = ops.JoinReference(t1, 20) assert r1 != r2 assert hash(r1) != hash(r2) chain = ops.JoinChain(r1, [ops.JoinLink("inner", r2, [True])], values={}) assert isinstance(chain, JoinChain) # not unique tables - r1 = ops.JoinTable(t1, 10) - r2 = ops.JoinTable(t1, 10) + r1 = ops.JoinReference(t1, 10) + r2 = ops.JoinReference(t1, 10) assert r1 == r2 assert hash(r1) == hash(r2) with pytest.raises(IntegrityError): @@ -609,7 +609,7 @@ def test_join_unambiguous_select(): expr2 = join.select("a_int", "b_int") assert expr1.equals(expr2) - with join_tables(a, b) as (r1, r2): + with join_tables(join) as (r1, r2): assert expr1.op() == JoinChain( first=r1, rest=[JoinLink("inner", r2, [r1.a_int == r2.b_int])], @@ -627,7 +627,7 @@ def test_join_with_subsequent_projection(): # a single computed value is pulled to a subsequent projection joined = t1.join(t2, [t1.a == t2.c]) expr = joined.select(t1.a, t1.b, col=t2.c + 1) - with join_tables(t1, t2) as (r1, r2): + with join_tables(joined) as (r1, r2): expected = JoinChain( first=r1, rest=[JoinLink("inner", r2, [r1.a == r2.c])], @@ -645,7 +645,7 @@ def test_join_with_subsequent_projection(): baz=t2.d.name("bar") + "3", baz2=(t2.c + t1.a).name("foo"), ) - with join_tables(t1, t2) as (r1, r2): + with join_tables(joined) as (r1, r2): expected = JoinChain( first=r1, rest=[JoinLink("inner", r2, [r1.a == r2.c])], @@ -674,7 +674,7 @@ def test_join_with_subsequent_projection_colliding_names(): foo=t2.a + 1, bar=t1.a + t2.a, ) - with join_tables(t1, t2) as (r1, r2): + with join_tables(expr) as (r1, r2): expected = JoinChain( first=r1, rest=[JoinLink("inner", r2, [r1.a == r2.a])], @@ -695,7 +695,7 @@ def test_chained_join(): joined = a.join(b, [a.a == b.c]).join(c, [a.a == c.e]) result = joined._finish() - with join_tables(a, b, c) as (r1, r2, r3): + with join_tables(joined) as (r1, r2, r3): assert result.op() == JoinChain( first=r1, rest=[ @@ -715,7 +715,7 @@ def test_chained_join(): joined = a.join(b, [a.a == b.c]).join(c, [b.c == c.e]) result = joined.select(a.a, b.d, c.f) - with join_tables(a, b, c) as (r1, r2, r3): + with join_tables(joined) as (r1, r2, r3): assert result.op() == JoinChain( first=r1, rest=[ @@ -739,7 +739,7 @@ def test_chained_join_referencing_intermediate_table(): abc = ab.join(c, [ab.a == c.e]) result = abc._finish() - with join_tables(a, b, c) as (r1, r2, r3): + with join_tables(abc) as (r1, r2, r3): assert result.op() == JoinChain( first=r1, rest=[ @@ -772,7 +772,7 @@ def test_join_predicate_dereferencing(): # dereference table.foo_id to filtered.foo_id j1 = filtered.left_join(table2, table["foo_id"] == table2["foo_id"]) - with join_tables(filtered, table2) as (r1, r2): + with join_tables(j1) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -793,7 +793,7 @@ def test_join_predicate_dereferencing(): j1 = filtered.left_join(table2, table["foo_id"] == table2["foo_id"]) j2 = j1.inner_join(table3, filtered["bar_id"] == table3["bar_id"]) view = j2[[filtered, table2["value1"], table3["value2"]]] - with join_tables(filtered, table2, table3) as (r1, r2, r3): + with join_tables(j2) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -835,7 +835,7 @@ def test_join_predicate_dereferencing_using_tuple_syntax(): j1 = ibis.join(t2, t3, [(t2.x, t3.x)]) j2 = ibis.join(t2, t4, [(t2.x, t4.x)]) - with join_tables(t2, t3) as (r1, r2): + with join_tables(j1) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -849,7 +849,7 @@ def test_join_predicate_dereferencing_using_tuple_syntax(): ) assert j1.op() == expected - with join_tables(t2, t4) as (r1, r2): + with join_tables(j2) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -864,6 +864,34 @@ def test_join_predicate_dereferencing_using_tuple_syntax(): assert j2.op() == expected +def test_join_rhs_dereferencing(): + t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"}) + t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"}) + + t3 = t2.mutate(e=t2.c + 1) + joined = t1.join(t3, [t1.a == t2.c]) + with join_tables(joined) as (r1, r2): + expected = JoinChain( + first=r1, + rest=[ + JoinLink("inner", r2, [r1.a == r2.c]), + ], + values={"a": r1.a, "b": r1.b, "c": r2.c, "d": r2.d, "e": r2.e}, + ) + assert joined.op() == expected + + joined = t1.join(t3, [t1.a == (t2.c + 1)]) + with join_tables(joined) as (r1, r2): + expected = JoinChain( + first=r1, + rest=[ + JoinLink("inner", r2, [r1.a == r2.e]), + ], + values={"a": r1.a, "b": r1.b, "c": r2.c, "d": r2.d, "e": r2.e}, + ) + assert joined.op() == expected + + def test_aggregate(): agg = t.aggregate(by=[t.bool_col], metrics=[t.int_col.sum()]) expected = Aggregate( @@ -1063,7 +1091,7 @@ def test_self_join(): t3 = t2.join(t2, ["key"]) t4 = t3.join(t3, ["key"]) - with join_tables(t2, t2, t3) as (r1, r2, r3): + with join_tables(t4) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1089,7 +1117,7 @@ def test_self_join_view(): t_view = t.view() expr = t.join(t_view, t.x == t_view.y).select("x", "y", "z", "z_right") - with join_tables(t, t_view) as (r1, r2): + with join_tables(expr) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1105,7 +1133,7 @@ def test_self_join_with_view_projection(): t2 = t1.view() expr = t1.inner_join(t2, ["x"])[[t1]] - with join_tables(t1, t2) as (r1, r2): + with join_tables(expr) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1120,10 +1148,14 @@ def test_joining_same_table_twice(): left = ibis.table(name="left", schema={"time1": int, "value": float, "a": str}) right = ibis.table(name="right", schema={"time2": int, "value2": float, "b": str}) - joined = left.inner_join(right, left.a == right.b).inner_join( - right, left.value == right.value2 - ) - with join_tables(left, right, right) as (r1, r2, r3): + first = left.inner_join(right, left.a == right.b) + + with pytest.raises(IbisInputError, match="Ambiguous field reference"): + first.inner_join(right, left.value == right.value2) + + right_ = right.view() + second = first.inner_join(right_, left.value == right_.value2) + with join_tables(second) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1142,7 +1174,7 @@ def test_joining_same_table_twice(): "b_right": r3.b, }, ) - assert joined.op() == expected + assert second.op() == expected def test_join_chain_gets_reused_and_continued_after_a_select(): @@ -1153,7 +1185,7 @@ def test_join_chain_gets_reused_and_continued_after_a_select(): ab = a.join(b, [a.a == b.c]) abc = ab[a.b, b.d].join(c, [a.a == c.e]) - with join_tables(a, b, c) as (r1, r2, r3): + with join_tables(abc) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1173,10 +1205,28 @@ def test_join_chain_gets_reused_and_continued_after_a_select(): def test_self_join_extensive(): a = ibis.table(name="a", schema={"a": "int64", "b": "string"}) - aa = a.join(a, [a.a == a.a]) + with pytest.raises(IbisInputError, match="Ambiguous field reference"): + a.join(a, [a.a == a.a]) + + a_ = a.view() + aa = a.join(a_, [a.a == a_.a]) + with join_tables(aa) as (r1, r2): + expected = ops.JoinChain( + first=r1, + rest=[ + ops.JoinLink("inner", r2, [r1.a == r2.a]), + ], + values={ + "a": r1.a, + "b": r1.b, + "b_right": r2.b, + }, + ) + assert aa.op() == expected + aa1 = a.join(a, "a") aa2 = a.join(a, [("a", "a")]) - with join_tables(a, a) as (r1, r2): + with join_tables(aa1) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1188,15 +1238,13 @@ def test_self_join_extensive(): "b_right": r2.b, }, ) - assert aa.op() == expected assert aa1.op() == expected assert aa2.op() == expected - aaa = a.join(a, [a.a == a.a]).join(a, [a.a == a.a]) - aaa1 = aa.join(a, [aa.a == a.a]) - aaa2 = aa.join(a, "a") - aaa3 = aa.join(a, [("a", "a")]) - with join_tables(a, a, a) as (r1, r2, r3): + aaa = a.join(a, "a").join(a, "a") + aaa1 = aa1.join(a, "a") + aaa2 = aa1.join(a, [("a", "a")]) + with join_tables(aaa) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1212,14 +1260,32 @@ def test_self_join_extensive(): assert aaa.op() == expected assert aaa1.op() == expected assert aaa2.op() == expected - assert aaa3.op() == expected def test_self_join_with_intermediate_selection(): a = ibis.table(name="a", schema={"a": "int64", "b": "string"}) proj = a[["b", "a"]] + # the predicate only references the original table, unless we enforce + # that the predicates must contain both sides of the join, we can't + # do much with this, perhaps raise a warning join = proj.join(a, [a.a == a.a]) - with join_tables(proj, a) as (r1, r2): + with join_tables(join) as (r1, r2): + expected = ops.JoinChain( + first=r1, + rest=[ + ops.JoinLink("inner", r2, [r2.a == r2.a]), + ], + values={ + "b": r1.b, + "a": r1.a, + "a_right": r2.a, + "b_right": r2.b, + }, + ) + assert join.op() == expected + + join = proj.join(a, [proj.a == a.a]) + with join_tables(join) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1233,9 +1299,11 @@ def test_self_join_with_intermediate_selection(): ) assert join.op() == expected - aa = a.join(a, [a.a == a.a])["a", "b_right"] - aaa = aa.join(a, [aa.a == a.a]) - with join_tables(a, a, a) as (r1, r2, r3): + a1 = a.view() + a2 = a.view() + aa = a.join(a1, [a.a == a1.a])["a", "b_right"] + aaa = aa.join(a2, [aa.a == a2.a]) + with join_tables(aaa) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1306,7 +1374,7 @@ def test_self_view_join_followed_by_aggregate_correctly_dereference_fields(): join = agged.inner_join(view, [agged.a == view.b]) agg = join.aggregate(metrics, by=[agged.g]) - with join_tables(agged, view) as (r1, r2): + with join_tables(join) as (r1, r2): expected_join = ops.JoinChain( first=r1, rest=[ @@ -1367,7 +1435,7 @@ def test_join_between_joins(): exprs = [left, right.value3, right.value4] expr = joined.select(exprs) - with join_tables(t1, t2, right) as (r1, r2, r3): + with join_tables(expr) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1393,7 +1461,7 @@ def test_join_with_filtered_join_of_left(): joined = t1.left_join(t2, [t1.a == t2.a]).filter(t1.a < 5) expr = t1.left_join(joined, [t1.a == joined.a]).select(t1) - with join_tables(t1, joined) as (r1, r2): + with join_tables(expr) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1451,7 +1519,7 @@ def test_join_with_compound_predicate(): ], ) expr = joined[t1] - with join_tables(t1, t2) as (r1, r2): + with join_tables(joined) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1486,7 +1554,7 @@ def test_inner_join_convenience(): t5 = ibis.table(name="t5", schema={"a": "int64", "f": "string"}) first_join = t1.inner_join(t2, [t1.a == t2.a]) - with join_tables(t1, t2) as (r1, r2): + with join_tables(first_join) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1504,7 +1572,7 @@ def test_inner_join_convenience(): # note that we are joining on r2.a which isn't among the values second_join = first_join.inner_join(t3, [r2.a == t3.a]) - with join_tables(t1, t2, t3) as (r1, r2, r3): + with join_tables(second_join) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1523,7 +1591,7 @@ def test_inner_join_convenience(): assert result == expected third_join = second_join.left_join(t4, [r3.a == t4.a]) - with join_tables(t1, t2, t3, t4) as (r1, r2, r3, r4): + with join_tables(third_join) as (r1, r2, r3, r4): expected = ops.JoinChain( first=r1, rest=[ @@ -1545,7 +1613,7 @@ def test_inner_join_convenience(): assert result == expected fourth_join = third_join.inner_join(t5, [r3.a == t5.a], rname="{name}_") - with join_tables(t1, t2, t3, t4, t5) as (r1, r2, r3, r4, r5): + with join_tables(fourth_join) as (r1, r2, r3, r4, r5): # equality groups are being reset expected = ops.JoinChain( first=r1, @@ -1575,7 +1643,7 @@ def test_inner_join_convenience(): third_join.inner_join(t5, [r4.a == t5.a])._finish() fifth_join = third_join.inner_join(t5, [r4.a == t5.a], rname="{name}_") - with join_tables(t1, t2, t3, t4, t5) as (r1, r2, r3, r4, r5): + with join_tables(fifth_join) as (r1, r2, r3, r4, r5): # equality groups are being reset expected = ops.JoinChain( first=r1, @@ -1631,3 +1699,15 @@ def test_impure_operation_dereferencing(func): parent=t1, values={"x": t1.x, "y": t1.y, "z": v2.cast("string")} ) assert t2.op() == expected + + +def test_mutate_ambiguty_check_not_too_strict(): + t = ibis.table({"id": "int64"}, name="t") + + first = t.mutate(v=t["id"]) + second = first.mutate(v2=t["id"]) + expected = ops.Project( + parent=first, + values={"id": first.id, "v": first.v, "v2": first.id}, + ) + assert second.op() == expected diff --git a/ibis/expr/types/joins.py b/ibis/expr/types/joins.py index 71a8b11ab061..ff081a353a51 100644 --- a/ibis/expr/types/joins.py +++ b/ibis/expr/types/joins.py @@ -20,9 +20,9 @@ from ibis.expr.rewrites import peel_join_field from ibis.expr.types.generic import Value from ibis.expr.types.relations import ( + DerefMap, Table, bind, - dereference_mapping, unwrap_aliases, ) @@ -112,47 +112,8 @@ def disambiguate_fields( return fields, collisions, equalities -def dereference_mapping_left(chain): - # construct the list of join table we wish to dereference fields to - rels = [chain.first] - for link in chain.rest: - if link.how not in ("semi", "anti"): - rels.append(link.table) - - # create the dereference mapping suitable to disambiguate field references - # from earlier in the relation hierarchy to one of the join tables - subs = dereference_mapping(rels) - - # also allow to dereference fields of the join chain itself - for k, v in chain.values.items(): - subs[ops.Field(chain, k)] = v - - return subs - - -def dereference_mapping_right(right): - # the right table is wrapped in a JoinTable the uniqueness of the underlying - # table which requires the predicates to be dereferenced to the wrapped - return {v: ops.Field(right, k) for k, v in right.values.items()} - - -def dereference_sides(left, right, deref_left, deref_right): - left = left.replace(deref_left, filter=ops.Value) - right = right.replace(deref_right, filter=ops.Value) - return left, right - - -def dereference_value(pred, deref_left, deref_right): - deref_both = {**deref_left, **deref_right} - if isinstance(pred, ops.Comparison) and pred.left.relations == pred.right.relations: - left, right = dereference_sides(pred.left, pred.right, deref_left, deref_right) - return pred.copy(left=left, right=right) - else: - return pred.replace(deref_both, filter=ops.Value) - - def prepare_predicates( - left: ops.JoinChain, + chain: ops.JoinChain, right: ops.Relation, predicates: Sequence[Any], comparison: type[ops.Comparison] = ops.Equals, @@ -187,8 +148,8 @@ def prepare_predicates( Parameters ---------- - left - The left table + chain + The join chain right The right table predicates @@ -197,21 +158,16 @@ def prepare_predicates( The comparison operation to construct if the input is a pair of expression-like objects """ - deref_left = dereference_mapping_left(left) - deref_right = dereference_mapping_right(right) + reverse = {ops.Field(chain, k): v for k, v in chain.values.items()} + deref_right = DerefMap.from_targets(right) + deref_left = DerefMap.from_targets(chain.tables, extra=reverse) + deref_both = DerefMap.from_targets([*chain.tables, right], extra=reverse) - left, right = left.to_expr(), right.to_expr() + left, right = chain.to_expr(), right.to_expr() for pred in util.promote_list(predicates): - if pred is True or pred is False: - yield ops.Literal(pred, dtype="bool") - elif isinstance(pred, Value): - for node in flatten_predicates(pred.op()): - yield dereference_value(node, deref_left, deref_right) - elif isinstance(pred, Deferred): - # resolve deferred expressions on the left table - pred = pred.resolve(left).op() - for node in flatten_predicates(pred): - yield dereference_value(node, deref_left, deref_right) + if isinstance(pred, (Value, Deferred, bool)): + for bound in bind(left, pred): + yield deref_both.dereference(bound.op()) else: if isinstance(pred, tuple): if len(pred) != 2: @@ -225,9 +181,9 @@ def prepare_predicates( (right_value,) = bind(right, rk) # dereference the left value to one of the relations in the join chain - left_value, right_value = dereference_sides( - left_value.op(), right_value.op(), deref_left, deref_right - ) + left_value = deref_left.dereference(left_value.op()) + right_value = deref_right.dereference(right_value.op()) + yield comparison(left_value, right_value) @@ -248,11 +204,8 @@ class Join(Table): def __init__(self, arg, collisions=(), equalities=()): assert isinstance(arg, ops.Node) if not isinstance(arg, ops.JoinChain): - # coerce the input node to a join chain operation by first wrapping - # the input relation in a JoinTable so that we can join the same - # table with itself multiple times and to enable optimization - # passes later on - arg = ops.JoinTable(arg, index=0) + # coerce the input node to a join chain operation + arg = ops.JoinReference(arg, identifier=0) arg = ops.JoinChain(arg, rest=(), values=arg.fields) super().__init__(arg) # the collisions and equalities are used to track the name collisions @@ -297,11 +250,14 @@ def join( elif how == "asof": raise IbisInputError("use table.asof_join(...) instead") - left = self.op() - right = ops.JoinTable(right, index=left.length) + chain = self.op() + right = right.op() + if not isinstance(right, ops.Reference): + right = ops.JoinReference(right, identifier=chain.length) # bind and dereference the predicates - preds = list(prepare_predicates(left, right, predicates)) + preds = prepare_predicates(chain, right, predicates) + preds = flatten_predicates(preds) if not preds and how != "cross": # if there are no predicates, default to every row matching unless # the join is a cross join, because a cross join already has this @@ -316,7 +272,7 @@ def join( how=how, predicates=preds, equalities=self._equalities, - left_fields=left.values, + left_fields=chain.values, right_fields=right.fields, left_template=lname, right_template=rname, @@ -324,10 +280,10 @@ def join( # construct a new join link and add it to the join chain link = ops.JoinLink(how, table=right, predicates=preds) - left = left.copy(rest=left.rest + (link,), values=values) + chain = chain.copy(rest=chain.rest + (link,), values=values) # return with a new JoinExpr wrapping the new join chain - return self.__class__(left, collisions=collisions, equalities=equalities) + return self.__class__(chain, collisions=collisions, equalities=equalities) @functools.wraps(Table.asof_join) def asof_join( @@ -382,19 +338,21 @@ def asof_join( values = {**self.op().values, **filtered.op().values} return result.select(values) - left = self.op() - right = ops.JoinTable(right, index=left.length) + chain = self.op() + right = right.op() + if not isinstance(right, ops.Reference): + right = ops.JoinReference(right, identifier=chain.length) # TODO(kszucs): add extra validation for `on` with clear error messages - (on,) = prepare_predicates(left, right, [on], comparison=ops.GreaterEqual) - preds = prepare_predicates(left, right, predicates, comparison=ops.Equals) + (on,) = prepare_predicates(chain, right, [on], comparison=ops.GreaterEqual) + preds = prepare_predicates(chain, right, predicates, comparison=ops.Equals) preds = [on, *preds] values, collisions, equalities = disambiguate_fields( how="asof", predicates=preds, equalities=self._equalities, - left_fields=left.values, + left_fields=chain.values, right_fields=right.fields, left_template=lname, right_template=rname, @@ -402,10 +360,10 @@ def asof_join( # construct a new join link and add it to the join chain link = ops.JoinLink("asof", table=right, predicates=preds) - left = left.copy(rest=left.rest + (link,), values=values) + chain = chain.copy(rest=chain.rest + (link,), values=values) # return with a new JoinExpr wrapping the new join chain - return self.__class__(left, collisions=collisions, equalities=equalities) + return self.__class__(chain, collisions=collisions, equalities=equalities) @functools.wraps(Table.cross_join) def cross_join( @@ -428,13 +386,15 @@ def select(self, *args, **kwargs): values = bind(self, (args, kwargs)) values = unwrap_aliases(values) + links = [link.table for link in chain.rest if link.how not in ("semi", "anti")] + derefmap = DerefMap.from_targets([chain.first, *links]) + # if there are values referencing fields from the join chain constructed # so far, we need to replace them the fields from one of the join links - subs = dereference_mapping_left(chain) values = { k: v.replace(peel_join_field, filter=ops.Value) for k, v in values.items() } - values = {k: v.replace(subs, filter=ops.Value) for k, v in values.items()} + values = {k: derefmap.dereference(v) for k, v in values.items()} node = chain.copy(values=values) return Table(node) diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 3a19bb8d5483..6678f3e9f690 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -5,12 +5,7 @@ import re from collections.abc import Iterable, Iterator, Mapping, Sequence from keyword import iskeyword -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Literal, -) +from typing import TYPE_CHECKING, Any, Callable, Literal import toolz from public import public @@ -22,6 +17,7 @@ import ibis.expr.schema as sch from ibis import util from ibis.common.deferred import Deferred, Resolver +from ibis.expr.rewrites import DerefMap from ibis.expr.types.core import Expr, _FixedTextJupyterMixin from ibis.expr.types.generic import Value, literal from ibis.expr.types.pretty import to_rich @@ -149,58 +145,12 @@ def unwrap_aliases(values: Iterator[ir.Value]) -> Mapping[str, ir.Value]: return result -def dereference_mapping(parents): - parents = util.promote_list(parents) - mapping = {} - - for parent in parents: - # do not defereference fields referencing the requested parents - for _, v in parent.fields.items(): - mapping[v] = v - - for parent in parents: - for k, v in parent.values.items(): - if isinstance(v, ops.Field): - # track down the field in the hierarchy until no modification - # is made so only follow ops.Field nodes not arbitrary values; - # also stop tracking if the field belongs to a parent which - # we want to dereference to, see the docstring of - # `dereference_values()` for more details - while isinstance(v, ops.Field) and v not in mapping: - mapping[v] = ops.Field(parent, k) - v = v.rel.values.get(v.name) - elif v not in mapping and not v.find(ops.Impure): - # do not dereference literal expressions - mapping[v] = ops.Field(parent, k) - - return mapping - - def dereference_values( parents: Iterable[ops.Parents], values: Mapping[str, ops.Value] ) -> Mapping[str, ops.Value]: """Trace and replace fields from earlier relations in the hierarchy. - In order to provide a nice user experience, we need to allow expressions - from earlier relations in the hierarchy. Consider the following example: - - t = ibis.table([('a', 'int64'), ('b', 'string')], name='t') - t1 = t.select([t.a, t.b]) - t2 = t1.filter(t.a > 0) # note that not t1.a is referenced here - t3 = t2.select(t.a) # note that not t2.a is referenced here - - However the relational operations in the IR are strictly enforcing that - the expressions are referencing the immediate parent only. So we need to - track fields upwards the hierarchy to replace `t.a` with `t1.a` and `t2.a` - in the example above. This is called dereferencing. - - Whether we can treat or not a field of a relation semantically equivalent - with a field of an earlier relation in the hierarchy depends on the - `.values` mapping of the relation. Leaf relations, like `t` in the example - above, have an empty `.values` mapping, so we cannot dereference fields - from them. On the other hand a projection, like `t1` in the example above, - has a `.values` mapping like `{'a': t.a, 'b': t.b}`, so we can deduce that - `t1.a` is semantically equivalent with `t.a` and so on. + For more details see :class:`ibis.expr.rewrites.DerefMap`. Parameters ---------- @@ -214,8 +164,8 @@ def dereference_values( The same mapping as `values` but with all the dereferenceable fields replaced with the fields from the parents. """ - subs = dereference_mapping(parents) - return {k: v.replace(subs, filter=ops.Value) for k, v in values.items()} + dm = DerefMap.from_targets(parents) + return {k: dm.dereference(v) for k, v in values.items()} @public @@ -1185,23 +1135,21 @@ def aggregate( groups = bind(self, by) metrics = bind(self, (metrics, kwargs)) - having = bind(self, having) + having = tuple(bind(self, having)) groups = unwrap_aliases(groups) metrics = unwrap_aliases(metrics) - having = unwrap_aliases(having) groups = dereference_values(node, groups) metrics = dereference_values(node, metrics) - having = dereference_values(node, having) # the user doesn't need to specify the metrics used in the having clause # explicitly, we implicitly add them to the metrics list by looking for # any metrics depending on self which are not specified explicitly pattern = p.Reduction(relations=Contains(node)) & ~In(set(metrics.values())) original_metrics = metrics.copy() - for pred in having.values(): - for metric in pred.find_topmost(pattern): + for pred in having: + for metric in pred.op().find_topmost(pattern): if metric.name in metrics: metrics[util.get_name("metric")] = metric else: @@ -1212,7 +1160,7 @@ def aggregate( if having: # apply the having clause - agg = agg.filter(*having.values()) + agg = agg.filter(*having) # remove any metrics that were only used in the having clause if metrics != original_metrics: agg = agg.select(*groups.keys(), *original_metrics.keys()) diff --git a/ibis/tests/expr/test_analysis.py b/ibis/tests/expr/test_analysis.py index aeacf249edf0..527bbab84c8f 100644 --- a/ibis/tests/expr/test_analysis.py +++ b/ibis/tests/expr/test_analysis.py @@ -27,7 +27,7 @@ def test_rewrite_join_projection_without_other_ops(con): # Project out the desired fields view = j2[[filtered, table2["value1"], table3["value2"]]] - with join_tables(filtered, table2, table3) as (r1, r2, r3): + with join_tables(j2) as (r1, r2, r3): # Construct the thing we expect to obtain expected = ops.JoinChain( first=r1, @@ -170,7 +170,7 @@ def test_filter_self_join(): what = [left.region, metric] projected = joined.select(what) - with join_tables(left, right) as (r1, r2): + with join_tables(joined) as (r1, r2): join = ops.JoinChain( first=r1, rest=[ diff --git a/ibis/tests/expr/test_struct.py b/ibis/tests/expr/test_struct.py index 75707c650f2c..de297a4e4a58 100644 --- a/ibis/tests/expr/test_struct.py +++ b/ibis/tests/expr/test_struct.py @@ -72,16 +72,15 @@ def test_unpack_from_table(t): def test_lift_join(t, s): join = t.join(s, t.d == s.a.g) result = join.a_right.lift() - - with join_tables(t, s) as (r1, r2): - join = ops.JoinChain( - first=r1, + with join_tables(join) as (t, s): + expected = ops.JoinChain( + first=t, rest=[ - ops.JoinLink("inner", r2, [r1.d == r2.a.g]), + ops.JoinLink("inner", s, [t.d == s.a.g]), ], - values={"f": r2.a.f, "g": r2.a.g}, + values={"f": s.a.f, "g": s.a.g}, ) - assert result.op() == join + assert result.op() == expected def test_unpack_join_from_table(t, s): diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index 420682ae1810..c8b2ed0103ce 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -901,7 +901,7 @@ def test_join_no_predicate_list(con): pred = region.r_regionkey == nation.n_regionkey joined = region.inner_join(nation, pred) - with join_tables(region, nation) as (r1, r2): + with join_tables(joined) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ops.JoinLink("inner", r2, [r1.r_regionkey == r2.n_regionkey])], @@ -924,7 +924,7 @@ def test_join_deferred(con): res = region.join(nation, _.r_regionkey == nation.n_regionkey) - with join_tables(region, nation) as (r1, r2): + with join_tables(res) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ops.JoinLink("inner", r2, [r1.r_regionkey == r2.n_regionkey])], @@ -970,7 +970,7 @@ def test_asof_join_with_by(): right = ibis.table([("time", "int32"), ("key", "int32"), ("value2", "double")]) join_without_by = api.asof_join(left, right, "time") - with join_tables(left, right) as (r1, r2): + with join_tables(join_without_by) as (r1, r2): r2 = join_without_by.op().rest[0].table.to_expr() expected = ops.JoinChain( first=r1, @@ -987,7 +987,7 @@ def test_asof_join_with_by(): assert join_without_by.op() == expected join_with_predicates = api.asof_join(left, right, "time", predicates="key") - with join_tables(left, right) as (r1, r2): + with join_tables(join_with_predicates) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1205,7 +1205,7 @@ def test_cross_join_multiple(table): c = table["f", "h"] joined = ibis.cross_join(a, b, c) - with join_tables(a, b, c) as (r1, r2, r3): + with join_tables(joined) as (r1, r2, r3): expected = ops.JoinChain( first=r1, rest=[ @@ -1292,7 +1292,7 @@ def test_join_key_alternatives(con, key_maker): key = key_maker(t1, t2) joined = t1.inner_join(t2, key) - with join_tables(t1, t2) as (r1, r2): + with join_tables(joined) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[ @@ -1376,7 +1376,7 @@ def test_unravel_compound_equijoin(table): p3 = t1.key3 == t2.key3 joined = t1.inner_join(t2, [p1 & p2 & p3]) - with join_tables(t1, t2) as (r1, r2): + with join_tables(joined) as (r1, r2): expected = ops.JoinChain( first=r1, rest=[