Skip to content

Commit

Permalink
[Unity] Add Relax multi-device e2e cases (#15823)
Browse files Browse the repository at this point in the history
* [Unity] filter out non-GPU primfuncs in default_gpu_schedule

* Add relex heterogeneous e2e case

* Remove get_prim_func_device

* Update test cases

* Fix flake8

* fix lint

* Add test case for change of default_gpu_schedule

* fix comment
  • Loading branch information
yongwww authored Dec 20, 2023
1 parent f328e9b commit 1c35c39
Show file tree
Hide file tree
Showing 15 changed files with 471 additions and 51 deletions.
27 changes: 20 additions & 7 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,20 +243,33 @@ def build(

if not isinstance(inputs, (dict, container.Map)):
target = Target.current() if target is None else target
target = target if target else "llvm"
target_input_mod = {target: input_mod}
if target is None and isinstance(input_mod, tvm.IRModule):
target_mod = {}
for gvar, func in input_mod.functions.items():
tgt = func.attrs["target"] if func.attrs and "target" in func.attrs else "llvm"
if tgt not in target_mod:
target_mod[tgt] = {}
target_mod[tgt][gvar] = func

target_input_mod = {}
for tgt in target_mod.keys():
tir_mod = tvm.IRModule(target_mod[tgt])
tir_mod.with_attrs(input_mod.attrs)
target_input_mod[tgt] = tir_mod
else:
target_input_mod = {target: input_mod}
else:
target_input_mod = inputs
target_input_mod = {tgt: lower(mod) for tgt, mod in inputs.items()}

# Because modules can be created from a variety of sources, we annotate them
# with the relevant attributes here to ensure they propagate
annotated_mods = {}
for tar, mod in target_input_mod.items():
if not isinstance(tar, (str, Target)):
for tgt, mod in target_input_mod.items():
if not isinstance(tgt, (str, Target)):
raise ValueError("The key of inputs must be str or " "Target when inputs is dict.")
if not isinstance(mod, tvm.IRModule):
raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.")
annotated_mods[tar] = mod.with_attr("runtime", runtime)
raise ValueError("inputs must be Schedule, IRModule, " "or dict of str to IRModule.")
annotated_mods[tgt] = mod.with_attr("runtime", runtime)

# TODO(mbs): Both CompilationConfig and TIRToRuntime implement the same host target
# defaulting logic, but there's currently no way to get back the decided host.
Expand Down
26 changes: 24 additions & 2 deletions python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .expr import Tuple as rx_Tuple
from .expr import Expr, ShapeExpr, Function, PrimValue, StringImm, te_tensor
from ..te import Tensor as te_Tensor, create_prim_func
from ..ir import Array, Attrs, Type, Map
from ..ir import Array, Attrs, Type, Map, VDevice
from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo


Expand Down Expand Up @@ -418,6 +418,24 @@ def _populate_used_vars(expr):
diff = used_vars - bound_vars
return list(diff)

def _get_vdevice(arg: Any) -> Optional[VDevice]:
"""get the virtual device from arguments."""
vdevice = None
if isinstance(arg, Expr): # type: ignore
if isinstance(arg.struct_info, TensorStructInfo):
vdevice = arg.struct_info.vdevice
elif isinstance(arg, (list, Array, tuple)):
for x in arg:
vdevice = _get_vdevice(x)
if vdevice is not None:
return vdevice
elif isinstance(arg, (dict, Map)):
for k in arg:
vdevice = _get_vdevice(arg[k])
if vdevice is not None:
return vdevice
return vdevice

def _shape_with_old_tir_var(
shape_values: List[tir.PrimExpr], tir_var_inverse_map: Dict[tir.Var, tir.PrimExpr]
):
Expand Down Expand Up @@ -456,7 +474,11 @@ def _shape_with_old_tir_var(
tir_var_inverse_map = {v: k for k, v in tir_var_map.items()}

output_sinfo = [
TensorStructInfo(_shape_with_old_tir_var(out.shape, tir_var_inverse_map), out.dtype)
TensorStructInfo(
_shape_with_old_tir_var(out.shape, tir_var_inverse_map),
out.dtype,
_get_vdevice(args),
)
for out in outs
]

Expand Down
32 changes: 20 additions & 12 deletions python/tvm/relax/vm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def jit(self, fcompile=None, addons=None, **kwargs) -> tvm.runtime.Module:
vm = tvm.relax.VirtualMachine(rt_mod, tvm.cuda())
"""

# TODO(tvm-team): Update runtime.Module interfac
# TODO(tvm-team): Update runtime.Module interface
# to query these properties as bitmask.
def _not_runnable(x):
return x.type_key in ("c", "static_library")
Expand Down Expand Up @@ -179,13 +179,17 @@ def _vmcodegen(
raise ValueError(f"Unknown exec_mode {exec_mode}")


def _autodetect_system_lib_req(target: tvm.target.Target, system_lib):
def _autodetect_system_lib_req(
target: Optional[tvm.target.Target] = None, system_lib: Optional[bool] = None
):
"""Automatically detect system lib requirement"""
host = target if target.host is None else target.host
if system_lib is None:
system_lib = False
if "wasm" in host.attrs.get("mtriple", ""):
system_lib = True
if target is not None:
host = target if target.host is None else target.host
if system_lib is None:
system_lib = False
if "wasm" in host.attrs.get("mtriple", ""):
system_lib = True

if system_lib:
# use packed-func to avoid relay dep.
return tvm.get_global_func("relay.backend.CreateRuntime")("cpp", {"system-lib": system_lib})
Expand All @@ -194,7 +198,7 @@ def _autodetect_system_lib_req(target: tvm.target.Target, system_lib):

def _vmlink(
builder: "relax.ExecBuilder",
target: Union[str, tvm.target.Target],
target: Optional[Union[str, tvm.target.Target]],
tir_mod: Optional[tvm.IRModule] = None,
ext_libs: List[tvm.runtime.Module] = None,
params: Optional[Dict[str, list]] = None,
Expand All @@ -213,8 +217,10 @@ def _vmlink(
builder: relax.ExecBuilder
Builder used to collect executables.
target : Union[str, tvm.target.Target]
target : Optional[Union[str, tvm.target.Target]]
A build target which can have optional host side compilation target.
If the target is not specified, the target in the vdevice list will be used.
For multi-target compilation, the vdevice should be annotated.
tir_mod: IRModule
The input TIR IRModule to be linked together.
Expand All @@ -239,14 +245,16 @@ def _vmlink(
lib = None
if tir_mod is not None:
lib = tvm.build(
tir_mod, target=target, runtime=_autodetect_system_lib_req(target, system_lib)
tir_mod,
target=target,
runtime=_autodetect_system_lib_req(target, system_lib),
)
return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params)) # type: ignore


def build(
mod: tvm.IRModule,
target: Union[str, tvm.target.Target],
target: Optional[Union[str, tvm.target.Target]] = None,
params: Optional[Dict[str, list]] = None,
pipeline: Union[None, str, tvm.transform.Pass] = "default_build",
exec_mode: str = "bytecode",
Expand All @@ -261,7 +269,7 @@ def build(
mod: IRModule
The input IRModule to be built.
target : Union[str, tvm.target.Target]
target : Optional[Union[str, tvm.target.Target]]
A build target which can have optional host side compilation target.
When TVM compiles device specific program such as CUDA,
Expand Down
7 changes: 1 addition & 6 deletions python/tvm/runtime/relax_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
Parameters
----------
mod: Union[tvm.runtime.Module, tvm.relax.Executable]
rt_mod: Union[tvm.runtime.Module, tvm.relax.Executable]
Runtime module exported by the result of build.
device : Union[Device, List[Device]]
Expand Down Expand Up @@ -107,11 +107,6 @@ def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]])
)
devs = [dev]

if any(dev.device_type % RPC_SESS_MASK == tvm.cpu().device_type for dev in devs[:-1]):
raise RuntimeError(
"CPU host is required to be the last element of the device list if provided."
)

# CPU is required for executing shape functions
if devs[-1].device_type % RPC_SESS_MASK != tvm.cpu().device_type:
devs.append(tvm.cpu())
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,16 @@ def _any_gpu_exists():
)


def _multi_gpu_exists():
return (
(tvm.cuda(0).exist and tvm.cuda(1).exist)
or (tvm.rocm(0).exist and tvm.rocm(1).exist)
or (tvm.opencl(0).exist and tvm.opencl(1).exist)
or (tvm.metal(0).exist and tvm.metal(1).exist)
or (tvm.vulkan(0).exist and tvm.vulkan(1).exist)
)


# Mark a test as requiring llvm to run
requires_llvm = Feature(
"llvm", "LLVM", cmake_flag="USE_LLVM", target_kind_enabled="llvm", target_kind_hardware="llvm"
Expand All @@ -847,6 +857,16 @@ def _any_gpu_exists():
# :py:func:`tvm.testing.requires_gpu`.
uses_gpu = requires_gpu(support_required="optional")

# Mark a test as requiring multiple GPUs to run.
requires_multi_gpu = Feature("multi_gpu", run_time_check=_multi_gpu_exists)

# Mark to differentiate tests that use multiple GPUs in some capacity.
#
# These tests will be run on test nodes with multiple GPUs.
# To mark a test that must have multiple GPUs present to run, use
# :py:func:`tvm.testing.requires_multi_gpu`.
uses_multi_gpu = requires_multi_gpu(support_required="optional")

# Mark a test as requiring the x86 Architecture to run.
requires_x86 = Feature(
"x86", "x86 Architecture", run_time_check=lambda: platform.machine() == "x86_64"
Expand Down
2 changes: 1 addition & 1 deletion src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ void IRModuleNode::Update(const IRModule& mod) {

IRModule IRModuleNode::ShallowCopy() {
return IRModule(this->functions, this->type_definitions, this->Imports(), this->source_map,
this->attrs);
this->attrs, this->global_infos);
}

std::pair<IRModule, GlobalVar> IRModule::FromExprInContext(
Expand Down
39 changes: 31 additions & 8 deletions src/relax/transform/call_tir_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
*/
/*!
* \file src/relax/transform/call_tir_rewrite.cc
* \brief Perform explicit tensor allocation for call_tir.
* \brief Perform explicit tensor allocation for call_tir,
* call_tir_inplace, and call_dps_packed.
*/
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/expr_functor.h>
Expand All @@ -28,6 +29,7 @@
#include <tvm/tir/op.h>

#include "../../relay/transforms/pattern_utils.h"
#include "utils.h"

namespace tvm {
namespace relax {
Expand All @@ -43,6 +45,19 @@ namespace relax {

class CallTIRMutator : public ExprMutator {
public:
explicit CallTIRMutator(const IRModule& mod) : ExprMutator(mod), mod_(std::move(mod)) {}

IRModule Run() {
for (const auto& [gv, func] : mod_->functions) {
if (func->IsInstance<FunctionNode>()) {
auto updated_func = Downcast<Function>(this->VisitExpr(func));
builder_->UpdateFunction(gv, Downcast<BaseFunc>(updated_func));
}
}
return builder_->GetContextIRModule();
}

private:
using ExprMutator::VisitExpr_;
Expr VisitExpr_(const CallNode* call) override {
// post-order mutation
Expand All @@ -65,11 +80,15 @@ class CallTIRMutator : public ExprMutator {
const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value();
ICHECK(tensor_sinfo->shape.defined())
<< "the TensorStructInfo shape of call_tir has not populated";
int dev_index = 0;
if (tensor_sinfo->vdevice.defined()) {
dev_index = GetDeviceIndex(mod_, tensor_sinfo->vdevice.value());
}
if (!is_inplace) {
outs.push_back(
builder_->Emit(Call(alloc_tensor_op, //
builder_->Emit(Call(alloc_tensor_op,
{Downcast<ShapeExpr>(tensor_sinfo->shape.value()),
DataTypeImm(tensor_sinfo->dtype), PrimValue::Int64(0)}, //
DataTypeImm(tensor_sinfo->dtype), PrimValue::Int64(dev_index)},
Attrs()),
"alloc"));
} else {
Expand Down Expand Up @@ -150,16 +169,20 @@ class CallTIRMutator : public ExprMutator {

return GetRef<Expr>(call);
}
};

Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); }
/*! \brief The context IRModule. */
IRModule mod_;
};

namespace transform {

Pass CallTIRRewrite() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(CallTIRRewrite(f)); };
return CreateFunctionPass(pass_func, 0, "CallTIRRewrite", {});
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod, PassContext pc) { return CallTIRMutator(mod).Run(); };
return CreateModulePass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"CallTIRRewrite",
/*required=*/{});
}

TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite);
Expand Down
42 changes: 42 additions & 0 deletions src/relax/transform/legalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>

namespace tvm {
Expand Down Expand Up @@ -72,6 +73,14 @@ class LegalizeMutator : public ExprMutator {
builder_->UpdateFunction(gv, Downcast<BaseFunc>(updated_func));
}
}
// Fill the "kTarget" attribute of PrimFunc
for (const auto& [gv, func] : builder_->GetContextIRModule()->functions) {
const tir::PrimFuncNode* prim_func;
if (tmap_.count(gv) && (prim_func = func.as<tir::PrimFuncNode>())) {
auto f = WithAttr(GetRef<tir::PrimFunc>(prim_func), tvm::attr::kTarget, tmap_[gv]);
builder_->UpdateFunction(gv, f);
}
}
return builder_->GetContextIRModule();
}

Expand Down Expand Up @@ -109,6 +118,33 @@ class LegalizeMutator : public ExprMutator {
return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args);
}

Target GetTarget(const Array<StructInfo>& sinfos) {
for (auto sinfo : sinfos) {
if (const auto* tinfo = sinfo.as<TensorStructInfoNode>()) {
if (tinfo->vdevice.defined()) {
auto vdevice = tinfo->vdevice.value();
if (vdevice->target.defined()) {
return vdevice->target;
}
}
} else if (const auto* tup_sinfo = sinfo.as<TupleStructInfoNode>()) {
return GetTarget(tup_sinfo->fields);
}
}
return Target();
}

void SaveTarget(const Expr& expr) {
if (expr->IsInstance<CallNode>()) {
auto call = Downcast<Call>(expr);
auto target = GetTarget(call->sinfo_args);
const GlobalVarNode* gvar_node;
if (target.defined() && (gvar_node = call->args[0].as<GlobalVarNode>())) {
this->tmap_.Set(GetRef<GlobalVar>(gvar_node), target);
}
}
}

Expr VisitExpr_(const CallNode* call) final {
Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
Expand Down Expand Up @@ -164,6 +200,10 @@ class LegalizeMutator : public ExprMutator {
builder_->BeginBindingBlock();
}
Expr legalized = legalization_func(builder_, visited_call);

// Save the expected target info. into tmap_
SaveTarget(legalized);

legalized = builder_->Normalize(legalized);

BindingBlock prologue = builder_->EndBlock();
Expand Down Expand Up @@ -196,6 +236,8 @@ class LegalizeMutator : public ExprMutator {
IRModule mod_;
/*! \brief The customized legalization function map. */
Map<String, PackedFunc> cmap_;
/*! \brief The map from GlobalVar of PrimFunc to compilation Target. */
Map<GlobalVar, Target> tmap_;
/*!
* \brief A boolean value indicating if to print warnings for CallNode whose op's
* legalization function is not registered.
Expand Down
Loading

0 comments on commit 1c35c39

Please sign in to comment.