diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 72dc8a5c9bf9..1493544e7324 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -190,16 +190,6 @@ constexpr const char* kTarget = "target"; */ constexpr const char* kGlobalSymbol = "global_symbol"; -/*! - * \brief The \p VirtualDevice which will hold each of the functions parameters. - * - * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but - * may be included as an annotation on user programs. - * - * Type: Array - */ -constexpr const char* kParamVirtualDevice = "param_virtual_devices"; - } // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 4bbc2df0ae48..ea3a5dba6bf7 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -499,6 +499,10 @@ TVM_DLL Pass PlanDevices(CompilationConfig config); /*! * \brief Bind the free variables to a Relay expression. This is a helper * function usually called by other pass functions to help optimizations. + * If any free variables are introduced into a function, those are added + * to the functoin parameters. + * Additionally this may change the order of parameters if you map a variable + * to a variable. * * \param expr The input expression. * \param binds The variable to expression map that will be used to help the @@ -508,6 +512,19 @@ TVM_DLL Pass PlanDevices(CompilationConfig config); */ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); +/*! + * \brief Substitute variables with new variables (including function parameters) in a function. + * This is a helper function usually called by other pass functions to help optimizations. + * Expects all values in the bind map to be Vars. + * + * \param func The input function. + * \param binds The variable to expression map that will be used to help the + * binding. + * + * \return The updated expression. + */ +TVM_DLL Function SubstituteBoundVars(const Function& func, const tvm::Map& binds); + /*! * \brief Apply rewrite rules to rewrite the expr in post DFS order. This * function is used as a helper function to rewrtie an expression in a pass. diff --git a/include/tvm/target/virtual_device.h b/include/tvm/target/virtual_device.h index 4a40777563af..37f4b23b12c2 100644 --- a/include/tvm/target/virtual_device.h +++ b/include/tvm/target/virtual_device.h @@ -367,7 +367,7 @@ class VirtualDeviceCache { * * Type: VirtualDevice */ -constexpr const char* kVirtualDevice = "result_virtual_device"; +constexpr const char* kVirtualDevice = "virtual_device"; } // namespace tvm diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 62eb1d43153f..9b15893092f7 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -456,9 +456,13 @@ class Parser { * * "x" -> Var("x"), these are needed to map from the raw string names * to unique variable nodes. + * If a virtual device is specified, sets the virtual device of the variable. */ - Var BindVar(const std::string& name, const relay::Type& type_annotation) { + Var BindVar(const std::string& name, const relay::Type& type_annotation, + Optional virtual_device = Optional()) { auto var = Var(name, type_annotation); + var->virtual_device_ = virtual_device.value_or(VirtualDevice::FullyUnconstrained()); + VLOG(1) << "Binding var named " << name << " to variable node " << PrettyPrint(var); this->expr_scopes.Add(name, var); return var; } @@ -1113,11 +1117,26 @@ class Parser { [&]() { auto token = Match(TokenType::kLocal); auto string = token.ToString(); + + // The fake attributes where the virtual device is specified. + VirtualDevice virtual_device; + if (WhenMatch(TokenType::kLCurly)) { + Map fake_attrs = ParseAttrs(); + VLOG(9) << "Fake attributes for function parameter: " << fake_attrs; + Match(TokenType::kRCurly); + if (fake_attrs.size() == 1 && fake_attrs.count(kVirtualDevice)) { + ICHECK(fake_attrs[kVirtualDevice].as()) + << "Expected the " << kVirtualDevice + << " to have type VirtualDeviceNode, but got " << virtual_device->GetTypeKey(); + virtual_device = Downcast(fake_attrs[kVirtualDevice]); + } + } + Type type; if (WhenMatch(TokenType::kColon)) { type = ParseType(); } - return BindVar(string, type); + return BindVar(string, type, virtual_device); }, [&] { auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; @@ -1150,8 +1169,15 @@ class Parser { ICHECK(vid.as()) << "Expected the " << kVirtualDevice << " to have type VirtualDeviceNode, but got " << vid->GetTypeKey(); - raw_attrs.erase(kVirtualDevice); - Function func = relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs)); + + DictAttrs attrs; + // Don't fill the raw_attrs in if there's nothing other than kVirtualDevice in the + // attributes + if (raw_attrs.size() > 1) { + raw_attrs.erase(kVirtualDevice); + attrs = DictAttrs(raw_attrs); + } + Function func = relay::Function(params, body, ret_type, generics, attrs); func->virtual_device_ = vid; return func; } else { diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 0ef45d878393..97231931ad88 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -220,9 +220,13 @@ Doc RelayTextPrinter::AllocVar(const Var& var) { } Doc val = GetUniqueName("%" + name); memo_[var] = val; + if (!var->virtual_device()->IsFullyUnconstrained()) { + val << " {" << kVirtualDevice << "=" << PrintAttributeValue(var->virtual_device()) << "}"; + } if (var->type_annotation.defined()) { val << ": " << Print(var->type_annotation); } + val << PrintOptionalInfo(var); return val; } @@ -445,7 +449,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) { for (const Doc& d : PrintDictAttrs(fn->attrs)) { params.push_back(d); } - if (fn->virtual_device() != VirtualDevice::FullyUnconstrained()) { + if (!fn->virtual_device()->IsFullyUnconstrained()) { Doc vid_doc; vid_doc << kVirtualDevice << "=" << PrintAttributeValue(fn->virtual_device()); params.push_back(vid_doc); @@ -454,7 +458,6 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) { if (fn->ret_type.defined()) { doc << "-> " << Print(fn->ret_type) << " "; } - doc << PrintBody(fn->body); return doc; } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index ef8bccf42d34..e94919de7f20 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -252,21 +252,16 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { // Do that flattening on-the-fly here. Function inner_func = Downcast(func->body); std::vector params; - std::vector param_virtual_devices; params.reserve(func->params.size() + inner_func->params.size()); - param_virtual_devices.reserve(func->params.size() + inner_func->params.size()); param_device_indexes.reserve(func->params.size() + inner_func->params.size()); for (size_t i = 0; i < func->params.size(); ++i) { params.emplace_back(func->params[i]); - VirtualDevice param_virtual_device = GetFunctionParamVirtualDevice(func.get(), i); - param_virtual_devices.push_back(param_virtual_device); - param_device_indexes.push_back(GetDeviceIndex(param_virtual_device)); + param_device_indexes.push_back(GetDeviceIndex(func->params[i]->virtual_device())); } for (size_t i = 0; i < inner_func->params.size(); ++i) { params.emplace_back(inner_func->params[i]); - VirtualDevice param_virtual_device = GetFunctionParamVirtualDevice(inner_func.get(), i); - param_virtual_devices.push_back(param_virtual_device); - param_device_indexes.push_back(GetDeviceIndex(param_virtual_device)); + + param_device_indexes.push_back(GetDeviceIndex(inner_func->params[i]->virtual_device())); } std::vector type_params; type_params.reserve(func->type_params.size() + inner_func->type_params.size()); @@ -278,13 +273,12 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { } Function flattened_func = Function(params, inner_func->body, inner_func->ret_type, type_params, func->attrs, func->span); - VisitExpr(MaybeFunctionOnDevice(flattened_func, param_virtual_devices, - GetFunctionResultVirtualDevice(inner_func.get()))); + flattened_func->virtual_device_ = inner_func->virtual_device(); + VisitExpr(flattened_func); } else { param_device_indexes.reserve(func->params.size()); for (size_t i = 0; i < func->params.size(); ++i) { - param_device_indexes.push_back( - GetDeviceIndex(GetFunctionParamVirtualDevice(func.get(), i))); + param_device_indexes.push_back(GetDeviceIndex(func->params[i]->virtual_device())); } VisitExpr(func); } diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index f2bd9e6b9a8a..b2912f6263dc 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -111,7 +111,6 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { auto free_type_vars = FreeTypeVars(func, module_); Array captured_vars; - std::vector captured_var_virtual_devices; bool recursive = false; for (const auto& var : free_vars) { if (!letrec_.empty() && var == letrec_.back()) { @@ -119,7 +118,6 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { continue; } captured_vars.push_back(var); - captured_var_virtual_devices.push_back(GetVirtualDevice(var)); } // Freshen all the captured vars. @@ -127,6 +125,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { Map rebinding_map; for (auto free_var : captured_vars) { auto var = Var(free_var->name_hint(), free_var->checked_type()); + var->virtual_device_ = GetVirtualDevice(free_var); typed_captured_vars.push_back(var); rebinding_map.Set(free_var, var); } @@ -173,6 +172,8 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { if (captured_vars.empty() && free_type_vars.empty()) { lifted_func = Function(body->params, body->body, body->ret_type, body->type_params, body->attrs, body->span); + // We also need to copy the virtual device + lifted_func->virtual_device_ = body->virtual_device(); } else { // When a closure is locally bound in a program, we have its full type information // avalible to us. @@ -187,14 +188,14 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { // construct the "closure" function with fully annotated arguments, no longer relying // on type inference. size_t before_arity = body->params.size(); + VLOG(9) << "Binding " << rebinding_map << " into\n" << PrettyPrint(body->body); auto rebound_body = WithFields(func, func->params, Bind(body->body, rebinding_map)); size_t after_arity = rebound_body->params.size(); CHECK_EQ(before_arity, after_arity); lifted_func = Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(), free_type_vars, /*attrs=*/{}, func->span); - lifted_func = - MaybeFunctionOnDevice(lifted_func, captured_var_virtual_devices, result_virtual_device); + lifted_func->virtual_device_ = result_virtual_device; lifted_func = MarkClosure(lifted_func); } diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index b710c2791acf..98e2ac0433b0 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -472,24 +472,25 @@ class ExprBinder : public MixedModeMutator, PatternMutator { const tvm::Map& args_map_; }; +// This function should be called SubstAndBind, since it assumes any variables introduced +// in the substitution right hand side should be implicitly bound in the function. Expr Bind(const Expr& expr, const tvm::Map& args_map) { if (const FunctionNode* func = expr.as()) { Expr new_body = ExprBinder(args_map).VisitExpr(func->body); Array new_params; - std::vector new_param_virtual_devices; for (size_t i = 0; i < func->params.size(); ++i) { if (!args_map.count(func->params[i])) { new_params.push_back(func->params[i]); - new_param_virtual_devices.push_back(GetFunctionParamVirtualDevice(func, i)); } } if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { return expr; } + auto ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); - ret = - MaybeFunctionOnDevice(ret, new_param_virtual_devices, GetFunctionResultVirtualDevice(func)); + ret->virtual_device_ = func->virtual_device(); + std::unordered_set set; for (const auto& v : FreeVars(expr)) { set.insert(v); @@ -497,20 +498,16 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { for (const auto& v : FreeVars(ret)) { if (set.count(v) == 0) { new_params.push_back(v); - if (!GetFunctionResultVirtualDevice(func)->IsFullyUnconstrained()) { - // TODO(mbs): The function has been annotated with a device, which means we are supposed - // to be preserving device annotations on every transformation. However there's no - // such context for the free vars in args_map. - LOG(WARNING) << "introduced free var '" << PrettyPrint(v) - << "' into function body but no device is known for it"; - } - new_param_virtual_devices.push_back(VirtualDevice::FullyUnconstrained()); } } + ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); - ret = - MaybeFunctionOnDevice(ret, new_param_virtual_devices, GetFunctionResultVirtualDevice(func)); + ret->virtual_device_ = func->virtual_device(); + + VLOG(4) << "Expr:\n" << expr; + VLOG(4) << "Ret:\n" << ret; + ICHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size()); return std::move(ret); } else { @@ -528,6 +525,27 @@ TVM_REGISTER_GLOBAL("relay.ir.Bind").set_body([](TVMArgs args, TVMRetValue* ret) } }); +Function SubstituteBoundVars(const Function& func, const tvm::Map& args_map) { + Expr new_body = ExprBinder(args_map).VisitExpr(func->body); + Array new_params; + for (size_t i = 0; i < func->params.size(); i++) { + if (!args_map.count(func->params[i])) { + new_params.push_back(func->params[i]); + } else { + if (const VarNode* var = args_map[func->params[i]].as()) { + new_params.push_back(GetRef(var)); + } else { + ICHECK(false) << "Expected all values in args_map to be vars, but found " + << args_map[func->params[i]]->GetTypeKey(); + } + } + } + auto ret = + Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); + ret->virtual_device_ = func->virtual_device(); + return ret; +} + void ExpandANormalForm(const LetNode* op, std::function pre_visit, std::function post_visit) { std::stack stack; diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index c66c91ecc739..155b6daf0848 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include "../../transforms/infer_layout_utils.h" #include "../type_relations.h" @@ -142,48 +143,5 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) { return {}; } -Function FunctionOnDevice(Function function, Array param_virtual_devices, - VirtualDevice result_virtual_device) { - auto func = WithAttr( - WithFields(std::move(function), {}, {}, {}, {}, {}, std::move(result_virtual_device)), - tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)); - VLOG(1) << "Annotated func: " << PrettyPrint(func); - return func; -} - -TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice); - -Function MaybeFunctionOnDevice(Function function, Array param_virtual_devices, - VirtualDevice result_virtual_device) { - if (std::all_of(param_virtual_devices.begin(), param_virtual_devices.end(), - [](const VirtualDevice& virtual_device) { - return virtual_device->IsFullyUnconstrained(); - }) && - result_virtual_device->IsFullyUnconstrained()) { - // Nothing to annotate. - return function; - } - return FunctionOnDevice(function, std::move(param_virtual_devices), - std::move(result_virtual_device)); -} - -VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node) { - return function_node->virtual_device(); -} - -VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i) { - ICHECK_LT(i, function_node->params.size()) - << "param index " << i << " out of range for function of arity " - << function_node->params.size(); - auto opt_array = function_node->GetAttr>(tvm::attr::kParamVirtualDevice); - if (!opt_array) { - // No annotation. - return VirtualDevice::FullyUnconstrained(); - } - ICHECK_EQ(opt_array.value().size(), function_node->params.size()) - << "annotation parameters do not match function arity"; - return opt_array.value()[i]; -} - } // namespace relay } // namespace tvm diff --git a/src/relay/op/memory/on_device.h b/src/relay/op/memory/on_device.h index 7489e3b62b0c..b597af8fc7fa 100644 --- a/src/relay/op/memory/on_device.h +++ b/src/relay/op/memory/on_device.h @@ -154,33 +154,6 @@ const NodeType* AsIgnoringOnDevice(const Expr& expr) { return props.body.as(); } -/*! - * \brief Returns \p function annotated with "param_virtual_devices" and "result_virtual_device" - * attributes capturing parameter and result \p VirtualDevices respectively. - */ -Function FunctionOnDevice(Function function, Array param_virtual_devices, - VirtualDevice body_virtual_device); - -/*! - * \brief As for \p FunctionOnDevice, but returns \p function unchanged if all parameters and - * result \p VirtualDevices are unconstrained. - */ -Function MaybeFunctionOnDevice(Function function, Array param_virtual_devices, - VirtualDevice result_virtual_device); - -/*! - * \brief Returns the \p VirtualDevice for the resut of \p function_node, or the unconstrained - * \p VirtualDevice if function does not have the "result_virtual_device" annotation. - */ -VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node); - -/*! - * \brief Returns the \p VirtualDevice for the \p i'th parameter of \p function_node, or - * the unconstrained \p VirtualDevice if function does not have the "param_virtual_devices" - * annotation. - */ -VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i); - } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc index 10584da51976..b5ad64add89a 100644 --- a/src/relay/transforms/device_aware_visitors.cc +++ b/src/relay/transforms/device_aware_visitors.cc @@ -38,7 +38,7 @@ LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional& maybe_mod) if (maybe_mod) { for (const auto& kv : maybe_mod.value()->functions) { if (const auto* function_node = kv.second.as()) { - VirtualDevice virtual_device = GetFunctionResultVirtualDevice(function_node); + VirtualDevice virtual_device = function_node->virtual_device(); if (!virtual_device->IsFullyUnconstrained()) { VLOG(2) << "global '" << kv.first->name_hint << "' has virtual device " << virtual_device; global_var_virtual_devices_.emplace(kv.first, virtual_device); @@ -74,7 +74,7 @@ VirtualDevice LexicalOnDeviceMixin::GetVirtualDevice(const Expr& expr) const { } // else: fallthrough to unconstrained } else { - return GetFunctionResultVirtualDevice(function_node); + return function_node->virtual_device(); } } else { if (!expr_virtual_devices_.empty()) { @@ -132,11 +132,11 @@ void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) { DeviceAwareVisitExpr_(function_node); } else { // Function parameters come into scope. - for (size_t i = 0; i < function_node->params.size(); ++i) { - PushBoundVar(function_node->params[i], GetFunctionParamVirtualDevice(function_node, i)); + for (auto param : function_node->params) { + PushBoundVar(param, param->virtual_device()); } // Entering scope of function body. - PushVirtualDevice(GetFunctionResultVirtualDevice(function_node)); + PushVirtualDevice(function_node->virtual_device()); EnterFunctionBody(); DeviceAwareVisitExpr_(function_node); @@ -218,11 +218,11 @@ Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) { return DeviceAwareVisitExpr_(function_node); } else { // Function parameters come into scope. - for (size_t i = 0; i < function_node->params.size(); ++i) { - PushBoundVar(function_node->params[i], GetFunctionParamVirtualDevice(function_node, i)); + for (auto param : function_node->params) { + PushBoundVar(param, param->virtual_device()); } // Entering scope of function body. - PushVirtualDevice(GetFunctionResultVirtualDevice(function_node)); + PushVirtualDevice(function_node->virtual_device()); EnterFunctionBody(); Expr result = DeviceAwareVisitExpr_(function_node); diff --git a/src/relay/transforms/device_aware_visitors.h b/src/relay/transforms/device_aware_visitors.h index 9340c03fc2d5..8e669725b91d 100644 --- a/src/relay/transforms/device_aware_visitors.h +++ b/src/relay/transforms/device_aware_visitors.h @@ -145,11 +145,11 @@ class DeviceAwareExprFunctor : public ExprFunctorparams.size(); ++i) { - PushBoundVar(function_node->params[i], GetFunctionParamVirtualDevice(function_node, i)); + for (auto param : function_node->params) { + PushBoundVar(param, param->virtual_device()); } // Entering scope of function body. - VirtualDevice virtual_device = GetFunctionResultVirtualDevice(function_node); + VirtualDevice virtual_device = function_node->virtual_device(); VLOG(2) << "entering " << virtual_device << " for function:" << std::endl << PrettyPrint(GetRef(function_node)); PushVirtualDevice(virtual_device); diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index 161f7c0c3342..3562da3b0d6f 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -583,15 +583,14 @@ class DeviceAnalyzer : public MixedModeVisitor { // If the function already has VirtualDevice attributes then we can further constrain the // function's domain to match them. - if (!GetFunctionResultVirtualDevice(function_node)->IsFullyUnconstrained()) { + if (!function_node->virtual_device()->IsFullyUnconstrained()) { std::vector args_and_result; - for (size_t i = 0; i < function_node->params.size(); ++i) { + for (auto param : function_node->params) { args_and_result.emplace_back( - domains_->ForVirtualDevice(function_node->params[i]->checked_type(), - GetFunctionParamVirtualDevice(function_node, i))); + domains_->ForVirtualDevice(param->checked_type(), param->virtual_device())); } - args_and_result.emplace_back(domains_->ForVirtualDevice( - function_node->body->checked_type(), GetFunctionResultVirtualDevice(function_node))); + args_and_result.emplace_back(domains_->ForVirtualDevice(function_node->body->checked_type(), + function_node->virtual_device())); auto annotation_domain = domains_->MakeHigherOrderDomain(std::move(args_and_result)); if (domains_->UnifyOrNull(func_domain, annotation_domain) == nullptr) { // higher-order // TODO(mbs): Proper diagnostics. @@ -877,7 +876,6 @@ class DeviceDefaulter : public ExprVisitor { }; /* =============== Phase 3 =============== */ - /*! * \brief Inserts missing "device_copy" CallNodes, and ensures the device type of every * sub-expression in a module can be easily recovered by a later transformation using simple @@ -886,8 +884,9 @@ class DeviceDefaulter : public ExprVisitor { * - Discard any existing "on_device" CallNodes since their job is done. Similarly, discard * any existing "device_copy" CallNodes which are no-ops. * - * - Functions are given "param_virtual_devices" and "result_virtual_device" attributes to capture - * the device type for its parameters and result. + * - The result virtual device for a function is stored in the function's virtual_device_ field + * and the virtual devices of the function's parameters are stored in the parameter's + * virtual_device_ field. * * - Additional "device_copy" CallNodes are inserted wherever there's a transition between * storage device types. Since the DeviceAnalyzer phase succeeded this can only happen @@ -1003,14 +1002,25 @@ class DeviceCapturer : public ExprMutator { ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); VirtualDevice result_virtual_device = domains_->ResultVirtualDevice(func_domain); ICHECK(!result_virtual_device->IsFullyUnconstrained()); - Array param_virtual_devices; - param_virtual_devices.reserve(function_node->params.size()); + + // Map the function parameters to a new variable annotated with a virtual device so + // we can substitute them later. + Map annotated_bind_map; + Array annotated_params; + annotated_params.reserve(function_node->params.size()); for (size_t i = 0; i < function_node->params.size(); ++i) { VirtualDevice param_virtual_device = domains_->ResultVirtualDevice(func_domain->function_param(i)); + VLOG(4) << "Param: " << function_node->params[i]; + Var annotated_var = WithFields(function_node->params[i], {}, {}, param_virtual_device); + VLOG(4) << "Annotated param: " << annotated_var; + VLOG(4) << "VirtualDevice: " << annotated_var->virtual_device(); ICHECK(!param_virtual_device->IsFullyUnconstrained()); - param_virtual_devices.push_back(param_virtual_device); + annotated_bind_map.Set(function_node->params[i], annotated_var); + annotated_params.push_back(annotated_var); } + // Eventually we probably want to bind before visiting, but for now this is causing an issue + // with the GetVirtualDevice utility, so leaving as is for now. // Rewrite the body. Note that the body may have begun with an "on_device" so // be prepared to insert a "device_copy". @@ -1018,10 +1028,14 @@ class DeviceCapturer : public ExprMutator { /*lexical_virtual_device=*/result_virtual_device, /*expected_virtual_device=*/result_virtual_device, /*child_virtual_device=*/GetVirtualDevice(function_node->body), function_node->body); - + VLOG(4) << "Visited body: " << body; Function func = WithFields(GetRef(function_node), function_node->params, body); - return FunctionOnDevice(func, std::move(param_virtual_devices), - std::move(result_virtual_device)); + VLOG(4) << "New function: " << func; + func = SubstituteBoundVars(func, annotated_bind_map); + VLOG(4) << "Func with bound params: " << func; + func->virtual_device_ = result_virtual_device; + VLOG(4) << "Func with bound params & result vid set: " << func; + return func; } Expr VisitExpr_(const CallNode* call_node) final { diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 2f6efb9cef9a..8319726b79c5 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -304,9 +304,9 @@ class Fill : ExprFunctor, private transform::Lexi } else { // Keep track of expression and bound variable device types for lexically enclosing // sub-expressions. - PushVirtualDevice(GetFunctionResultVirtualDevice(f)); - for (size_t i = 0; i < f->params.size(); ++i) { - PushBoundVar(f->params[i], GetFunctionParamVirtualDevice(f, i)); + PushVirtualDevice(f->virtual_device()); + for (auto param : f->params) { + PushBoundVar(param, param->virtual_device()); } EnterFunctionBody(); ret = WithFields(GetRef(f), f->params, diff --git a/tests/python/relay/op/annotation/test_annotation.py b/tests/python/relay/op/annotation/test_annotation.py index abc458313101..dcbb40cdcabc 100644 --- a/tests/python/relay/op/annotation/test_annotation.py +++ b/tests/python/relay/op/annotation/test_annotation.py @@ -61,18 +61,6 @@ def test_on_device_free(): assert not call.attrs.constrain_result -def test_function_on_device(): - x = relay.Var("x") - y = relay.Var("y") - f = relay.Function([x, y], relay.add(x, y)) - func = relay.annotation.function_on_device(f, ["cpu", "cuda"], "cuda") - assert isinstance(func, relay.Function) - assert len(func.attrs["param_virtual_devices"]) == 2 - assert func.attrs["param_virtual_devices"][0].device_type_int == 1 # ie kDLCPU - assert func.attrs["param_virtual_devices"][1].device_type_int == 2 # ie kDLCUDA - assert func.virtual_device_.device_type_int == 2 # ie KDLCUDA - - if __name__ == "__main__": import sys diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 7b7ef0ce920f..c165d140b1a6 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -22,11 +22,6 @@ from tvm.relay.testing import run_infer_type, create_workload -def annot_func(f): - """Returns f with arg/result device attributes for the argument and result.""" - return relay.op.annotation.function_on_device(f, [tvm.cpu()], tvm.cpu()) - - def annot_expr(e): """Returns e wrapped with an on_device annotation.""" return relay.op.annotation.on_device(e, tvm.cpu(), constrain_result=True) @@ -96,20 +91,24 @@ def test_fold_const_with_on_device(): def before(): c = relay.const(c_data) x = relay.var("x", t) + x.virtual_device_ = tvm.cpu() y = relay.add(c, c) y = relay.multiply(y, relay.const(2, "float32")) y = relay.add(x, y) z = relay.add(y, c) f = relay.Function([x], z) - return annot_func(f) + f.virtual_device_ = tvm.cpu() + return f def expected(): x = relay.var("x", t) + x.virtual_device_ = tvm.cpu() c_folded = (c_data + c_data) * 2 y = relay.add(x, relay.const(c_folded)) z = relay.add(y, relay.const(c_data)) f = relay.Function([x], z) - return annot_func(f) + f.virtual_device_ = tvm.cpu() + return f zz = run_opt_pass(before(), transform.FoldConstant()) zexpected = run_opt_pass(expected(), transform.InferType()) @@ -151,21 +150,25 @@ def test_fold_let_with_on_device(): def before(): sb = relay.ScopeBuilder() x = relay.var("x", t) + x.virtual_device_ = tvm.cpu() t1 = sb.let("t1", annot_expr(relay.const(c_data))) t2 = sb.let("t2", annot_expr(relay.add(t1, t1))) t3 = sb.let("t3", annot_expr(relay.add(t2, x))) sb.ret(t3) f = relay.Function([x], sb.get()) - return annot_func(f) + f.virtual_device_ = tvm.cpu() + return f def expected(): sb = relay.ScopeBuilder() x = relay.var("x", t) + x.virtual_device_ = tvm.cpu() c_folded = c_data + c_data t3 = sb.let("t3", annot_expr(relay.add(annot_expr(relay.const(c_folded)), x))) sb.ret(t3) f = relay.Function([x], sb.get()) - return annot_func(f) + f.virtual_device_ = tvm.cpu() + return f zz = run_opt_pass(before(), transform.FoldConstant()) zexpected = run_opt_pass(expected(), transform.InferType()) diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py index 6d658de387b0..f9fe9cf3555b 100644 --- a/tests/python/relay/test_pass_plan_devices.py +++ b/tests/python/relay/test_pass_plan_devices.py @@ -133,10 +133,9 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][1], meta[VirtualDevice][1], meta[VirtualDevice][1]], - result_virtual_device=meta[VirtualDevice][1]) { + def @main(%a {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], %b {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], + %c {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], %d {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][1]) { %0 = add(%a, %b); %1 = add(%c, %d); subtract(%0, %1) @@ -178,10 +177,9 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][1]], - result_virtual_device=meta[VirtualDevice][1]) { + def @main(%a {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %b {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + %c {virtual_device= meta[VirtualDevice][1]}: Tensor[(5, 7), float32], %d {virtual_device= meta[VirtualDevice][1]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][1]) { %0 = add(%a, %b); %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); @@ -225,10 +223,9 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][1]], - result_virtual_device=meta[VirtualDevice][1]) { + def @main(%a {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %b {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + %c {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], %d {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][1]) { %0 = add(%a, %b); %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); @@ -272,10 +269,9 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0]], - result_virtual_device=meta[VirtualDevice][1]) { + def @main(%a {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %b {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + %c {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %d {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][1]) { %0 = add(%a, %b); %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); %2 = add(%c, %d); @@ -320,8 +316,8 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) { + def @main(%a {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %b {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][1]) { %0 = add(%a, %b); %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); %2 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); @@ -367,10 +363,9 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][1]], - result_virtual_device=meta[VirtualDevice][1]) { + def @main(%a {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %b {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + %c {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], %d {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][1]) { %0 = add(%a, %b); let %l = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); let %r = on_device(add(%c, %d), virtual_device=meta[VirtualDevice][1], constrain_result=True); @@ -417,12 +412,11 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0]], - result_virtual_device=meta[VirtualDevice][0]) { - let %f = fn (%x, %y, - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { + def @main(%a {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %b {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + %c {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %d {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][0]) { + let %f = fn (%x {virtual_device=meta[VirtualDevice][0]}, %y {virtual_device=meta[VirtualDevice][0]}, + virtual_device=meta[VirtualDevice][0]) { add(%x, %y) }; %0 = %f(%a, %b); @@ -469,12 +463,11 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][1]], - result_virtual_device=meta[VirtualDevice][1]) { - let %f = fn (%x, %y, - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { + def @main(%a {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %b {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + %c {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], %d {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][1]) { + let %f = fn (%x {virtual_device=meta[VirtualDevice][0]}, %y {virtual_device=meta[VirtualDevice][0]}, + virtual_device=meta[VirtualDevice][0]) { add(%x, %y) }; %1 = %f(%a, %b); @@ -528,16 +521,16 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) { - let %f = fn (%g, param_virtual_devices=[meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][1]) { - fn (%a, param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) { + def @main(%x {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], %y {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][1]) { + let %f = fn (%g {virtual_device=meta[VirtualDevice][1]}, virtual_device=meta[VirtualDevice][1]) { + fn (%a {virtual_device=meta[VirtualDevice][0]}, virtual_device=meta[VirtualDevice][1]) { %0 = device_copy(%a, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); %1 = %g(%0); add(%1, %x) } }; - let %h = fn (%b, param_virtual_devices=[meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][1]) { + let %h = fn (%b {virtual_device=meta[VirtualDevice][1]}, virtual_device=meta[VirtualDevice][1]) { negative(%b) }; %2 = %f(%h); @@ -590,10 +583,10 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { - let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { + def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %y {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][0]) { + let %f = fn (%a {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %b {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][0]) { add(%a, %b) }; let %t = on_device((%f, %x), virtual_device=meta[VirtualDevice][0], constrain_result=True); @@ -635,8 +628,8 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) { + def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][1]) { %0 = device_copy(%x, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); add(%0, meta[relay.Constant][0]) } @@ -676,8 +669,8 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(?, ?), float32], - param_virtual_devices=[meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][0]) { + def @main(%x {virtual_device=meta[VirtualDevice][1]}: Tensor[(?, ?), float32], + virtual_device=meta[VirtualDevice][0]) { vm.shape_of(%x, dtype="int64") } """, @@ -712,8 +705,8 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%size: int64, %alignment: int64, - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) { + def @main(%size {virtual_device=meta[VirtualDevice][0]}: int64, %alignment {virtual_device=meta[VirtualDevice][0]}: int64, + virtual_device=meta[VirtualDevice][1]) { memory.alloc_storage(%size, %alignment, virtual_device=meta[VirtualDevice][1]) } """, @@ -751,7 +744,7 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%sto: Storage[], param_virtual_devices=[meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][1]) { + def @main(%sto {virtual_device=meta[VirtualDevice][1]}: Storage[], virtual_device=meta[VirtualDevice][1]) { %0 = on_device(0, virtual_device=meta[VirtualDevice][0], constrain_result=True); %1 = on_device(meta[relay.Constant][0], virtual_device=meta[VirtualDevice][0], constrain_result=True); memory.alloc_tensor(%sto, %0, %1, const_shape=meta[relay.Constant][0], assert_shape=[]) @@ -790,8 +783,8 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(2, 8), float32], - param_virtual_devices=[meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][1]) { + def @main(%x {virtual_device=meta[VirtualDevice][1]}: Tensor[(2, 8), float32], + virtual_device=meta[VirtualDevice][1]) { %0 = on_device(meta[relay.Constant][0], virtual_device=meta[VirtualDevice][0], constrain_result=True); vm.reshape_tensor(%x, %0, newshape=[2, 4, 2]) } @@ -828,8 +821,8 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x0: Tensor[(?, ?), float32], %x1: Tensor[(?, ?), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { + def @main(%x0 {virtual_device=meta[VirtualDevice][0]}: Tensor[(?, ?), float32], %x1 {virtual_device=meta[VirtualDevice][0]}: Tensor[(?, ?), float32], + virtual_device=meta[VirtualDevice][0]) { add(%x0, %x1) } """, @@ -868,9 +861,8 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1]], - result_virtual_device=meta[VirtualDevice][1]) { + def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %y {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %z {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][1]) { %0 = add(%x, %y); %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); @@ -915,9 +907,8 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][1], meta[VirtualDevice][0]], - result_virtual_device=meta[VirtualDevice][0]) { + def @main(%x {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], %y {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], %z {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][0]) { %0 = add(%x, %y); %1 = on_device(%0, virtual_device=meta[VirtualDevice][1], constrain_result=True); %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][0]); @@ -958,9 +949,8 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0]], - result_virtual_device=meta[VirtualDevice][0]) { + def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %y {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %z {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][0]) { %0 = add(%x, %y); subtract(%0, %z) } @@ -1016,10 +1006,9 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 56), float32], - %weight: Tensor[(64, 64, 3, 3), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0]], - result_virtual_device=meta[VirtualDevice][0]) { + def @main(%data1 {virtual_device=meta[VirtualDevice][0]}: Tensor[(1, 64, 56, 56), float32], %data2 {virtual_device=meta[VirtualDevice][0]}: Tensor[(1, 64, 56, 56), float32], + %weight {virtual_device=meta[VirtualDevice][0]}: Tensor[(64, 64, 3, 3), float32], + virtual_device=meta[VirtualDevice][0]) { %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); %2 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); @@ -1069,8 +1058,8 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(3, 3, 4), float32], - param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) { + def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(3, 3, 4), float32], + virtual_device=meta[VirtualDevice][1]) { %0 = split(%x, indices_or_sections=3); let %t = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); %1 = %t.0; @@ -1137,8 +1126,8 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { + def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][0]) { %0 = negative(%x); %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); @@ -1207,8 +1196,8 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][0]) { + def @main(%x {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], %y {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][0]) { %0 = add(%x, %y); %1 = on_device(%0, virtual_device=meta[VirtualDevice][1], constrain_result=True); %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][0]); @@ -1273,10 +1262,9 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][1]], - result_virtual_device=meta[VirtualDevice][0]) { + def @main(%a {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %b {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + %c {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], %d {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][0]) { %0 = multiply(%c, %d); %1 = on_device(%0, virtual_device=meta[VirtualDevice][1], constrain_result=True); %2 = add(%a, %b); @@ -1328,13 +1316,12 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0]], - result_virtual_device=meta[VirtualDevice][0]) { - let %f = fn (%a, param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { + def @main(%x {virtual_device=meta[VirtualDevice][0]}: bool, %y {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %z {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][0]) { + let %f = fn (%a {virtual_device=meta[VirtualDevice][0]}, virtual_device=meta[VirtualDevice][0]) { add(%a, %y) }; - let %g = fn (%a1, param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { + let %g = fn (%a1 {virtual_device=meta[VirtualDevice][0]}, virtual_device=meta[VirtualDevice][0]) { subtract(%a1, %y) }; let %h = on_device(if (%x) { @@ -1388,16 +1375,14 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][0]], - result_virtual_device=meta[VirtualDevice][1]) -> Tensor[(5, 7), float32] { + def @f(%a {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], %b {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][1]) -> Tensor[(5, 7), float32] { %0 = device_copy(%b, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); add(%a, %0) } - def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][1]], - result_virtual_device=meta[VirtualDevice][1]) -> Tensor[(5, 7), float32] { + def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %y {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][1]) -> Tensor[(5, 7), float32] { @f(%y, %x) } """, @@ -1439,8 +1424,8 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) { + def @main(%x {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], %y {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][1]) { let %r = on_device(ref(%x), virtual_device=meta[VirtualDevice][1], constrain_result=True); %0 = device_copy(%y, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); on_device(ref_write(%r, %0), virtual_device=meta[VirtualDevice][1], constrain_result=True); @@ -1496,8 +1481,8 @@ def expected(): Cons(A, List[A]), Nil, } - def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { + def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %y {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][0]) { %0 = Nil; %1 = Cons(%y, %0); let %l = on_device(Cons(%x, %1), virtual_device=meta[VirtualDevice][0], constrain_result=True); @@ -1535,14 +1520,12 @@ def input(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @on_scope_b(%x: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][2]], - result_virtual_device=meta[VirtualDevice][2]) -> Tensor[(5, 7), float32] { + def @on_scope_b(%x {virtual_device=meta[VirtualDevice][2]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][2]) -> Tensor[(5, 7), float32] { %x } - def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][2]], - result_virtual_device=meta[VirtualDevice][1]) { + def @main(%a {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32], %b {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], %c {virtual_device=meta[VirtualDevice][2]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][1]) { // %a's memory scope is unconstrained, so will take on "scopeB" and on_device has no effect %0 = @on_scope_b(on_device(%a, virtual_device=meta[VirtualDevice][0], constrain_body=False)); // %b's memory scope is "scopeA", so will require a "scopeA"->"scopeB" copy. @@ -1563,14 +1546,12 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @on_scope_b(%x: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][2]], - result_virtual_device=meta[VirtualDevice][2]) -> Tensor[(5, 7), float32] { + def @on_scope_b(%x {virtual_device=meta[VirtualDevice][2]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][2]) -> Tensor[(5, 7), float32] { %x } - def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], - param_virtual_devices=[meta[VirtualDevice][2], meta[VirtualDevice][1], meta[VirtualDevice][2]], - result_virtual_device=meta[VirtualDevice][1]) { + def @main(%a {virtual_device=meta[VirtualDevice][2]}: Tensor[(5, 7), float32], %b {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32], %c {virtual_device=meta[VirtualDevice][2]}: Tensor[(5, 7), float32], + virtual_device=meta[VirtualDevice][1]) { %0 = @on_scope_b(%a); %1 = device_copy(%b, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][2]); %2 = @on_scope_b(%1); @@ -1649,11 +1630,10 @@ def input(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x : Tensor[(128, 128), float32], - %y : Tensor[(128, 128), float32], - %z : Tensor[(128, 128), float32], - param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][2], meta[VirtualDevice][1]], - result_virtual_device=meta[VirtualDevice][2]) { + def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], + %y {virtual_device=meta[VirtualDevice][2]}: Tensor[(128, 128), float32], + %z {virtual_device=meta[VirtualDevice][1]}: Tensor[(128, 128), float32], + virtual_device=meta[VirtualDevice][2]) { call_lowered(@gem, (%x, %y, %z)) } """, @@ -1672,11 +1652,10 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%x : Tensor[(128, 128), float32], - %y : Tensor[(128, 128), float32], - %z : Tensor[(128, 128), float32], - param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][2], meta[VirtualDevice][1]], - result_virtual_device=meta[VirtualDevice][2]) { + def @main(%x {virtual_device=meta[VirtualDevice][1]}: Tensor[(128, 128), float32], + %y {virtual_device=meta[VirtualDevice][2]}: Tensor[(128, 128), float32], + %z {virtual_device=meta[VirtualDevice][1]}: Tensor[(128, 128), float32], + virtual_device=meta[VirtualDevice][2]) { %0 = device_copy(%z, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][2]); %1 = on_device(%0, virtual_device=meta[VirtualDevice][2], constrain_result=True); %2 = call_lowered(@gem, (%x, %y, %1));