Skip to content

Commit

Permalink
[Unity] Implement FNormalize for relax.op.call_tir (#16068)
Browse files Browse the repository at this point in the history
* [Unity] Implement FNormalize for relax.op.call_tir

Prior to this commit, `relax.op.call_tir` could express the TIR
arguments as either an in-line tuple, or as a variable bound to a
tuple.  Because several passes assume the arguments will always be an
in-line tuple, this is being codified as the normal form of
`relax.op.call_tir`.  Any upstream transform that produces
`relax.op.call_tir` with arguments provided as a by-variable tuple
will either be normalized to an in-line tuple if possible, or will
produce an error during the upstream transform otherwise.

This commit is specifically to allow the current usage of
`Downcast<Tuple>(call->args[1])` in passes such as `CallTIRRewrite`,
`FoldConstant`, `FuseTIR`, and `RewriteDataflowReshape`.

* Resolve unit test failures

* Added normalization for call_tir_inplace and call_tir_with_grad

* Normalize arg_tuple to (arg_tuple[0], ..., arg_tuple[N]) if unknown
  • Loading branch information
Lunderberg authored Nov 14, 2023
1 parent c9de001 commit 0ddfc65
Show file tree
Hide file tree
Showing 4 changed files with 341 additions and 19 deletions.
16 changes: 13 additions & 3 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,19 @@ class WellFormedChecker : public relax::ExprVisitor,

if (auto func_normalize = op_map_normalize_.get(call->op, nullptr); func_normalize != nullptr) {
auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_);
auto before_normalize = GetRef<Call>(call);
auto after_normalize = func_normalize(dummy_builder, before_normalize);
if (!before_normalize.same_as(after_normalize)) {
Call before_normalize = GetRef<Call>(call);
Optional<Expr> after_normalize = NullOpt;
try {
after_normalize = func_normalize(dummy_builder, before_normalize);
} catch (std::exception& err) {
Malformed(
Diagnostic::Error(call)
<< "If an operator defines an operator-specific normalization function (FNormalize), "
<< "calls to that operator must be normalized with it. "
<< "However, normalization of " << before_normalize << " resulted in the error: \n"
<< err.what());
}
if (after_normalize && !before_normalize.same_as(after_normalize)) {
Malformed(
Diagnostic::Error(call)
<< "If an operator defines an operator-specific normalization function (FNormalize), "
Expand Down
89 changes: 74 additions & 15 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,70 @@ StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) {
return call->sinfo_args[0];
}

Expr NormalizeCallTIR(const BlockBuilder&, Call call) {
// Temporary implementation to ensure that at least one op has a
// registered value for FNormalize. This temporary implementation
// is fully implemented in follow-up PR
// https://github.com/apache/tvm/pull/16068.
Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) {
// This function is used for normalization of `relax.call_tir`,
// along with the variants `relax.call_tir_with_grad` and
// `relax.call_tir_inplace`. Therefore, all error messages should
// be written in terms of `call->op`, and should not explicitly
// reference the `relax.call_tir` operator.`
CHECK(call->args.size() == 2 || call->args.size() == 3)
<< "Operation " << call->op << " expects either two arguments [callee, arg_tuple], "
<< "or three arguments [callee, arg_tuple, tir_args], "
<< "but " << call << " has " << call->args.size() << " arguments.";

Expr arg_expr = call->args[1];

CHECK(arg_expr->struct_info_.as<TupleStructInfoNode>())
<< "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. "
<< "However, the second argument " << arg_expr << " has struct info "
<< arg_expr->struct_info_ << ".";

if (arg_expr.as<TupleNode>()) {
return std::move(call);
}

CHECK(arg_expr.as<VarNode>())
<< "Operation " << call->op << " must hold its arguments as an in-line tuple. "
<< "However, " << call << " has arguments " << arg_expr
<< ", which is neither an in-line tuple, "
<< "nor a variable binding that may be normalized to an in-line tuple.";

auto unwrap_binding = [&ctx](Expr expr) -> Optional<Expr> {
if (auto var = expr.as<Var>()) {
if (auto bound_value = ctx->LookupBinding(var.value())) {
return bound_value.value();
}
}
return NullOpt;
};

while (auto unwrapped = unwrap_binding(arg_expr)) {
arg_expr = unwrapped.value();
}

Tuple new_arg_expr = [&]() {
// Preferred replacement. The argument tuple is provided as a
// variable, but we know the value bound to that variable.
if (auto opt = arg_expr.as<Tuple>()) {
return opt.value();
}

// Fallback case. The argument tuple is provided as a variable,
// and we don't know the value bound to that variable. For
// example, if a relax function accepted a tuple as an parameter,
// then provided that same tuple as an argument to call_tir.
Array<Expr> tuple_elements;
size_t num_fields = Downcast<TupleStructInfo>(arg_expr->struct_info_)->fields.size();
for (size_t i = 0; i < num_fields; i++) {
tuple_elements.push_back(TupleGetItem(arg_expr, i));
}
return Tuple(tuple_elements);
}();

auto new_args = call->args;
new_args.Set(1, new_arg_expr);
call.CopyOnWrite()->args = new_args;

return std::move(call);
}

Expand Down Expand Up @@ -314,6 +373,7 @@ RELAY_REGISTER_OP("relax.call_tir_with_grad")
"ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from "
"args if unused")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
.set_attr<FNormalize>("FNormalize", NormalizeCallTIR)
.set_attr<Bool>("FPurity", Bool(true));

Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list,
Expand Down Expand Up @@ -353,14 +413,12 @@ TVM_REGISTER_GLOBAL("relax.op.call_tir_with_grad").set_body_typed(MakeCallTIRWit

// call_tir_inplace

StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& ctx) {
if (call->sinfo_args.size() != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "sinfo_args should have exactly 1 output struct info.");
}
CHECK(call->args[0]->IsInstance<GlobalVarNode>())
<< "call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. "
<< "However, gets " << call->args[0];
Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) {
// Apply normalization before error checks. This allows the error
// checks to safely apply `Downcast<Tuple>(call->args[1])`, which
// may result in an error if performed before normalization.
call = Downcast<Call>(NormalizeCallTIR(ctx, std::move(call)));

// there must be an inplace index for each output
const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
size_t num_outputs = 1U;
Expand Down Expand Up @@ -443,7 +501,7 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c
}
}

return call->sinfo_args[0];
return std::move(call);
}

TVM_REGISTER_NODE_TYPE(CallTIRInplaceAttrs);
Expand All @@ -456,7 +514,8 @@ RELAY_REGISTER_OP("relax.call_tir_inplace")
.add_argument("packed_ints", "Expr",
"ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from "
"args if unused")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIRInplace)
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
.set_attr<FNormalize>("FNormalize", NormalizeCallTIRInPlace)
// Warning: considered pure, but it has the potential to create visible effects!
// This should only be used if it has been *checked* that it is safe (no aliases, in-place
// arguments will no longer be live)
Expand Down
47 changes: 47 additions & 0 deletions tests/python/relax/test_transform_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,5 +292,52 @@ def foo(x: R.Tensor((2, 3), dtype="float32")):
verify(Before, Expected)


def test_call_tir_tuple_arg():
@I.ir_module
class Before:
@R.function
def main(A: R.Tensor([16, 16], "int32"), B: R.Tensor([16, 16], "int32")):
cls = Before
Prod = R.call_tir(cls.product, [A, B], out_sinfo=R.Tensor([16, 16], "int32"))
Sum = R.call_tir(cls.sum, [A, B], out_sinfo=R.Tensor([16, 16], "int32"))
return (Prod, Sum)

@T.prim_func(private=True)
def product(
A: T.Buffer([16, 16], "int32"),
B: T.Buffer([16, 16], "int32"),
C: T.Buffer([16, 16], "int32"),
):
for iters in T.grid(*A.shape):
with T.block("compute"):
i, j = T.axis.remap("SS", iters)
C[i, j] = A[i, j] * B[i, j]

@T.prim_func(private=True)
def sum(
A: T.Buffer([16, 16], "int32"),
B: T.Buffer([16, 16], "int32"),
C: T.Buffer([16, 16], "int32"),
):
for iters in T.grid(*A.shape):
with T.block("compute"):
i, j = T.axis.remap("SS", iters)
C[i, j] = A[i, j] + B[i, j]

Expected = Before

# If EliminateCommonSubexpr produces unnormalized expressions,
# normalization of those expressions may produce additional
# variables bindings. This test case should be agnostic to those
# additional bindings, so DCE is applied after CSE.
After = tvm.ir.transform.Sequential(
[
EliminateCommonSubexpr(),
tvm.relax.transform.DeadCodeElimination(),
]
)(Before)
tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
tvm.testing.main()
Loading

0 comments on commit 0ddfc65

Please sign in to comment.