Skip to content

Commit

Permalink
[checkpoint] Another round of PR review fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
mbs-octoml committed Dec 1, 2021
1 parent 3271d5e commit 9484ec5
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 16 deletions.
12 changes: 7 additions & 5 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,10 @@ class TECompilerImpl : public TECompilerNode {
VLOG(1) << "scheduling";
IRModule scheduled_module =
tvm::LowerSchedule(value->cached_func->schedule, all_args, func_name, binds);
// Unfortunately the above machinery creates its own GlobalVars instead of using *the*
// GlobalVar we established above. Fix this before the confusion spreads any further.
// TODO(mbs): LowerSchedule should be given prim_fn_gvar instead of func_name.
for (const auto& kv : scheduled_module->functions) {
// TODO(msb): LowerSchedule should accept prim_fn_var instead of func_name.
GlobalVar global_var = kv.first->name_hint == value->cached_func->prim_fn_var->name_hint
? value->cached_func->prim_fn_var
: kv.first;
Expand Down Expand Up @@ -570,7 +572,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
Array<GlobalVar> all_prim_fn_vars;
for (const auto& kv : cfunc->funcs->functions) {
if (opt_compiler) {
// We expect just the original func but with just the ExternalSymbol attribute signalling
// We expect just the original func but with just the ExternalSymbol attribute signaling
// the function (will be) compiled externally.
ICHECK(kv.second.as<FunctionNode>())
<< PrettyPrint(kv.first) << " must be bound to an (external) Function";
Expand Down Expand Up @@ -737,7 +739,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
Target target;
if (primitive_func->GetAttr<String>(attr::kCompiler).defined()) {
// The generic 'external device' target.
// TODO(mbs): Retire once replaced unified BYOC compiler and target macihnery.
// TODO(mbs): Retire once replaced unified BYOC compiler and target machinery
target = Target("ext_dev");
} else {
// The target corresponding to the call_node expression's annotation.
Expand Down Expand Up @@ -1090,10 +1092,10 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr
external_mods.push_back(mod); // copy-on-write.
}

// Annotate the module with C Device API context mapping (this is until we have Target's
// Annotate the module with C Device API context mapping (this is until we have Targets
// annotated for the C Device API)
// TODO(Mousius) - Remove "device_contexts" as soon as we have the graph annotated properly with
// Target's
// Targets
Map<GlobalVar, String> device_contexts =
module->GetAttr<Map<GlobalVar, String>>("device_contexts", Map<GlobalVar, String>()).value();
Map<GlobalVar, String> new_device_contexts = compiler->GetDeviceContexts();
Expand Down
6 changes: 4 additions & 2 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
shape_function_res_types.push_back(TensorType(t->GetShape(), t->GetDataType()));
}

// Assign the shape function it's true type.
// Assign the shape function its true type.
FuncType type(shape_function_arg_types, TupleType(shape_function_res_types),
/*type_params=*/{}, /*type_constraints=*/{});
VLOG(1) << "shape function '" << prim_fn_gvar->name_hint << "' has type:" << std::endl
Expand Down Expand Up @@ -483,9 +483,11 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
std::unordered_map<te::Tensor, tir::Buffer> binds;
IRModule lowered_module = tvm::LowerSchedule(schedule, all_args, func_name, binds);

// Unfortunately the above machinery creates its own GlobalVars instead of using *the*
// GlobalVar we established above. Fix this before the confusion spreads any further.
// TODO(mbs): LowerSchedule should be given prim_fn_gvar instead of func_name.
IRModule fixed_lowered_module;
for (const auto& kv : lowered_module->functions) {
// TODO(mbs): LowerSchedule should be given prim_fn_gvar instead of func_name.
GlobalVar global_var =
kv.first->name_hint == prim_fn_gvar->name_hint ? prim_fn_gvar : kv.first;
fixed_lowered_module->Add(global_var, kv.second);
Expand Down
10 changes: 5 additions & 5 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -961,13 +961,13 @@ transform::Sequential VMCompiler::MemoryOpt(const SEScope& host_se_scope) {
pass_seqs.push_back(transform::FoldConstant());

// Fuse & lower any new shape functions and device_copies.
pass_seqs.push_back(LowerOperators(host_se_scope));
pass_seqs.push_back(FuseAndLowerOperators(host_se_scope));

// Manifest the allocations needed for the shape functions.
pass_seqs.push_back(transform::ManifestAlloc(host_se_scope));

// Fuse & lower any new allocations.
pass_seqs.push_back(LowerOperators(host_se_scope));
pass_seqs.push_back(FuseAndLowerOperators(host_se_scope));

// TODO(mbrookhart, jroesch, masahi): this pass is very slow, and is
// incomplete to provide memory resuse optimizations. Disable it until we can
Expand All @@ -979,7 +979,7 @@ transform::Sequential VMCompiler::MemoryOpt(const SEScope& host_se_scope) {
pass_seqs.push_back(transform::FoldConstant());

// Fuse & lower yet again
pass_seqs.push_back(LowerOperators(host_se_scope));
pass_seqs.push_back(FuseAndLowerOperators(host_se_scope));

// Create allocations for math introduced by dynamic region math.
pass_seqs.push_back(transform::ManifestAlloc(host_se_scope));
Expand All @@ -995,7 +995,7 @@ transform::Sequential VMCompiler::MemoryOpt(const SEScope& host_se_scope) {
return transform::Sequential(std::move(pass_seqs));
}

transform::Sequential VMCompiler::LowerOperators(const SEScope& host_se_scope) {
transform::Sequential VMCompiler::FuseAndLowerOperators(const SEScope& host_se_scope) {
Array<Pass> pass_seqs;
// Hoist operators to "primitive" Functions.
pass_seqs.push_back(FuseOps());
Expand Down Expand Up @@ -1072,7 +1072,7 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));
pass_seqs.push_back(transform::LabelOps());

// Lower all function's annotated as "primitive" by FuseOps.
// Lower all functions annotated as "primitive" by FuseOps.
pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod",
[this](const BaseFunc& func) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class VMCompiler : public runtime::ModuleNode {
IRModule OptimizeModuleImpl(IRModule mod);

transform::Sequential MemoryOpt(const SEScope& host_se_scope);
transform::Sequential LowerOperators(const SEScope& host_se_scope);
transform::Sequential FuseAndLowerOperators(const SEScope& host_se_scope);

/*!
* \brief Populate the global function names in a map where the value is used
Expand Down
6 changes: 3 additions & 3 deletions src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
//
// The result will be the return type of the operator.
Type PrimitiveCall(const FuncTypeNode* op, Array<Type> arg_types, const Attrs& attrs,
const Span& span, const Expr& expr) {
const Span& span) {
if (op->type_params.size() != arg_types.size() + 1) return Type();
if (op->type_constraints.size() != 1) return Type();
const TypeRelationNode* rel = op->type_constraints[0].as<TypeRelationNode>();
Expand Down Expand Up @@ -541,8 +541,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
}

if (const OpNode* opnode = call->op.as<OpNode>()) {
Type rtype = PrimitiveCall(opnode->op_type.as<FuncTypeNode>(), arg_types, call->attrs,
call->span, GetRef<Call>(call));
Type rtype =
PrimitiveCall(opnode->op_type.as<FuncTypeNode>(), arg_types, call->attrs, call->span);

if (rtype.defined()) {
AddTypeArgs(GetRef<Call>(call), arg_types);
Expand Down

0 comments on commit 9484ec5

Please sign in to comment.