diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h index e419a421115b..8e3e9d92f554 100644 --- a/include/tvm/relax/attrs/op.h +++ b/include/tvm/relax/attrs/op.h @@ -57,6 +57,20 @@ struct CallTIRInplaceAttrs : public tvm::AttrsNode { } }; // struct CallTIRInplaceAttrs +/*! \brief Attributes used in call_inplace_packed */ +struct CallInplacePackedAttrs : public tvm::AttrsNode { + Array inplace_indices; + + TVM_DECLARE_ATTRS(CallInplacePackedAttrs, "relax.attrs.CallInplacePackedAttrs") { + TVM_ATTR_FIELD(inplace_indices) + .describe( + "Indices that describe which input corresponds to which output. If the `i`th member " + "has the value `k` >= 0, then that means that input `k` should be used to store the " + "`i`th output. If an element has the value -1, that means the output will be newly " + "allocated."); + } +}; // struct CallInplacePackedAttrs + /*! \brief Attributes used in to_vdevice */ struct ToVDeviceAttrs : public tvm::AttrsNode { VDevice dst_vdevice; diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 77f1d0ff44e0..60a4332d838c 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -25,6 +25,7 @@ assert_op, call_builtin_with_ctx, call_dps_packed, + call_inplace_packed, call_pure_packed, call_tir, call_tir_inplace, diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 603148d0cf38..b363dc6952d8 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -604,6 +604,72 @@ def shape_to_tensor(expr: Expr) -> Expr: return _ffi_api.shape_to_tensor(expr) # type: ignore # pylint: disable=no-member +@args_converter.auto +def call_inplace_packed( + func: Union[str, ExternFunc, GlobalVar], + *args: Expr, + inplace_indices: Union[int, List[int]], + sinfo_args: Union[StructInfo, List[StructInfo]], +) -> Expr: + """ + Construct a call to a packed function that consumes some of its arguments "in-place" + and returns the mutated arguments (aliased), but should be considered to be otherwise pure. + The `inplace_indices` argument indicates which of the outputs are mutated arguments. + + The resulting call will have the same semantics as calling the packed function directly. + + Note: This should be used for cases when the user knows that calling the packed function + with these arguments will **in reality** not cause any other side effects. + If it is used for a call that **does** result in other side effects, then the compiler + may end up removing, reordering, or repeating that call, with no guarantees + made about any side effects from the callee. + + Warning: This operator as treated as pure by the type system even though it *is* performing + side effects (mutating some arguments). It is therefore incumbent upon the user to ensure + that it is being used safely (viz., that mutated arguments are not live after the mutation, + that they do not alias values live after the mutation). + + Parameters + ---------- + func : Union[str, ExternFunc] + The name (global symbol) for a PackedFunc or an ExternFunc node. + + args: Expr + The arguments for the PackedFunc. + + input_indices : Union[int, List[int]] + Specify which arguments should be used for in-place computations. + If `input_indices` is a single integer, it will be made into a singleton list. + Suppose `input_indices[i] = j`, where `j >= 0`. Then the `i`th output + will be an alias of `args[j]`. + If `input_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor. + At least one member of `input_indices` must not be -1. + + sinfo_args: Union[StructInfo, List[StructInfo]] + The list of structure info arguments (giving the structural info for the returned value). + + Returns + ------- + result : Expr + A Relax call, corresponding to + `call_pure_packed(ExternFunc(func), args, DictAttrs(kwargs), sinfo_args)` + """ + if isinstance(func, ExternFunc): + func = func.global_symbol + + op = ExternFunc(func) + if sinfo_args is None: + raise ValueError("R.call_pure_packed is required to have type_args") + if isinstance(sinfo_args, tuple): # type: ignore + sinfo_args = list(sinfo_args) + elif not isinstance(sinfo_args, list): + sinfo_args = [sinfo_args] + if not isinstance(inplace_indices, list): + inplace_indices = [inplace_indices] + + return _ffi_api.call_inplace_packed(op, args, inplace_indices, sinfo_args) # type: ignore # pylint: disable=no-member + + @args_converter.auto def call_pure_packed( func: Union[str, ExternFunc, GlobalVar], diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 2166827f48c4..142d0e6d96aa 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -52,6 +52,7 @@ broadcast_to, builtin, call_builtin_with_ctx, + call_inplace_packed, call_pure_packed, call_tir, call_tir_inplace, @@ -650,6 +651,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "bitwise_xor", "broadcast_to", "builtin", + "call_inplace_packed", "call_packed", "call_pure_packed", "call_tir", diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 326a219ed50f..01d0d04be0cc 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -121,6 +121,125 @@ Expr MakeCallPurePacked(const Expr& callee, Array args, const Attrs& attrs TVM_REGISTER_GLOBAL("relax.op.call_pure_packed").set_body_typed(MakeCallPurePacked); +// call_inplace_packed + +StructInfo InferStructInfoCallInplacePacked(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() <= 1) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "call_inplace_packed must be called with at least two arguments" + << " (the packed call and at least one argument to the packed call" + << "if the packed call does not need arguments, use call_pure_packed instead)"); + } + + // the callee must be an opaque function + auto callee = call->args[0]; + ICHECK(!callee.as()) << "call_pure_packed cannot be used with an op node"; + auto opt = MatchStructInfo(callee); + ICHECK(opt) << "Callee must have a function struct info"; + FuncStructInfo finfo = opt.value(); + ICHECK(finfo->IsOpaque()) << "call_pure_packed must be called with an opaque function, but " + << callee << " is not opaque"; + + // check the range for inplace indices, make sure at least one is not -1, ensure they're unique + const auto* attrs = call->attrs.as(); + size_t num_args = call->args.size() - 1; + std::unordered_set encountered; + for (size_t i = 0; i < attrs->inplace_indices.size(); i++) { + int index = attrs->inplace_indices[i].IntValue(); + if (index < -1 || index >= static_cast(num_args)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "In-place index " << i << " is out of range (must be between -1 and " + << (num_args - 1) << ", inclusive, but is " << index << ")"); + } + if (index != -1) { + if (encountered.count(index)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "All in-place indices must be unique, but index " << index + << " appears more than once."); + } + encountered.insert(index); + } + } + if (encountered.empty()) { + ctx->ReportFatal(Diagnostic::Error(call) << "At least one index must have a value other than " + "-1 (or else simply use call_pure_packed)"); + } + + // same logic as from DeriveCallRetStructInfo for ordinary calls + StructInfo ret; + if (finfo->derive_func.defined()) { + // derive using custom derivation function. + ret = finfo->derive_func.value()(call, ctx); + } else { + // directly return the normal value. + ret = finfo->ret; + } + + // make sure that the derived return struct info matches that of the in-place args + // (note: arg 0 is the packed func, so we add 1 to the arg index) + if (attrs->inplace_indices.size() == 1) { + auto arg_idx = attrs->inplace_indices[0].IntValue() + 1; + auto arg_sinfo = GetStructInfo(call->args[arg_idx]); + if (!IsBaseOf(ret, arg_sinfo, ctx->GetAnalyzer())) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The derived return StructInfo does not match that for " + << "the in-place argument at index " << (arg_idx - 1) << ": " << ret + << " vs " << arg_sinfo); + } + } else { + auto* tup_info = ret.as(); + if (!tup_info) { + ctx->ReportFatal(Diagnostic::Error(call) << "Multiple outputs given via the inplace indices " + "but the derived StructInfo is not a tuple"); + } + for (size_t i = 0; i < attrs->inplace_indices.size(); i++) { + if (attrs->inplace_indices[i] == -1) { + continue; + } + auto arg_idx = attrs->inplace_indices[i].IntValue() + 1; + auto arg_sinfo = GetStructInfo(call->args[arg_idx]); + auto ret_sinfo = tup_info->fields[i]; + if (!IsBaseOf(ret_sinfo, arg_sinfo, ctx->GetAnalyzer())) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The derived return StructInfo does not match that for " + << "the in-place argument at index " << (arg_idx - 1) << ": " << ret_sinfo + << " vs " << arg_sinfo); + } + } + } + + return ret; +} + +TVM_REGISTER_NODE_TYPE(CallInplacePackedAttrs); + +RELAY_REGISTER_OP("relax.call_inplace_packed") + .set_num_inputs(-1) + .set_attrs_type() + .add_argument("args", "Array", + "The first argument is the function being called. The rest are the " + "arguments to that function.") + .set_attr("FInferStructInfo", InferStructInfoCallInplacePacked) + // Warning: considered pure, but it has the potential to create visible effects! + // This should only be used if it has been *checked* that it is safe (no aliases, in-place + // arguments will no longer be live) and the user believes the packed func to have no + // side effects other than modifying the arguments specified as "inplace" + .set_attr("FPurity", Bool(true)); + +Expr MakeCallInplacePacked(Expr func, Array args, Array inplace_indices, + Array sinfo_args) { + ObjectPtr attrs = make_object(); + attrs->inplace_indices = Array(inplace_indices.begin(), inplace_indices.end()); + + static const Op& op = Op::Get("relax.call_inplace_packed"); + Array call_args = {func}; + call_args.insert(call_args.end(), args.begin(), args.end()); + return Call(op, call_args, Attrs(attrs), sinfo_args); +} + +TVM_REGISTER_GLOBAL("relax.op.call_inplace_packed").set_body_typed(MakeCallInplacePacked); + // call_tir StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { diff --git a/src/relax/transform/remove_purity_checking.cc b/src/relax/transform/remove_purity_checking.cc index f190968d9da1..7e3a31d0ba13 100644 --- a/src/relax/transform/remove_purity_checking.cc +++ b/src/relax/transform/remove_purity_checking.cc @@ -52,6 +52,12 @@ class PurityRemover : public ExprMutator { call->attrs, call->sinfo_args); return VisitExpr(ret); } + if (call->op == call_inplace_packed_op_) { + // call_inplace_packed has its own attrs so we don't pass those down + auto ret = Call(call->args[0], Array(call->args.begin() + 1, call->args.end()), + tvm::Attrs(), call->sinfo_args); + return VisitExpr(ret); + } if (call->op == invoke_pure_closure_op_) { auto ret = Call(invoke_closure_op_, call->args, call->attrs, call->sinfo_args); return VisitExpr(ret); @@ -66,6 +72,7 @@ class PurityRemover : public ExprMutator { private: const Op& call_pure_packed_op_ = Op::Get("relax.call_pure_packed"); + const Op& call_inplace_packed_op_ = Op::Get("relax.call_inplace_packed"); const Op& invoke_pure_closure_op_ = Op::Get("relax.invoke_pure_closure"); const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); }; diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index 632ac96ff4ec..a278b0916772 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -39,7 +39,9 @@ def run_cpu(mod, func_name, *input): target = tvm.target.Target("llvm") ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, tvm.cpu()) - return vm[func_name](*input) + vm.set_input(func_name, *input) + vm.invoke_stateful(func_name) + return vm.get_outputs(func_name) def test_unique(): @@ -248,6 +250,89 @@ def pure_copy(x: R.Tensor((3, 4), "float32")): assert (copy_found.numpy() == arr).all() +def test_op_call_inplace_packed(): + # in this case we can use the same test as above + @tvm.script.ir_module + class CallInplaceTest: + @R.function + def pure_copy(x: R.Tensor((3, 4), "float32")): + z = R.call_inplace_packed( + "vm.builtin.copy", + x, + inplace_indices=0, + sinfo_args=(R.Tensor((3, 4), dtype="float32")), + ) + return z + + @tvm.register_func("test.inplace.add") + def inplace_add(a, b): + arr_a = a.numpy() + arr_b = b.numpy() + for i in range(len(arr_a)): + for j in range(len(arr_a[i])): + arr_a[i][j] = arr_a[i][j] + arr_b[i][j] + a.copyfrom(arr_a) + return a + + @tvm.script.ir_module + class CallInplaceAddTest: + @R.function + def inplace_add(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): + z = R.call_inplace_packed( + "test.inplace.add", + x, + y, + inplace_indices=0, + sinfo_args=(R.Tensor((3, 4), dtype="float32")), + ) + return z + + np.random.seed(1) # to avoid flakiness + arr_a = np.random.rand(3, 4).astype("float32") + arr_b = np.random.rand(3, 4).astype("float32") + sum = arr_a + arr_b + tvm_arr_a = tvm.nd.array(arr_a) + result = run_cpu(CallInplaceAddTest, "inplace_add", tvm_arr_a, tvm.nd.array(arr_b)) + assert result == tvm_arr_a + assert (result.numpy() == sum).all() + + @tvm.register_func("test.inplace.tuple_add") + def inplace_tuple_add(a, b): + arr_a = a.numpy() + arr_b = b.numpy() + c = tvm.nd.array(arr_a + arr_b) + for i in range(len(arr_a)): + for j in range(len(arr_a[i])): + arr_a[i][j] = arr_a[i][j] + arr_b[i][j] + a.copyfrom(arr_a) + return tvm.runtime.container.ADT(0, [a, c]) + + @tvm.script.ir_module + class CallInplaceTuple: + @R.function + def inplace_tuple(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): + z = R.call_inplace_packed( + "test.inplace.tuple_add", + x, + y, + inplace_indices=[0, -1], + sinfo_args=(R.Tensor((3, 4), dtype="float32"), R.Tensor((3, 4), dtype="float32")), + ) + return z + + np.random.seed(2) # to avoid flakiness + arr_a = np.random.rand(3, 4).astype("float32") + arr_b = np.random.rand(3, 4).astype("float32") + sum = arr_a + arr_b + tvm_arr_a = tvm.nd.array(arr_a) + tvm_arr_b = tvm.nd.array(arr_b) + result = run_cpu(CallInplaceTuple, "inplace_tuple", tvm_arr_a, tvm_arr_b) + assert result[0] == tvm_arr_a + assert (result[0].numpy() == sum).all() + assert result[1] != tvm_arr_a and result[1] != tvm_arr_b + assert (result[1].numpy() == sum).all() + + def test_op_to_device(): @tvm.script.ir_module class CallToDevice: