Skip to content

Commit

Permalink
[Relax] Identify tuple unpack/repack in CanonicalizeBindings (#17313)
Browse files Browse the repository at this point in the history
Prior to this commit, the `CanonicalizeBindings` pass could identify
and simplify a value that had been packed into a tuple, then
extracted from it.  (e.g. Simplifying `tup = (x,y); z = tup[0]` into
`z = x`.)  However, it could not identify a value that had been
expanded from a tuple, and then re-bundled.  (e.g. Simplifying
`new_tuple = (tup[0], tup[1])` into `new_tuple = tup`.)

This commit updates `CanonicalizeBindings` to identify and remove
unnecessary tuple unpacking/repacking.
  • Loading branch information
Lunderberg authored Aug 28, 2024
1 parent be8607d commit 108a4e1
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 20 deletions.
112 changes: 92 additions & 20 deletions src/relax/transform/canonicalize_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,33 +262,105 @@ class CanonicalizePlanner : public ExprVisitor {
current_block_ = Optional<BindingBlock>();
}

void VisitBinding(const Binding& binding) override {
bool has_same_struct_info = true;
Expr value;
if (auto ptr = binding.as<VarBindingNode>()) {
value = ptr->value;
} else if (auto ptr = binding.as<MatchCastNode>()) {
has_same_struct_info =
StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(ptr->value));
value = ptr->value;
} else {
LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey();
}
Optional<Expr> UnwrapKnownValue(Expr expr) {
// If the expression is a variable, then it can be unwrapped into
// its known value.
auto unwrap_var = [this](Expr expr) -> Expr {
if (auto var = expr.as<Var>()) {
if (auto opt = known_bindings_.Get(var.value())) {
return opt.value();
}
}
return expr;
};

// Unwrap TupleGetItem, if the Tuple being accessed is known.
if (auto tuple_get_item = value.as<TupleGetItemNode>()) {
Expr tuple = tuple_get_item->tuple;
while (auto tuple_var = tuple.as<Var>()) {
if (auto opt = known_bindings_.Get(tuple_var.value())) {
tuple = opt.value();
auto recursively_unwrap_var = [&unwrap_var](Expr expr) -> Expr {
while (true) {
auto new_expr = unwrap_var(expr);
if (new_expr.same_as(expr)) {
return expr;
} else {
break;
expr = new_expr;
}
}
};

// If the expression is a TupleGetItem, which accesses a field of
// a known tuple, then it can be unwrapped into a direct access of
// that field.
if (auto tuple_get_item = expr.as<TupleGetItemNode>()) {
Expr tuple = recursively_unwrap_var(tuple_get_item->tuple);
if (auto ptr = tuple.as<TupleNode>()) {
value = ptr->fields[tuple_get_item->index];
return ptr->fields[tuple_get_item->index];
}
}

// If the expression is a Tuple, and each element is
// `TupleGetItem(earlier_tuple, i)`, then this is just a copy of
// `earlier_tuple`.
auto earlier_tuple = [&]() -> Optional<Expr> {
auto expr_tuple = expr.as<TupleNode>();
if (!expr_tuple) {
return NullOpt;
}

if (expr_tuple->fields.empty()) {
return NullOpt;
}

auto first_element = recursively_unwrap_var(expr_tuple->fields[0]).as<TupleGetItemNode>();
if (!first_element) {
return NullOpt;
}

auto earlier_tuple_size =
Downcast<TupleStructInfo>(GetStructInfo(first_element->tuple))->fields.size();
if (earlier_tuple_size != expr_tuple->fields.size()) {
return NullOpt;
}

Expr earlier_tuple = recursively_unwrap_var(first_element->tuple);

for (size_t i = 0; i < expr_tuple->fields.size(); i++) {
auto element = recursively_unwrap_var(expr_tuple->fields[i]).as<TupleGetItemNode>();
if (!element) {
return NullOpt;
}
if (static_cast<size_t>(element->index) != i) {
return NullOpt;
}

auto source_of_element = recursively_unwrap_var(element->tuple);

if (!earlier_tuple.same_as(source_of_element)) {
return NullOpt;
}
}

return earlier_tuple;
}();
if (earlier_tuple) {
return earlier_tuple.value();
}

return NullOpt;
}

void VisitBinding(const Binding& binding) override {
bool has_same_struct_info = [&]() {
if (binding.as<VarBindingNode>()) {
return true;
} else if (auto match_cast = binding.as<MatchCastNode>()) {
return StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(match_cast->value));
} else {
LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey();
}
}();

Expr value = GetBoundValue(binding);

if (auto unwrapped = UnwrapKnownValue(value)) {
value = unwrapped.value();
}

if (auto parent = value.as<Var>(); parent && has_same_struct_info) {
Expand Down
51 changes: 51 additions & 0 deletions tests/python/relax/test_transform_canonicalize_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,5 +1294,56 @@ def _get_binding_names(mod):
assert after_names == expected_names


def test_trace_tuple_through_round_trip():
"""Canonicalize to the orignal tuple, without unwrap/rewrap."""

@I.ir_module
class Before:
@R.function
def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])):
with R.dataflow():
A = param_tuple[0]
B = param_tuple[1]
C = param_tuple[2]
output = (A, B, C)
R.output(output)
return output

@I.ir_module
class Expected:
@R.function
def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])):
with R.dataflow():
A = param_tuple[0]
B = param_tuple[1]
C = param_tuple[2]
R.output()

return param_tuple

After = CanonicalizeBindings()(Before)
tvm.ir.assert_structural_equal(After, Expected)


def test_trace_partial_tuple_through_round_trip():
"""Canonicalize to the orignal tuple, without unwrap/rewrap."""

@I.ir_module
class Before:
@R.function
def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])):
with R.dataflow():
A = param_tuple[0]
B = param_tuple[1]
output = (A, B)
R.output(output)
return output

Expected = Before

After = CanonicalizeBindings()(Before)
tvm.ir.assert_structural_equal(After, Expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 108a4e1

Please sign in to comment.