From 108a4e15b3c68fea2f803dc13b1b45291b00f15b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 28 Aug 2024 18:29:18 -0500 Subject: [PATCH] [Relax] Identify tuple unpack/repack in CanonicalizeBindings (#17313) Prior to this commit, the `CanonicalizeBindings` pass could identify and simplify a value that had been packed into a tuple, then extracted from it. (e.g. Simplifying `tup = (x,y); z = tup[0]` into `z = x`.) However, it could not identify a value that had been expanded from a tuple, and then re-bundled. (e.g. Simplifying `new_tuple = (tup[0], tup[1])` into `new_tuple = tup`.) This commit updates `CanonicalizeBindings` to identify and remove unnecessary tuple unpacking/repacking. --- src/relax/transform/canonicalize_bindings.cc | 112 ++++++++++++++---- .../test_transform_canonicalize_bindings.py | 51 ++++++++ 2 files changed, 143 insertions(+), 20 deletions(-) diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index d1a9f97337de..807914075e8d 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -262,33 +262,105 @@ class CanonicalizePlanner : public ExprVisitor { current_block_ = Optional(); } - void VisitBinding(const Binding& binding) override { - bool has_same_struct_info = true; - Expr value; - if (auto ptr = binding.as()) { - value = ptr->value; - } else if (auto ptr = binding.as()) { - has_same_struct_info = - StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(ptr->value)); - value = ptr->value; - } else { - LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey(); - } + Optional UnwrapKnownValue(Expr expr) { + // If the expression is a variable, then it can be unwrapped into + // its known value. + auto unwrap_var = [this](Expr expr) -> Expr { + if (auto var = expr.as()) { + if (auto opt = known_bindings_.Get(var.value())) { + return opt.value(); + } + } + return expr; + }; - // Unwrap TupleGetItem, if the Tuple being accessed is known. - if (auto tuple_get_item = value.as()) { - Expr tuple = tuple_get_item->tuple; - while (auto tuple_var = tuple.as()) { - if (auto opt = known_bindings_.Get(tuple_var.value())) { - tuple = opt.value(); + auto recursively_unwrap_var = [&unwrap_var](Expr expr) -> Expr { + while (true) { + auto new_expr = unwrap_var(expr); + if (new_expr.same_as(expr)) { + return expr; } else { - break; + expr = new_expr; } } + }; + // If the expression is a TupleGetItem, which accesses a field of + // a known tuple, then it can be unwrapped into a direct access of + // that field. + if (auto tuple_get_item = expr.as()) { + Expr tuple = recursively_unwrap_var(tuple_get_item->tuple); if (auto ptr = tuple.as()) { - value = ptr->fields[tuple_get_item->index]; + return ptr->fields[tuple_get_item->index]; + } + } + + // If the expression is a Tuple, and each element is + // `TupleGetItem(earlier_tuple, i)`, then this is just a copy of + // `earlier_tuple`. + auto earlier_tuple = [&]() -> Optional { + auto expr_tuple = expr.as(); + if (!expr_tuple) { + return NullOpt; + } + + if (expr_tuple->fields.empty()) { + return NullOpt; + } + + auto first_element = recursively_unwrap_var(expr_tuple->fields[0]).as(); + if (!first_element) { + return NullOpt; + } + + auto earlier_tuple_size = + Downcast(GetStructInfo(first_element->tuple))->fields.size(); + if (earlier_tuple_size != expr_tuple->fields.size()) { + return NullOpt; } + + Expr earlier_tuple = recursively_unwrap_var(first_element->tuple); + + for (size_t i = 0; i < expr_tuple->fields.size(); i++) { + auto element = recursively_unwrap_var(expr_tuple->fields[i]).as(); + if (!element) { + return NullOpt; + } + if (static_cast(element->index) != i) { + return NullOpt; + } + + auto source_of_element = recursively_unwrap_var(element->tuple); + + if (!earlier_tuple.same_as(source_of_element)) { + return NullOpt; + } + } + + return earlier_tuple; + }(); + if (earlier_tuple) { + return earlier_tuple.value(); + } + + return NullOpt; + } + + void VisitBinding(const Binding& binding) override { + bool has_same_struct_info = [&]() { + if (binding.as()) { + return true; + } else if (auto match_cast = binding.as()) { + return StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(match_cast->value)); + } else { + LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey(); + } + }(); + + Expr value = GetBoundValue(binding); + + if (auto unwrapped = UnwrapKnownValue(value)) { + value = unwrapped.value(); } if (auto parent = value.as(); parent && has_same_struct_info) { diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index a7ff8cdc3202..1d982b0972ed 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -1294,5 +1294,56 @@ def _get_binding_names(mod): assert after_names == expected_names +def test_trace_tuple_through_round_trip(): + """Canonicalize to the orignal tuple, without unwrap/rewrap.""" + + @I.ir_module + class Before: + @R.function + def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])): + with R.dataflow(): + A = param_tuple[0] + B = param_tuple[1] + C = param_tuple[2] + output = (A, B, C) + R.output(output) + return output + + @I.ir_module + class Expected: + @R.function + def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])): + with R.dataflow(): + A = param_tuple[0] + B = param_tuple[1] + C = param_tuple[2] + R.output() + + return param_tuple + + After = CanonicalizeBindings()(Before) + tvm.ir.assert_structural_equal(After, Expected) + + +def test_trace_partial_tuple_through_round_trip(): + """Canonicalize to the orignal tuple, without unwrap/rewrap.""" + + @I.ir_module + class Before: + @R.function + def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])): + with R.dataflow(): + A = param_tuple[0] + B = param_tuple[1] + output = (A, B) + R.output(output) + return output + + Expected = Before + + After = CanonicalizeBindings()(Before) + tvm.ir.assert_structural_equal(After, Expected) + + if __name__ == "__main__": tvm.testing.main()