Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Identify tuple unpack/repack in CanonicalizeBindings #17313

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 92 additions & 20 deletions src/relax/transform/canonicalize_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,33 +262,105 @@ class CanonicalizePlanner : public ExprVisitor {
current_block_ = Optional<BindingBlock>();
}

void VisitBinding(const Binding& binding) override {
bool has_same_struct_info = true;
Expr value;
if (auto ptr = binding.as<VarBindingNode>()) {
value = ptr->value;
} else if (auto ptr = binding.as<MatchCastNode>()) {
has_same_struct_info =
StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(ptr->value));
value = ptr->value;
} else {
LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey();
}
Optional<Expr> UnwrapKnownValue(Expr expr) {
// If the expression is a variable, then it can be unwrapped into
// its known value.
auto unwrap_var = [this](Expr expr) -> Expr {
if (auto var = expr.as<Var>()) {
if (auto opt = known_bindings_.Get(var.value())) {
return opt.value();
}
}
return expr;
};

// Unwrap TupleGetItem, if the Tuple being accessed is known.
if (auto tuple_get_item = value.as<TupleGetItemNode>()) {
Expr tuple = tuple_get_item->tuple;
while (auto tuple_var = tuple.as<Var>()) {
if (auto opt = known_bindings_.Get(tuple_var.value())) {
tuple = opt.value();
auto recursively_unwrap_var = [&unwrap_var](Expr expr) -> Expr {
while (true) {
auto new_expr = unwrap_var(expr);
if (new_expr.same_as(expr)) {
return expr;
} else {
break;
expr = new_expr;
}
}
};

// If the expression is a TupleGetItem, which accesses a field of
// a known tuple, then it can be unwrapped into a direct access of
// that field.
if (auto tuple_get_item = expr.as<TupleGetItemNode>()) {
Expr tuple = recursively_unwrap_var(tuple_get_item->tuple);
if (auto ptr = tuple.as<TupleNode>()) {
value = ptr->fields[tuple_get_item->index];
return ptr->fields[tuple_get_item->index];
}
}

// If the expression is a Tuple, and each element is
// `TupleGetItem(earlier_tuple, i)`, then this is just a copy of
// `earlier_tuple`.
auto earlier_tuple = [&]() -> Optional<Expr> {
auto expr_tuple = expr.as<TupleNode>();
if (!expr_tuple) {
return NullOpt;
}

if (expr_tuple->fields.empty()) {
return NullOpt;
}

auto first_element = recursively_unwrap_var(expr_tuple->fields[0]).as<TupleGetItemNode>();
if (!first_element) {
return NullOpt;
}

auto earlier_tuple_size =
Downcast<TupleStructInfo>(GetStructInfo(first_element->tuple))->fields.size();
if (earlier_tuple_size != expr_tuple->fields.size()) {
return NullOpt;
}

Expr earlier_tuple = recursively_unwrap_var(first_element->tuple);

for (size_t i = 0; i < expr_tuple->fields.size(); i++) {
auto element = recursively_unwrap_var(expr_tuple->fields[i]).as<TupleGetItemNode>();
if (!element) {
return NullOpt;
}
if (static_cast<size_t>(element->index) != i) {
return NullOpt;
}

auto source_of_element = recursively_unwrap_var(element->tuple);

if (!earlier_tuple.same_as(source_of_element)) {
return NullOpt;
}
}

return earlier_tuple;
}();
if (earlier_tuple) {
return earlier_tuple.value();
}

return NullOpt;
}

void VisitBinding(const Binding& binding) override {
bool has_same_struct_info = [&]() {
if (binding.as<VarBindingNode>()) {
return true;
} else if (auto match_cast = binding.as<MatchCastNode>()) {
return StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(match_cast->value));
} else {
LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey();
}
}();

Expr value = GetBoundValue(binding);

if (auto unwrapped = UnwrapKnownValue(value)) {
value = unwrapped.value();
}

if (auto parent = value.as<Var>(); parent && has_same_struct_info) {
Expand Down
51 changes: 51 additions & 0 deletions tests/python/relax/test_transform_canonicalize_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,5 +1260,56 @@ def _get_binding_names(mod):
assert after_names == expected_names


def test_trace_tuple_through_round_trip():
"""Canonicalize to the orignal tuple, without unwrap/rewrap."""

@I.ir_module
class Before:
@R.function
def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])):
with R.dataflow():
A = param_tuple[0]
B = param_tuple[1]
C = param_tuple[2]
output = (A, B, C)
R.output(output)
return output

@I.ir_module
class Expected:
@R.function
def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])):
with R.dataflow():
A = param_tuple[0]
B = param_tuple[1]
C = param_tuple[2]
R.output()

return param_tuple

After = CanonicalizeBindings()(Before)
tvm.ir.assert_structural_equal(After, Expected)


def test_trace_partial_tuple_through_round_trip():
"""Canonicalize to the orignal tuple, without unwrap/rewrap."""

@I.ir_module
class Before:
@R.function
def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])):
with R.dataflow():
A = param_tuple[0]
B = param_tuple[1]
output = (A, B)
R.output(output)
return output

Expected = Before

After = CanonicalizeBindings()(Before)
tvm.ir.assert_structural_equal(After, Expected)


if __name__ == "__main__":
tvm.testing.main()
Loading