Skip to content

Commit

Permalink
[RELAY] [VIRTUALDEVICE] Change syntax for device planning and store p…
Browse files Browse the repository at this point in the history
…arameter virtual devices in virtual_device_ field (#10352)

* parent 33082e0
author electriclilies <[email protected]> 1643141097 -0800
committer Lily Orth-Smith <[email protected]> 1645560059 -0800

Store function param virtual devices in virtual_device_ field

Fix test_annotation.py and change result_virtual_device to virtual_device

* Change plan devices tests to use the new syntax for function parameters

* Fix free var problem

* Fix attribute parsing if there is virtual device; most device planning tests passgit status

* fixed lambda lifting

* Debugging high order functions -- right now FunctionOnDevice and Bind are mutually recursive. This needs to not be the case.

* tests pass wootgit status

* Remove FunctionOnDevice from device planner

* Don't use MaybeFunctionOnDevice in VM compiler

* Remove MaybeFunctionOnDevice from lambda lifter

* Delete FunctionOnDevice and MaybeFunctionOnDevice!

* Reomve GetFunctionResultVirtualDevice

* Remove GetFunctionParamVirtualDevice

* lint

* lint

* Python formatting

* Remove FunctionOnDevice python test

* Fix bug in binds & debug output

* Fix text printer

* lint

* Remove function on device from fold constant tests

* Mark nits

* Revert behavior of bind

* clean up debug

* Make ExprBinder public interface and use instead of Bind

* Fix lambda lift

* This is broken but not sure how to fix

* passes all device planning tests yay!

* Add substitution helper and use in device planner

* Remove unnecessary check

* Respond to comments

* Update comment
  • Loading branch information
electriclilies authored Feb 25, 2022
1 parent d9fac4f commit 308d320
Show file tree
Hide file tree
Showing 17 changed files with 245 additions and 281 deletions.
10 changes: 0 additions & 10 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,6 @@ constexpr const char* kTarget = "target";
*/
constexpr const char* kGlobalSymbol = "global_symbol";

/*!
* \brief The \p VirtualDevice which will hold each of the functions parameters.
*
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
* may be included as an annotation on user programs.
*
* Type: Array<VirtualDevice>
*/
constexpr const char* kParamVirtualDevice = "param_virtual_devices";

} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
17 changes: 17 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,10 @@ TVM_DLL Pass PlanDevices(CompilationConfig config);
/*!
* \brief Bind the free variables to a Relay expression. This is a helper
* function usually called by other pass functions to help optimizations.
* If any free variables are introduced into a function, those are added
* to the functoin parameters.
* Additionally this may change the order of parameters if you map a variable
* to a variable.
*
* \param expr The input expression.
* \param binds The variable to expression map that will be used to help the
Expand All @@ -508,6 +512,19 @@ TVM_DLL Pass PlanDevices(CompilationConfig config);
*/
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);

/*!
* \brief Substitute variables with new variables (including function parameters) in a function.
* This is a helper function usually called by other pass functions to help optimizations.
* Expects all values in the bind map to be Vars.
*
* \param func The input function.
* \param binds The variable to expression map that will be used to help the
* binding.
*
* \return The updated expression.
*/
TVM_DLL Function SubstituteBoundVars(const Function& func, const tvm::Map<Var, Expr>& binds);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order. This
* function is used as a helper function to rewrtie an expression in a pass.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/target/virtual_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ class VirtualDeviceCache {
*
* Type: VirtualDevice
*/
constexpr const char* kVirtualDevice = "result_virtual_device";
constexpr const char* kVirtualDevice = "virtual_device";

} // namespace tvm

Expand Down
34 changes: 30 additions & 4 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,13 @@ class Parser {
*
* "x" -> Var("x"), these are needed to map from the raw string names
* to unique variable nodes.
* If a virtual device is specified, sets the virtual device of the variable.
*/
Var BindVar(const std::string& name, const relay::Type& type_annotation) {
Var BindVar(const std::string& name, const relay::Type& type_annotation,
Optional<VirtualDevice> virtual_device = Optional<VirtualDevice>()) {
auto var = Var(name, type_annotation);
var->virtual_device_ = virtual_device.value_or(VirtualDevice::FullyUnconstrained());
VLOG(1) << "Binding var named " << name << " to variable node " << PrettyPrint(var);
this->expr_scopes.Add(name, var);
return var;
}
Expand Down Expand Up @@ -1113,11 +1117,26 @@ class Parser {
[&]() {
auto token = Match(TokenType::kLocal);
auto string = token.ToString();

// The fake attributes where the virtual device is specified.
VirtualDevice virtual_device;
if (WhenMatch(TokenType::kLCurly)) {
Map<String, ObjectRef> fake_attrs = ParseAttrs();
VLOG(9) << "Fake attributes for function parameter: " << fake_attrs;
Match(TokenType::kRCurly);
if (fake_attrs.size() == 1 && fake_attrs.count(kVirtualDevice)) {
ICHECK(fake_attrs[kVirtualDevice].as<VirtualDeviceNode>())
<< "Expected the " << kVirtualDevice
<< " to have type VirtualDeviceNode, but got " << virtual_device->GetTypeKey();
virtual_device = Downcast<VirtualDevice>(fake_attrs[kVirtualDevice]);
}
}

Type type;
if (WhenMatch(TokenType::kColon)) {
type = ParseType();
}
return BindVar(string, type);
return BindVar(string, type, virtual_device);
},
[&] {
auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier;
Expand Down Expand Up @@ -1150,8 +1169,15 @@ class Parser {
ICHECK(vid.as<VirtualDeviceNode>())
<< "Expected the " << kVirtualDevice << " to have type VirtualDeviceNode, but got "
<< vid->GetTypeKey();
raw_attrs.erase(kVirtualDevice);
Function func = relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs));

DictAttrs attrs;
// Don't fill the raw_attrs in if there's nothing other than kVirtualDevice in the
// attributes
if (raw_attrs.size() > 1) {
raw_attrs.erase(kVirtualDevice);
attrs = DictAttrs(raw_attrs);
}
Function func = relay::Function(params, body, ret_type, generics, attrs);
func->virtual_device_ = vid;
return func;
} else {
Expand Down
7 changes: 5 additions & 2 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,13 @@ Doc RelayTextPrinter::AllocVar(const Var& var) {
}
Doc val = GetUniqueName("%" + name);
memo_[var] = val;
if (!var->virtual_device()->IsFullyUnconstrained()) {
val << " {" << kVirtualDevice << "=" << PrintAttributeValue(var->virtual_device()) << "}";
}
if (var->type_annotation.defined()) {
val << ": " << Print(var->type_annotation);
}

val << PrintOptionalInfo(var);
return val;
}
Expand Down Expand Up @@ -445,7 +449,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
for (const Doc& d : PrintDictAttrs(fn->attrs)) {
params.push_back(d);
}
if (fn->virtual_device() != VirtualDevice::FullyUnconstrained()) {
if (!fn->virtual_device()->IsFullyUnconstrained()) {
Doc vid_doc;
vid_doc << kVirtualDevice << "=" << PrintAttributeValue(fn->virtual_device());
params.push_back(vid_doc);
Expand All @@ -454,7 +458,6 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
if (fn->ret_type.defined()) {
doc << "-> " << Print(fn->ret_type) << " ";
}

doc << PrintBody(fn->body);
return doc;
}
Expand Down
18 changes: 6 additions & 12 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,21 +252,16 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
// Do that flattening on-the-fly here.
Function inner_func = Downcast<Function>(func->body);
std::vector<Var> params;
std::vector<VirtualDevice> param_virtual_devices;
params.reserve(func->params.size() + inner_func->params.size());
param_virtual_devices.reserve(func->params.size() + inner_func->params.size());
param_device_indexes.reserve(func->params.size() + inner_func->params.size());
for (size_t i = 0; i < func->params.size(); ++i) {
params.emplace_back(func->params[i]);
VirtualDevice param_virtual_device = GetFunctionParamVirtualDevice(func.get(), i);
param_virtual_devices.push_back(param_virtual_device);
param_device_indexes.push_back(GetDeviceIndex(param_virtual_device));
param_device_indexes.push_back(GetDeviceIndex(func->params[i]->virtual_device()));
}
for (size_t i = 0; i < inner_func->params.size(); ++i) {
params.emplace_back(inner_func->params[i]);
VirtualDevice param_virtual_device = GetFunctionParamVirtualDevice(inner_func.get(), i);
param_virtual_devices.push_back(param_virtual_device);
param_device_indexes.push_back(GetDeviceIndex(param_virtual_device));

param_device_indexes.push_back(GetDeviceIndex(inner_func->params[i]->virtual_device()));
}
std::vector<TypeVar> type_params;
type_params.reserve(func->type_params.size() + inner_func->type_params.size());
Expand All @@ -278,13 +273,12 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
}
Function flattened_func = Function(params, inner_func->body, inner_func->ret_type,
type_params, func->attrs, func->span);
VisitExpr(MaybeFunctionOnDevice(flattened_func, param_virtual_devices,
GetFunctionResultVirtualDevice(inner_func.get())));
flattened_func->virtual_device_ = inner_func->virtual_device();
VisitExpr(flattened_func);
} else {
param_device_indexes.reserve(func->params.size());
for (size_t i = 0; i < func->params.size(); ++i) {
param_device_indexes.push_back(
GetDeviceIndex(GetFunctionParamVirtualDevice(func.get(), i)));
param_device_indexes.push_back(GetDeviceIndex(func->params[i]->virtual_device()));
}
VisitExpr(func);
}
Expand Down
9 changes: 5 additions & 4 deletions src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,22 +111,21 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
auto free_type_vars = FreeTypeVars(func, module_);

Array<Var> captured_vars;
std::vector<VirtualDevice> captured_var_virtual_devices;
bool recursive = false;
for (const auto& var : free_vars) {
if (!letrec_.empty() && var == letrec_.back()) {
recursive = true;
continue;
}
captured_vars.push_back(var);
captured_var_virtual_devices.push_back(GetVirtualDevice(var));
}

// Freshen all the captured vars.
Array<Var> typed_captured_vars;
Map<Var, Expr> rebinding_map;
for (auto free_var : captured_vars) {
auto var = Var(free_var->name_hint(), free_var->checked_type());
var->virtual_device_ = GetVirtualDevice(free_var);
typed_captured_vars.push_back(var);
rebinding_map.Set(free_var, var);
}
Expand Down Expand Up @@ -173,6 +172,8 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
if (captured_vars.empty() && free_type_vars.empty()) {
lifted_func = Function(body->params, body->body, body->ret_type, body->type_params,
body->attrs, body->span);
// We also need to copy the virtual device
lifted_func->virtual_device_ = body->virtual_device();
} else {
// When a closure is locally bound in a program, we have its full type information
// avalible to us.
Expand All @@ -187,14 +188,14 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
// construct the "closure" function with fully annotated arguments, no longer relying
// on type inference.
size_t before_arity = body->params.size();
VLOG(9) << "Binding " << rebinding_map << " into\n" << PrettyPrint(body->body);
auto rebound_body = WithFields(func, func->params, Bind(body->body, rebinding_map));
size_t after_arity = rebound_body->params.size();
CHECK_EQ(before_arity, after_arity);
lifted_func =
Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(),
free_type_vars, /*attrs=*/{}, func->span);
lifted_func =
MaybeFunctionOnDevice(lifted_func, captured_var_virtual_devices, result_virtual_device);
lifted_func->virtual_device_ = result_virtual_device;
lifted_func = MarkClosure(lifted_func);
}

Expand Down
46 changes: 32 additions & 14 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -472,45 +472,42 @@ class ExprBinder : public MixedModeMutator, PatternMutator {
const tvm::Map<Var, Expr>& args_map_;
};

// This function should be called SubstAndBind, since it assumes any variables introduced
// in the substitution right hand side should be implicitly bound in the function.
Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
if (const FunctionNode* func = expr.as<FunctionNode>()) {
Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
Array<Var> new_params;
std::vector<VirtualDevice> new_param_virtual_devices;
for (size_t i = 0; i < func->params.size(); ++i) {
if (!args_map.count(func->params[i])) {
new_params.push_back(func->params[i]);
new_param_virtual_devices.push_back(GetFunctionParamVirtualDevice(func, i));
}
}
if (new_body.same_as(func->body) && new_params.size() == func->params.size()) {
return expr;
}

auto ret =
Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span);
ret =
MaybeFunctionOnDevice(ret, new_param_virtual_devices, GetFunctionResultVirtualDevice(func));
ret->virtual_device_ = func->virtual_device();

std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> set;
for (const auto& v : FreeVars(expr)) {
set.insert(v);
}
for (const auto& v : FreeVars(ret)) {
if (set.count(v) == 0) {
new_params.push_back(v);
if (!GetFunctionResultVirtualDevice(func)->IsFullyUnconstrained()) {
// TODO(mbs): The function has been annotated with a device, which means we are supposed
// to be preserving device annotations on every transformation. However there's no
// such context for the free vars in args_map.
LOG(WARNING) << "introduced free var '" << PrettyPrint(v)
<< "' into function body but no device is known for it";
}
new_param_virtual_devices.push_back(VirtualDevice::FullyUnconstrained());
}
}

ret =
Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span);
ret =
MaybeFunctionOnDevice(ret, new_param_virtual_devices, GetFunctionResultVirtualDevice(func));
ret->virtual_device_ = func->virtual_device();

VLOG(4) << "Expr:\n" << expr;
VLOG(4) << "Ret:\n" << ret;

ICHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
return std::move(ret);
} else {
Expand All @@ -528,6 +525,27 @@ TVM_REGISTER_GLOBAL("relay.ir.Bind").set_body([](TVMArgs args, TVMRetValue* ret)
}
});

Function SubstituteBoundVars(const Function& func, const tvm::Map<Var, Expr>& args_map) {
Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
Array<Var> new_params;
for (size_t i = 0; i < func->params.size(); i++) {
if (!args_map.count(func->params[i])) {
new_params.push_back(func->params[i]);
} else {
if (const VarNode* var = args_map[func->params[i]].as<VarNode>()) {
new_params.push_back(GetRef<Var>(var));
} else {
ICHECK(false) << "Expected all values in args_map to be vars, but found "
<< args_map[func->params[i]]->GetTypeKey();
}
}
}
auto ret =
Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span);
ret->virtual_device_ = func->virtual_device();
return ret;
}

void ExpandANormalForm(const LetNode* op, std::function<void(const LetNode*)> pre_visit,
std::function<void(const LetNode*)> post_visit) {
std::stack<const LetNode*> stack;
Expand Down
44 changes: 1 addition & 43 deletions src/relay/op/memory/on_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>

#include "../../transforms/infer_layout_utils.h"
#include "../type_relations.h"
Expand Down Expand Up @@ -142,48 +143,5 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) {
return {};
}

Function FunctionOnDevice(Function function, Array<VirtualDevice> param_virtual_devices,
VirtualDevice result_virtual_device) {
auto func = WithAttr(
WithFields(std::move(function), {}, {}, {}, {}, {}, std::move(result_virtual_device)),
tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices));
VLOG(1) << "Annotated func: " << PrettyPrint(func);
return func;
}

TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice);

Function MaybeFunctionOnDevice(Function function, Array<VirtualDevice> param_virtual_devices,
VirtualDevice result_virtual_device) {
if (std::all_of(param_virtual_devices.begin(), param_virtual_devices.end(),
[](const VirtualDevice& virtual_device) {
return virtual_device->IsFullyUnconstrained();
}) &&
result_virtual_device->IsFullyUnconstrained()) {
// Nothing to annotate.
return function;
}
return FunctionOnDevice(function, std::move(param_virtual_devices),
std::move(result_virtual_device));
}

VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node) {
return function_node->virtual_device();
}

VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i) {
ICHECK_LT(i, function_node->params.size())
<< "param index " << i << " out of range for function of arity "
<< function_node->params.size();
auto opt_array = function_node->GetAttr<Array<VirtualDevice>>(tvm::attr::kParamVirtualDevice);
if (!opt_array) {
// No annotation.
return VirtualDevice::FullyUnconstrained();
}
ICHECK_EQ(opt_array.value().size(), function_node->params.size())
<< "annotation parameters do not match function arity";
return opt_array.value()[i];
}

} // namespace relay
} // namespace tvm
Loading

0 comments on commit 308d320

Please sign in to comment.