Skip to content

Commit

Permalink
fix IRModule parsing by resolving GlobalVars later (apache#41)
Browse files Browse the repository at this point in the history
* fix IRModule parsing by resolving GlobalVars later

* disable fast path that causes type inference problem for now

* print checked type on vars if present

* document ResolveGlobals
  • Loading branch information
altanh authored and yongwww committed Aug 14, 2022
1 parent 3fc8133 commit 8a9b978
Show file tree
Hide file tree
Showing 12 changed files with 297 additions and 197 deletions.
40 changes: 27 additions & 13 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,68 +19,82 @@
import tvm.ir
from . import _ffi_api


@tvm._ffi.register_object("relax.FunctionPass")
class FunctionPass(tvm.ir.transform.Pass):
"""A pass that works on each tvm.relax.Function in a module. A function
pass class should be created through `function_pass`.
"""

def FMARewrite() -> tvm.transform.Pass:

def FMARewrite() -> tvm.ir.transform.Pass:
"""Perform fused multiply add rewriting in dataflow blocks.
Returns
-------
ret: tvm.transform.Pass
ret: tvm.ir.transform.Pass
"""
return _ffi_api.FMARewrite()


def ToNonDataflow() -> tvm.transform.Pass:
def ToNonDataflow() -> tvm.ir.transform.Pass:
"""Transform all dataflow structure to non-dataflow version.
Returns
-------
ret: tvm.transform.Pass
ret: tvm.ir.transform.Pass
"""
return _ffi_api.ToNonDataflow()


def CallDPSRewrite() -> tvm.transform.Pass:
def CallDPSRewrite() -> tvm.ir.transform.Pass:
"""Perform explicit tensor allocation for call_dps.
Returns
-------
ret: tvm.transform.Pass
ret: tvm.ir.transform.Pass
"""
return _ffi_api.CallDPSRewrite()


def VMMemoryLower() -> tvm.transform.Pass:
def VMMemoryLower() -> tvm.ir.transform.Pass:
"""Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to VM intrinsics.
Returns
-------
ret: tvm.transform.Pass
ret: tvm.ir.transform.Pass
"""
return _ffi_api.VMMemoryLower()


def VMShapeLower() -> tvm.transform.Pass:
"""Lower the shape expressions in relax to VM shape heap manipulations and generate related
def VMShapeLower() -> tvm.ir.transform.Pass:
"""Lower the shape expressions in relax to VM shape heap manipulations and generate related
TIR functions to do shape calculations.
Returns
-------
ret: tvm.transform.Pass
ret: tvm.ir.transform.Pass
"""
return _ffi_api.VMShapeLower()


def ToANF() -> tvm.transform.Pass:
def ToANF() -> tvm.ir.transform.Pass:
"""Transforming Relax IR to A-normal form.
Returns
-------
ret: tvm.transform.Pass
ret: tvm.ir.transform.Pass
"""
return _ffi_api.ToANF()


def ResolveGlobals() -> tvm.ir.transform.Pass:
"""Resolve global variables using string equality. This ensures all GlobalVars in the IR refer
to the correct GlobalVar of the input IRModule. An error is reported if any GlobalVar cannot be
resolved.
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.ResolveGlobals()
8 changes: 6 additions & 2 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from synr import ast, Transformer, to_ast

import tvm
from tvm import IRModule
from tvm import IRModule, relax
from tvm._ffi.base import TVMError
from tvm.ir import GlobalVar
from tvm.ir.function import BaseFunc
Expand Down Expand Up @@ -1381,5 +1381,9 @@ def ir_module(input_module: type) -> IRModule:
func_dict = {
name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc)
}
return IRModule(func_dict)
mod = IRModule(func_dict)
mod = relax.transform.ResolveGlobals()(mod)
# FIXME(@altanh): where is the source map?
return mod

raise TypeError("Only class definitions are supported.")
9 changes: 6 additions & 3 deletions python/tvm/script/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,9 +970,12 @@ def transform_expr(self, expr: ast.Expr) -> relax.Expr:
var_name = expr.id.name
if _is_registered(var_name, op_set=self._registered_ops):
return relay.op.get(var_name)
if var_name not in self.scope:
self.report_error("undefined variable", expr.span)
return self.scope[var_name]
if var_name in self.scope:
return self.scope[var_name]
# NOTE: this is a "hack" to get around Python eagerly parsing class method decorators
# first (meaning we need to resolve them after the functions are parsed). These
# GlobalVars need to be resolved using string equality only.
return relay.GlobalVar(var_name)

elif isinstance(expr, ast.Constant):
# FIXME(@altanh): use internal representation that doesn't have precision limits here
Expand Down
11 changes: 8 additions & 3 deletions src/printer/relax_script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -497,13 +497,18 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function&
}

Doc RelaxScriptPrinter::PrintVarAnnotation(const relax::Var& var) {
// TODO(@altanh): we should consider moving annotation into binding
Doc doc;
if (var->type_annotation.defined()) {
Type annotation = var->checked_type_;
if (!annotation.defined()) {
annotation = var->type_annotation.value_or(Type());
}
if (annotation.defined()) {
doc << ": ";
if (const relax::DynTensorTypeNode* tty = var->type_annotation.as<relax::DynTensorTypeNode>()) {
if (const relax::DynTensorTypeNode* tty = annotation.as<relax::DynTensorTypeNode>()) {
doc << PrintTensorAnnotation(GetRef<DynTensorType>(tty), var->shape_);
} else {
doc << Print(var->type_annotation);
doc << Print(annotation);
}
}
return doc;
Expand Down
28 changes: 14 additions & 14 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,10 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor<Expr(const Expr&)> {

private:
/*!
* \brief Memoization map for expressions using Id for equality of variables.
*/
* \brief Memoization map for expressions using Id for equality of variables.
*/
class ExprMemo {
public:
public:
Optional<Expr> Get(const Expr& expr) {
if (const VarNode* var = expr.as<VarNode>()) {
auto it = var_memo_.find(var->vid);
Expand All @@ -230,7 +230,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor<Expr(const Expr&)> {
}
}

private:
private:
std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;
std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> expr_memo_;
};
Expand Down Expand Up @@ -370,7 +370,9 @@ Var BlockBuilderNode::Emit(const Expr& expr, bool is_dataflow, std::string name_
Var BlockBuilderNode::Emit(const VarBinding& binding) {
BlockFrame* cur_frame = CurrentFrame();
if (cur_frame->is_dataflow) {
ICHECK(binding->var.as<DataflowVarNode>());
ICHECK(binding->var.as<DataflowVarNode>())
<< "Emit can only be used for local bindings in a dataflow block, use EmitOutput for "
"output bindings instead";
}
cur_frame->bindings.push_back(binding);
binding_table_[binding->var->vid] = binding->value;
Expand Down Expand Up @@ -408,9 +410,11 @@ Var BlockBuilderNode::EmitMatchShape(const Expr& value, const Array<PrimExpr>& p

Var BlockBuilderNode::EmitMatchShape(const MatchShape& binding) {
BlockFrame* cur_frame = CurrentFrame();
if (cur_frame->is_dataflow && binding->var.defined()) {
ICHECK(!binding->var.as<DataflowVarNode>())
<< "cannot bind DataflowVar outside dataflow block.";
if (binding->var.defined()) {
ICHECK(!cur_frame->is_dataflow || binding->var.as<DataflowVarNode>())
<< "EmitMatchShape can only be used for local bindings in a dataflow block.";
ICHECK(cur_frame->is_dataflow || !binding->var.as<DataflowVarNode>())
<< "cannot emit dataflow vars outside a dataflow block: " << binding->var->name_hint();
}
cur_frame->bindings.push_back(binding);
// TODO(@altanh, @yuchen): what value should we bind? Consider
Expand Down Expand Up @@ -511,13 +515,9 @@ BlockBuilderNode::BlockFrame* BlockBuilderNode::CurrentFrame() {
return &block_stack_.top();
}

NameTable* BlockBuilderNode::name_table() {
return name_table_.get();
}
NameTable* BlockBuilderNode::name_table() { return name_table_.get(); }

BlockBuilder BlockBuilder::Create() {
return BlockBuilder(make_object<BlockBuilderNode>());
}
BlockBuilder BlockBuilder::Create() { return BlockBuilder(make_object<BlockBuilderNode>()); }

TVM_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed(BlockBuilder::Create);

Expand Down
29 changes: 17 additions & 12 deletions src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,20 @@ void ExprMutator::VisitBinding_(const VarBindingNode* binding) {
Expr new_value = this->VisitExpr(binding->value);
Var new_var = this->VisitVarDef(binding->var);

if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
// no-op if there is no change
builder_->Emit(GetRef<VarBinding>(binding));
return;
}
auto emit = [this](VarBinding b) {
if (this->builder_->CurrentBlockIsDataFlow() && !b->var.as<DataflowVarNode>()) {
this->builder_->EmitOutput(b);
} else {
this->builder_->Emit(b);
}
};

// FIXME(@altanh): try to clean up all the fast paths and ty/shape infer, it's getting unwieldy
// if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
// // no-op if there is no change
// emit(GetRef<VarBinding>(binding));
// return;
// }

{
Var temp = WithShapeAndType(new_var, new_value->shape_, new_value->checked_type_);
Expand All @@ -368,11 +377,7 @@ void ExprMutator::VisitBinding_(const VarBindingNode* binding) {
}
}

if (builder_->CurrentBlockIsDataFlow() && !new_var.as<DataflowVarNode>()) {
builder_->EmitOutput(VarBinding(new_var, new_value));
} else {
builder_->Emit(VarBinding(new_var, new_value));
}
emit(VarBinding(new_var, new_value));
}

void ExprMutator::VisitBinding_(const MatchShapeNode* binding) {
Expand All @@ -387,8 +392,8 @@ void ExprMutator::VisitBinding_(const MatchShapeNode* binding) {
if (new_value->checked_type_.defined() && new_value->checked_type_.as<DynTensorTypeNode>()) {
new_shape = new_pattern;
}
Var temp =
WithShapeAndType(this->VisitVarDef(binding->var), new_shape, new_value->checked_type_);
new_var = this->VisitVarDef(binding->var);
Var temp = WithShapeAndType(new_var, new_shape, new_value->checked_type_);
if (!temp.same_as(new_var)) {
new_var = temp;
this->var_remap_[binding->var->vid] = new_var;
Expand Down
3 changes: 2 additions & 1 deletion src/relax/op/tensor/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ Type InferTypeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx) {
auto* t1 = rhs_type.as<DynTensorTypeNode>();
if (!t0 || !t1) {
diag_ctx.EmitFatal(Diagnostic::Error(call->span)
<< "Both lhs and rhs should be DynTensor for broadcasting");
<< "Both lhs and rhs should be DynTensor for broadcasting, but got "
<< lhs_type->GetTypeKey() << " and " << rhs_type->GetTypeKey());
}

DataType output_dtype;
Expand Down
65 changes: 65 additions & 0 deletions src/relax/transform/resolve_globals.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relax/transform/resolve_globals.cc
* \brief Resolve GlobalVars using string equality.
*/
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>

namespace tvm {
namespace relax {

class GlobalVarResolver : public ExprMutator {
public:
GlobalVarResolver(IRModule mod, DiagnosticContext diag_ctx) : mod_(mod), diag_ctx_(diag_ctx) {}

Expr VisitExpr_(const GlobalVarNode* gvar) {
if (!mod_->ContainGlobalVar(gvar->name_hint)) {
diag_ctx_.Emit(Diagnostic::Error(gvar->span)
<< "undefined variable/global \"" << gvar->name_hint << "\"");
return GetRef<GlobalVar>(gvar);
}
return mod_->GetGlobalVar(gvar->name_hint);
}

private:
/*! \brief the IRModule used for GlobalVar lookup. */
IRModule mod_;
DiagnosticContext diag_ctx_;
};

namespace transform {

Pass ResolveGlobals() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[](Function f, IRModule m, PassContext pc) {
// TODO(@altanh): make sure pc always has diag_ctx?
GlobalVarResolver resolver(m, pc->diag_ctx.value());
return Downcast<Function>(resolver.VisitExpr(f));
};
return CreateFunctionPass(pass_func, 0, "ResolveGlobals", {});
}

TVM_REGISTER_GLOBAL("relax.transform.ResolveGlobals").set_body_typed(ResolveGlobals);

} // namespace transform

} // namespace relax
} // namespace tvm
Loading

0 comments on commit 8a9b978

Please sign in to comment.