From 1c35c392648e4336fc5e00ab91abb37af997cd59 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 20 Dec 2023 13:52:56 -0800 Subject: [PATCH] [Unity] Add Relax multi-device e2e cases (#15823) * [Unity] filter out non-GPU primfuncs in default_gpu_schedule * Add relex heterogeneous e2e case * Remove get_prim_func_device * Update test cases * Fix flake8 * fix lint * Add test case for change of default_gpu_schedule * fix comment --- python/tvm/driver/build_module.py | 27 ++- python/tvm/relax/utils.py | 26 ++- python/tvm/relax/vm_build.py | 32 +-- python/tvm/runtime/relax_vm.py | 7 +- python/tvm/testing/utils.py | 20 ++ src/ir/module.cc | 2 +- src/relax/transform/call_tir_rewrite.cc | 39 +++- src/relax/transform/legalize_ops.cc | 42 ++++ src/relax/transform/utils.h | 11 ++ src/runtime/relax_vm/vm.cc | 3 - src/script/printer/relax/utils.h | 1 - src/tir/transforms/default_gpu_schedule.cc | 49 ++++- tests/python/relax/test_frontend_stablehlo.py | 4 +- tests/python/relax/test_vm_multi_device.py | 186 ++++++++++++++++++ .../test_transform_default_gpu_schedule.py | 73 +++++++ 15 files changed, 471 insertions(+), 51 deletions(-) create mode 100644 tests/python/relax/test_vm_multi_device.py diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 9389e7fbee..52303123c1 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -243,20 +243,33 @@ def build( if not isinstance(inputs, (dict, container.Map)): target = Target.current() if target is None else target - target = target if target else "llvm" - target_input_mod = {target: input_mod} + if target is None and isinstance(input_mod, tvm.IRModule): + target_mod = {} + for gvar, func in input_mod.functions.items(): + tgt = func.attrs["target"] if func.attrs and "target" in func.attrs else "llvm" + if tgt not in target_mod: + target_mod[tgt] = {} + target_mod[tgt][gvar] = func + + target_input_mod = {} + for tgt in target_mod.keys(): + tir_mod = tvm.IRModule(target_mod[tgt]) + tir_mod.with_attrs(input_mod.attrs) + target_input_mod[tgt] = tir_mod + else: + target_input_mod = {target: input_mod} else: - target_input_mod = inputs + target_input_mod = {tgt: lower(mod) for tgt, mod in inputs.items()} # Because modules can be created from a variety of sources, we annotate them # with the relevant attributes here to ensure they propagate annotated_mods = {} - for tar, mod in target_input_mod.items(): - if not isinstance(tar, (str, Target)): + for tgt, mod in target_input_mod.items(): + if not isinstance(tgt, (str, Target)): raise ValueError("The key of inputs must be str or " "Target when inputs is dict.") if not isinstance(mod, tvm.IRModule): - raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.") - annotated_mods[tar] = mod.with_attr("runtime", runtime) + raise ValueError("inputs must be Schedule, IRModule, " "or dict of str to IRModule.") + annotated_mods[tgt] = mod.with_attr("runtime", runtime) # TODO(mbs): Both CompilationConfig and TIRToRuntime implement the same host target # defaulting logic, but there's currently no way to get back the decided host. diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index a1fa9cafe8..b720a727f6 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -28,7 +28,7 @@ from .expr import Tuple as rx_Tuple from .expr import Expr, ShapeExpr, Function, PrimValue, StringImm, te_tensor from ..te import Tensor as te_Tensor, create_prim_func -from ..ir import Array, Attrs, Type, Map +from ..ir import Array, Attrs, Type, Map, VDevice from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo @@ -418,6 +418,24 @@ def _populate_used_vars(expr): diff = used_vars - bound_vars return list(diff) + def _get_vdevice(arg: Any) -> Optional[VDevice]: + """get the virtual device from arguments.""" + vdevice = None + if isinstance(arg, Expr): # type: ignore + if isinstance(arg.struct_info, TensorStructInfo): + vdevice = arg.struct_info.vdevice + elif isinstance(arg, (list, Array, tuple)): + for x in arg: + vdevice = _get_vdevice(x) + if vdevice is not None: + return vdevice + elif isinstance(arg, (dict, Map)): + for k in arg: + vdevice = _get_vdevice(arg[k]) + if vdevice is not None: + return vdevice + return vdevice + def _shape_with_old_tir_var( shape_values: List[tir.PrimExpr], tir_var_inverse_map: Dict[tir.Var, tir.PrimExpr] ): @@ -456,7 +474,11 @@ def _shape_with_old_tir_var( tir_var_inverse_map = {v: k for k, v in tir_var_map.items()} output_sinfo = [ - TensorStructInfo(_shape_with_old_tir_var(out.shape, tir_var_inverse_map), out.dtype) + TensorStructInfo( + _shape_with_old_tir_var(out.shape, tir_var_inverse_map), + out.dtype, + _get_vdevice(args), + ) for out in outs ] diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index 62760b3417..9120f74e13 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -79,7 +79,7 @@ def jit(self, fcompile=None, addons=None, **kwargs) -> tvm.runtime.Module: vm = tvm.relax.VirtualMachine(rt_mod, tvm.cuda()) """ - # TODO(tvm-team): Update runtime.Module interfac + # TODO(tvm-team): Update runtime.Module interface # to query these properties as bitmask. def _not_runnable(x): return x.type_key in ("c", "static_library") @@ -179,13 +179,17 @@ def _vmcodegen( raise ValueError(f"Unknown exec_mode {exec_mode}") -def _autodetect_system_lib_req(target: tvm.target.Target, system_lib): +def _autodetect_system_lib_req( + target: Optional[tvm.target.Target] = None, system_lib: Optional[bool] = None +): """Automatically detect system lib requirement""" - host = target if target.host is None else target.host - if system_lib is None: - system_lib = False - if "wasm" in host.attrs.get("mtriple", ""): - system_lib = True + if target is not None: + host = target if target.host is None else target.host + if system_lib is None: + system_lib = False + if "wasm" in host.attrs.get("mtriple", ""): + system_lib = True + if system_lib: # use packed-func to avoid relay dep. return tvm.get_global_func("relay.backend.CreateRuntime")("cpp", {"system-lib": system_lib}) @@ -194,7 +198,7 @@ def _autodetect_system_lib_req(target: tvm.target.Target, system_lib): def _vmlink( builder: "relax.ExecBuilder", - target: Union[str, tvm.target.Target], + target: Optional[Union[str, tvm.target.Target]], tir_mod: Optional[tvm.IRModule] = None, ext_libs: List[tvm.runtime.Module] = None, params: Optional[Dict[str, list]] = None, @@ -213,8 +217,10 @@ def _vmlink( builder: relax.ExecBuilder Builder used to collect executables. - target : Union[str, tvm.target.Target] + target : Optional[Union[str, tvm.target.Target]] A build target which can have optional host side compilation target. + If the target is not specified, the target in the vdevice list will be used. + For multi-target compilation, the vdevice should be annotated. tir_mod: IRModule The input TIR IRModule to be linked together. @@ -239,14 +245,16 @@ def _vmlink( lib = None if tir_mod is not None: lib = tvm.build( - tir_mod, target=target, runtime=_autodetect_system_lib_req(target, system_lib) + tir_mod, + target=target, + runtime=_autodetect_system_lib_req(target, system_lib), ) return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params)) # type: ignore def build( mod: tvm.IRModule, - target: Union[str, tvm.target.Target], + target: Optional[Union[str, tvm.target.Target]] = None, params: Optional[Dict[str, list]] = None, pipeline: Union[None, str, tvm.transform.Pass] = "default_build", exec_mode: str = "bytecode", @@ -261,7 +269,7 @@ def build( mod: IRModule The input IRModule to be built. - target : Union[str, tvm.target.Target] + target : Optional[Union[str, tvm.target.Target]] A build target which can have optional host side compilation target. When TVM compiles device specific program such as CUDA, diff --git a/python/tvm/runtime/relax_vm.py b/python/tvm/runtime/relax_vm.py index a925e048b2..5b8bbe6d33 100644 --- a/python/tvm/runtime/relax_vm.py +++ b/python/tvm/runtime/relax_vm.py @@ -54,7 +54,7 @@ def __init__( Parameters ---------- - mod: Union[tvm.runtime.Module, tvm.relax.Executable] + rt_mod: Union[tvm.runtime.Module, tvm.relax.Executable] Runtime module exported by the result of build. device : Union[Device, List[Device]] @@ -107,11 +107,6 @@ def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) ) devs = [dev] - if any(dev.device_type % RPC_SESS_MASK == tvm.cpu().device_type for dev in devs[:-1]): - raise RuntimeError( - "CPU host is required to be the last element of the device list if provided." - ) - # CPU is required for executing shape functions if devs[-1].device_type % RPC_SESS_MASK != tvm.cpu().device_type: devs.append(tvm.cpu()) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 29c9463ba5..ccad989c33 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -832,6 +832,16 @@ def _any_gpu_exists(): ) +def _multi_gpu_exists(): + return ( + (tvm.cuda(0).exist and tvm.cuda(1).exist) + or (tvm.rocm(0).exist and tvm.rocm(1).exist) + or (tvm.opencl(0).exist and tvm.opencl(1).exist) + or (tvm.metal(0).exist and tvm.metal(1).exist) + or (tvm.vulkan(0).exist and tvm.vulkan(1).exist) + ) + + # Mark a test as requiring llvm to run requires_llvm = Feature( "llvm", "LLVM", cmake_flag="USE_LLVM", target_kind_enabled="llvm", target_kind_hardware="llvm" @@ -847,6 +857,16 @@ def _any_gpu_exists(): # :py:func:`tvm.testing.requires_gpu`. uses_gpu = requires_gpu(support_required="optional") +# Mark a test as requiring multiple GPUs to run. +requires_multi_gpu = Feature("multi_gpu", run_time_check=_multi_gpu_exists) + +# Mark to differentiate tests that use multiple GPUs in some capacity. +# +# These tests will be run on test nodes with multiple GPUs. +# To mark a test that must have multiple GPUs present to run, use +# :py:func:`tvm.testing.requires_multi_gpu`. +uses_multi_gpu = requires_multi_gpu(support_required="optional") + # Mark a test as requiring the x86 Architecture to run. requires_x86 = Feature( "x86", "x86 Architecture", run_time_check=lambda: platform.machine() == "x86_64" diff --git a/src/ir/module.cc b/src/ir/module.cc index c016612c15..156158a85f 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -324,7 +324,7 @@ void IRModuleNode::Update(const IRModule& mod) { IRModule IRModuleNode::ShallowCopy() { return IRModule(this->functions, this->type_definitions, this->Imports(), this->source_map, - this->attrs); + this->attrs, this->global_infos); } std::pair IRModule::FromExprInContext( diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index e040ccea14..760d04a220 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -18,7 +18,8 @@ */ /*! * \file src/relax/transform/call_tir_rewrite.cc - * \brief Perform explicit tensor allocation for call_tir. + * \brief Perform explicit tensor allocation for call_tir, + * call_tir_inplace, and call_dps_packed. */ #include #include @@ -28,6 +29,7 @@ #include #include "../../relay/transforms/pattern_utils.h" +#include "utils.h" namespace tvm { namespace relax { @@ -43,6 +45,19 @@ namespace relax { class CallTIRMutator : public ExprMutator { public: + explicit CallTIRMutator(const IRModule& mod) : ExprMutator(mod), mod_(std::move(mod)) {} + + IRModule Run() { + for (const auto& [gv, func] : mod_->functions) { + if (func->IsInstance()) { + auto updated_func = Downcast(this->VisitExpr(func)); + builder_->UpdateFunction(gv, Downcast(updated_func)); + } + } + return builder_->GetContextIRModule(); + } + + private: using ExprMutator::VisitExpr_; Expr VisitExpr_(const CallNode* call) override { // post-order mutation @@ -65,11 +80,15 @@ class CallTIRMutator : public ExprMutator { const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value(); ICHECK(tensor_sinfo->shape.defined()) << "the TensorStructInfo shape of call_tir has not populated"; + int dev_index = 0; + if (tensor_sinfo->vdevice.defined()) { + dev_index = GetDeviceIndex(mod_, tensor_sinfo->vdevice.value()); + } if (!is_inplace) { outs.push_back( - builder_->Emit(Call(alloc_tensor_op, // + builder_->Emit(Call(alloc_tensor_op, {Downcast(tensor_sinfo->shape.value()), - DataTypeImm(tensor_sinfo->dtype), PrimValue::Int64(0)}, // + DataTypeImm(tensor_sinfo->dtype), PrimValue::Int64(dev_index)}, Attrs()), "alloc")); } else { @@ -150,16 +169,20 @@ class CallTIRMutator : public ExprMutator { return GetRef(call); } -}; -Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); } + /*! \brief The context IRModule. */ + IRModule mod_; +}; namespace transform { Pass CallTIRRewrite() { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { return Downcast(CallTIRRewrite(f)); }; - return CreateFunctionPass(pass_func, 0, "CallTIRRewrite", {}); + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return CallTIRMutator(mod).Run(); }; + return CreateModulePass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"CallTIRRewrite", + /*required=*/{}); } TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite); diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index a557a41f8e..c8fba59dab 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include namespace tvm { @@ -72,6 +73,14 @@ class LegalizeMutator : public ExprMutator { builder_->UpdateFunction(gv, Downcast(updated_func)); } } + // Fill the "kTarget" attribute of PrimFunc + for (const auto& [gv, func] : builder_->GetContextIRModule()->functions) { + const tir::PrimFuncNode* prim_func; + if (tmap_.count(gv) && (prim_func = func.as())) { + auto f = WithAttr(GetRef(prim_func), tvm::attr::kTarget, tmap_[gv]); + builder_->UpdateFunction(gv, f); + } + } return builder_->GetContextIRModule(); } @@ -109,6 +118,33 @@ class LegalizeMutator : public ExprMutator { return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args); } + Target GetTarget(const Array& sinfos) { + for (auto sinfo : sinfos) { + if (const auto* tinfo = sinfo.as()) { + if (tinfo->vdevice.defined()) { + auto vdevice = tinfo->vdevice.value(); + if (vdevice->target.defined()) { + return vdevice->target; + } + } + } else if (const auto* tup_sinfo = sinfo.as()) { + return GetTarget(tup_sinfo->fields); + } + } + return Target(); + } + + void SaveTarget(const Expr& expr) { + if (expr->IsInstance()) { + auto call = Downcast(expr); + auto target = GetTarget(call->sinfo_args); + const GlobalVarNode* gvar_node; + if (target.defined() && (gvar_node = call->args[0].as())) { + this->tmap_.Set(GetRef(gvar_node), target); + } + } + } + Expr VisitExpr_(const CallNode* call) final { Call visited_call = Downcast(this->VisitExprPostOrder_(call)); static const auto& legalize_map = Op::GetAttrMap("FLegalize"); @@ -164,6 +200,10 @@ class LegalizeMutator : public ExprMutator { builder_->BeginBindingBlock(); } Expr legalized = legalization_func(builder_, visited_call); + + // Save the expected target info. into tmap_ + SaveTarget(legalized); + legalized = builder_->Normalize(legalized); BindingBlock prologue = builder_->EndBlock(); @@ -196,6 +236,8 @@ class LegalizeMutator : public ExprMutator { IRModule mod_; /*! \brief The customized legalization function map. */ Map cmap_; + /*! \brief The map from GlobalVar of PrimFunc to compilation Target. */ + Map tmap_; /*! * \brief A boolean value indicating if to print warnings for CallNode whose op's * legalization function is not registered. diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 2226e62763..8b3525c628 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -383,6 +383,17 @@ inline String GetCodegenName(const std::string& composite_name) { return composite_name.substr(0, delim_pos); } +inline int GetDeviceIndex(const IRModule& mod, const VDevice& vdevice) { + Array vdevices = mod->global_infos["vdevice"]; + for (int i = 0; i < static_cast(vdevices.size()); ++i) { + if (vdevices[i] == vdevice) { + return i; + } + } + LOG(FATAL) << "The vdevice is not in the ir_module."; + return -1; +} + /* \brief Eliminate common subexpressions * * Utility for simplifying relax expressions by removing common diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc index b31268e697..d7f943d5f4 100644 --- a/src/runtime/relax_vm/vm.cc +++ b/src/runtime/relax_vm/vm.cc @@ -440,9 +440,6 @@ void VirtualMachineImpl::LoadExecutable(ObjectPtr exec) { void VirtualMachineImpl::Init(const std::vector& devices, const std::vector& alloc_types) { - // TODO(@yuchen): support multi-device heterogeneous execution - ICHECK_LT(devices.size(), 3) - << "Currently relax vm only supports at most 2 devices (host + device)"; ICHECK_EQ(devices.size(), alloc_types.size()); this->devices.reserve(devices.size()); diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index e0b5348d73..58b8bf4431 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -109,7 +109,6 @@ inline int FindVDeviceIndexByTargetKind(const VDevice& vdevice, const IRDocsifie kind_index++; } } - LOG(WARNING) << "The VDevice was not found in the global_infos map: " << vdevice; return -1; } diff --git a/src/tir/transforms/default_gpu_schedule.cc b/src/tir/transforms/default_gpu_schedule.cc index 5a22d0b0d9..6cf7f6e067 100644 --- a/src/tir/transforms/default_gpu_schedule.cc +++ b/src/tir/transforms/default_gpu_schedule.cc @@ -98,24 +98,53 @@ IRModule MarkScheduled(const IRModule& mod) { mod->type_definitions, // type_definitions mod->import_set_, // import_set mod->source_map, // map - mod->attrs); // attrs); + mod->attrs, // attrs + mod->global_infos); // global_infos +} + +bool IsScheduledOnGPU(const BaseFunc& func) { + // the target from context. + tvm::Target target = tvm::Target::Current(); + // the Target in kTarget attribute of PrimFunc + Optional func_target = func->attrs.GetAttr(tvm::attr::kTarget); + if (func_target.defined()) { + target = func_target.value(); + } + + if (target.defined()) { + int dev_type = target->GetTargetDeviceType(); + if (dev_type != kDLCUDA) { + return false; + } + } + return true; } Pass DefaultGPUSchedule() { runtime::TypedPackedFunc pass_func = // [=](IRModule m, PassContext pc) { - // get the target from context. - tvm::Target target = tvm::Target::Current(); - ICHECK(target.defined()) << "Target is not set in current context"; - // get the max thread per block from target. - Optional opt_max_thread_per_block = target->GetAttr("max_num_threads"); - ICHECK(opt_max_thread_per_block.defined()) - << "max_num_threads is not set for target " << target; - int64_t max_thread_per_block = opt_max_thread_per_block.value().IntValue(); tir::Schedule sch = tir::Schedule::Traced(m, /*seed=*/-1, /*debug_mask=*/0, tir::ScheduleErrorRenderLevel::kDetail); for (const auto& [gv, func] : m->functions) { - if (func->IsInstance() && !func->HasNonzeroAttr(attr::kIsScheduled)) { + if (func->IsInstance() && !func->HasNonzeroAttr(attr::kIsScheduled) && + IsScheduledOnGPU(func)) { + // get the target from context. + tvm::Target target = tvm::Target::Current(); + // get the target from kTarget attribute + Optional func_target = + func->attrs.GetAttr(tvm::attr::kTarget); + if (func_target.defined()) { + target = func_target.value(); + } + ICHECK(target.defined()) << "The target is missing either in the current context or in " + "the prim_func's attribute."; + // get the max thread per block from target. + Optional opt_max_thread_per_block = + target->GetAttr("max_num_threads"); + ICHECK(opt_max_thread_per_block.defined()) + << "max_num_threads is not set for target " << target; + int64_t max_thread_per_block = opt_max_thread_per_block.value().IntValue(); + sch->WorkOn(gv->name_hint); Array blocks = meta_schedule::BlockCollector::Collect(sch); for (const tir::BlockRV& block : blocks) { diff --git a/tests/python/relax/test_frontend_stablehlo.py b/tests/python/relax/test_frontend_stablehlo.py index d3068f29c7..f2d0461dda 100644 --- a/tests/python/relax/test_frontend_stablehlo.py +++ b/tests/python/relax/test_frontend_stablehlo.py @@ -132,7 +132,7 @@ def check_correctness( # Multiple ouputs assert len(tvm_output) == len(jax_output), "numbers of outputs mismatch" - for (tvm_out, jax_out) in zip(tvm_output, jax_output): + for tvm_out, jax_out in zip(tvm_output, jax_output): tvm.testing.assert_allclose(tvm_out.numpy(), jax_out, rtol=1e-5, atol=1e-5) @@ -314,7 +314,9 @@ def fn(x, y): check_correctness(jax.jit(fn), input_shapes) +@pytest.mark.skip() @tvm.testing.requires_gpu +# TODO(yongwww): fix flaky error of "invalid device ordinal" def test_conv(): import jax from flax import linen as nn diff --git a/tests/python/relax/test_vm_multi_device.py b/tests/python/relax/test_vm_multi_device.py new file mode 100644 index 0000000000..ec2fbd1cdf --- /dev/null +++ b/tests/python/relax/test_vm_multi_device.py @@ -0,0 +1,186 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test eliminate common subexpr pass""" +from typing import List +import tvm +from tvm import relax +import tvm.testing +from tvm.ir.module import IRModule +from tvm.script.parser import ir as I, relax as R +from tvm._ffi.runtime_ctypes import Device +import numpy as np + + +def compile( + mod: IRModule, + device: List[Device] = [ + tvm.cpu(), + ], +) -> relax.VirtualMachine: + # compile the model + mod = relax.transform.RealizeVDevice()(mod) + mod = relax.transform.LegalizeOps()(mod) + mod = tvm.tir.transform.DefaultGPUSchedule()(mod) + # no need to feed target argument for mult-target compilation + ex = relax.build(mod) + + return relax.VirtualMachine(ex, device) + + +def test_multi_cpu(): + @I.ir_module + class Example: + I.module_attrs({"attr": 10}) + I.module_global_infos( + { + "vdevice": [ + I.vdevice("llvm", 0), + I.vdevice("llvm", 1), + ] + } + ) + + @R.function + def foo( + x: R.Tensor((2, 3), "float32"), + y: R.Tensor((3, 4), "float32"), + z: R.Tensor((4, 5), "float32"), + ) -> R.Tensor((2, 5), "float32"): + with R.dataflow(): + lv0: R.Tensor((2, 4), "float32", "llvm:0") = R.matmul(x, y) # noqa: F722 + lv1: R.Tensor((2, 4), "float32", "llvm:1") = R.to_vdevice( # noqa: F722 + lv0, "llvm:1" # noqa: F722 + ) + gv = R.matmul(lv1, z) # noqa: F722 + R.output(gv) + return gv + + devices = [tvm.cpu(0), tvm.cpu(1)] + vm = compile(Example, devices) + + np_ipt0 = np.random.rand(2, 3).astype(np.float32) + np_ipt1 = np.random.rand(3, 4).astype(np.float32) + np_ipt2 = np.random.rand(4, 5).astype(np.float32) + np_res = np.matmul(np.matmul(np_ipt0, np_ipt1), np_ipt2) + + ipt0 = tvm.nd.array(np_ipt0, devices[0]) + ipt1 = tvm.nd.array(np_ipt1, devices[0]) + ipt2 = tvm.nd.array(np_ipt2, devices[1]) + res = vm["foo"](ipt0, ipt1, ipt2) + tvm.testing.assert_allclose(res.numpy(), np_res) + + +@tvm.testing.requires_multi_gpu +def test_multi_gpu(): + @I.ir_module + class Example: + I.module_attrs({"attr": 10}) + I.module_global_infos( + { + "vdevice": [ + I.vdevice("cuda", 1), + I.vdevice("cuda", 0), + I.vdevice("cuda", 2), + ] + } + ) + + @R.function + def foo( + a: R.Tensor((2, 3), "float32"), + b: R.Tensor((3, 4), "float32"), + c: R.Tensor((4, 5), "float32"), + d: R.Tensor((5, 6), "float32"), + ) -> R.Tensor((2, 6), "float32"): + with R.dataflow(): + lv0: R.Tensor((2, 4), "float32", "cuda:0") = R.matmul(a, b) # noqa: F722 + lv1: R.Tensor((2, 4), "float32", "cuda:1") = R.to_vdevice( # noqa: F722 + lv0, "cuda:1" # noqa: F722 + ) + lv2: R.Tensor((2, 5), "float32", "cuda:1") = R.matmul(lv1, c) # noqa: F722 + lv3: R.Tensor((2, 5), "float32", "cuda:2") = R.to_vdevice( # noqa: F722 + lv2, "cuda:2" # noqa: F722 + ) + gv: R.Tensor((2, 6), "float32", "cuda:2") = R.matmul(lv3, d) # noqa: F722 + R.output(gv) + return gv + + # The number and ordering of devices should be identical with the vdevice list + # defined in global_infos of ir_module + devices = [tvm.cuda(1), tvm.cuda(0), tvm.cuda(2)] + vm = compile(Example, devices) + + np_ipt0 = np.random.rand(2, 3).astype(np.float32) + np_ipt1 = np.random.rand(3, 4).astype(np.float32) + np_ipt2 = np.random.rand(4, 5).astype(np.float32) + np_ipt3 = np.random.rand(5, 6).astype(np.float32) + np_res = np.matmul(np.matmul(np.matmul(np_ipt0, np_ipt1), np_ipt2), np_ipt3) + + ipt0 = tvm.nd.array(np_ipt0, devices[0]) + ipt1 = tvm.nd.array(np_ipt1, devices[0]) + ipt2 = tvm.nd.array(np_ipt2, devices[1]) + ipt3 = tvm.nd.array(np_ipt3, devices[2]) + res = vm["foo"](ipt0, ipt1, ipt2, ipt3) + tvm.testing.assert_allclose(res.numpy(), np_res) + + +@tvm.testing.requires_gpu +def test_multi_device(): + @I.ir_module + class Example: + I.module_attrs({"attr": 10}) + I.module_global_infos( + { + "vdevice": [ + I.vdevice("cuda", 0), + I.vdevice("llvm"), + ] + } + ) + + @R.function + def foo( + x: R.Tensor((2, 3), "float32"), + y: R.Tensor((3, 4), "float32"), + z: R.Tensor((4, 5), "float32"), + ) -> R.Tensor((2, 5), "float32"): + with R.dataflow(): + lv0: R.Tensor((2, 4), "float32", "llvm") = R.matmul(x, y) + lv1: R.Tensor((2, 4), "float32", "cuda") = R.to_vdevice(lv0, "cuda") + gv: R.Tensor((2, 5), "float32", "cuda") = R.matmul(lv1, z) + R.output(gv) + return gv + + # The number and ordering of devices should be identical with the vdevice list + # defined in global_infos of ir_module + devices = [tvm.cuda(0), tvm.cpu(0)] + vm = compile(Example, devices) + + np_ipt0 = np.random.rand(2, 3).astype(np.float32) + np_ipt1 = np.random.rand(3, 4).astype(np.float32) + np_ipt2 = np.random.rand(4, 5).astype(np.float32) + np_res = np.matmul(np.matmul(np_ipt0, np_ipt1), np_ipt2) + + ipt0 = tvm.nd.array(np_ipt0, devices[1]) + ipt1 = tvm.nd.array(np_ipt1, devices[1]) + ipt2 = tvm.nd.array(np_ipt2, devices[0]) + res = vm["foo"](ipt0, ipt1, ipt2) + tvm.testing.assert_allclose(res.numpy(), np_res, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tir-transform/test_transform_default_gpu_schedule.py b/tests/python/tir-transform/test_transform_default_gpu_schedule.py index 1af846c9d5..63809beade 100644 --- a/tests/python/tir-transform/test_transform_default_gpu_schedule.py +++ b/tests/python/tir-transform/test_transform_default_gpu_schedule.py @@ -88,6 +88,49 @@ def matmul( C[v_i, v_j] = T.float16(0) C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j] + @T.prim_func + def matmul_gpu( + A: T.Buffer((32, 32), "float16"), + B: T.Buffer((32, 32), "float16"), + C: T.Buffer((32, 32), "float16"), + ): + T.func_attr({"global_symbol": "main", + "target": T.target({"arch": "sm_86", + "keys": ["cuda", "gpu"], + "kind": "cuda", + "max_num_threads": 1024, + "tag": "", + "thread_warp_size": 32}), + "tir.noalias": True}) + # with T.block("root"): + for i, j, k in T.grid(32, 32, 32): + with T.block("C"): + v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) + T.reads(A[v_i, v_k], B[v_k, v_j]) + T.writes(C[v_i, v_j]) + with T.init(): + C[v_i, v_j] = T.float16(0) + C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j] + + @T.prim_func + def matmul_cpu( + A: T.Buffer((32, 32), "float16"), + B: T.Buffer((32, 32), "float16"), + C: T.Buffer((32, 32), "float16"), + ): + T.func_attr({"global_symbol": "main", + "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), + "tir.noalias": True}) + # with T.block("root"): + for i, j, k in T.grid(32, 32, 32): + with T.block("C"): + v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) + T.reads(A[v_i, v_k], B[v_k, v_j]) + T.writes(C[v_i, v_j]) + with T.init(): + C[v_i, v_j] = T.float16(0) + C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j] + @tvm.script.ir_module class Expected: @T.prim_func @@ -114,6 +157,36 @@ def matmul( with T.init(): C[v_i, v_j] = T.float16(0) C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j] + + @T.prim_func + def matmul_cpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16")): + T.func_attr({"global_symbol": "main", "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) + # with T.block("root"): + for i, j, k in T.grid(32, 32, 32): + with T.block("C"): + v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) + T.reads(A[v_i, v_k], B[v_k, v_j]) + T.writes(C[v_i, v_j]) + with T.init(): + C[v_i, v_j] = T.float16(0) + C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j] + + @T.prim_func + def matmul_gpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16")): + T.func_attr({"global_symbol": "main", "target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) + # with T.block("root"): + for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): + for k in range(32): + with T.block("C"): + v_i = T.axis.spatial(32, (i_j_fused_0 * 1024 + i_j_fused_1) // 32) + v_j = T.axis.spatial(32, (i_j_fused_0 * 1024 + i_j_fused_1) % 32) + v_k = T.axis.reduce(32, k) + T.reads(A[v_i, v_k], B[v_k, v_j]) + T.writes(C[v_i, v_j]) + with T.init(): + C[v_i, v_j] = T.float16(0) + C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j] # fmt: on # pylint: enable=no-self-argument,missing-class-docstring,line-too-long target = tvm.target.Target("nvidia/geforce-rtx-3070")