Skip to content

Commit

Permalink
[TVMScript] Report error if add attr to implicit root block (#9507)
Browse files Browse the repository at this point in the history
* fix implict root block attrs

* lint
  • Loading branch information
Hzfengsy authored Nov 15, 2021
1 parent 76c78a9 commit 3f9b72d
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 18 deletions.
10 changes: 9 additions & 1 deletion python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ class ContextMaintainer:
_report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
"""Callable[[str, Union[Span, synr.ast.Span]], None]: The report error function handle"""

# root alloc_buffer
root_alloc_buffers: List[Buffer] = []
"""List[Buffer]: The buffers allocated under root block"""

def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
# scope context
self.node_stack = []
Expand All @@ -152,6 +156,8 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No
# parser and analyzer
self._report_error = _report_error
self.analyzer = tvm.arith.Analyzer()
# root alloc_buffer
self.root_alloc_buffers = []

def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
"""Creates a new scope
Expand Down Expand Up @@ -230,4 +236,6 @@ def report_error(self, message: str, span: Union[Span, synr.ast.Span]):
self._report_error(message, span)

def current_block_scope(self) -> BlockInfo:
return self.block_info_stack[-1]
if self.block_info_stack:
return self.block_info_stack[-1]
return None
21 changes: 8 additions & 13 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from . import _ffi_api
from . import tir

from .context_maintainer import BlockInfo, ContextMaintainer
from .context_maintainer import ContextMaintainer
from .meta_unparser import MetaUnparser
from .registry import Registry
from .diagnostics import TVMDiagnosticCtx
Expand Down Expand Up @@ -449,19 +449,8 @@ def check_decorator(decorators: List[ast.Expr]) -> bool:
node.span,
)

# New Scope : Implicit root block
# Each function contains an implicit root block in TensorIR,
# so here we need a block scope for it. Please note that `enter_block_scope`
# will not create a block directly but just stores some information.
# If the PrimFunc is not a TensorIR func (e.g. TE scheduled func or low-level func),
# the root block will not be added. The logic to add root block is in `_ffi_api.Complete`
self.context.enter_block_scope(nodes=node.body.stmts)

# fetch the body of root block
body = self.parse_body(node.body)
# Emit Scope : Implicit root block
root_info: BlockInfo = self.context.current_block_scope()
self.context.exit_block_scope()

# return a tir.PrimFunc
dict_attr = self.context.func_dict_attr
Expand All @@ -475,6 +464,12 @@ def check_decorator(decorators: List[ast.Expr]) -> bool:
span=tvm_span_from_synr(node.span),
)

# New Scope : Implicit root block
# Each function contains an implicit root block in TensorIR,
# so here we need a block scope for it.
# If the PrimFunc is not a TensorIR func (e.g. TE scheduled func or low-level func),
# the root block will not be added. The logic to add root block is in `_ffi_api.Complete`

# Fix the PrimFunc
# 1. generate root block if necessary
# 2. generate surrounding loops for blocks if necessary
Expand All @@ -484,7 +479,7 @@ def check_decorator(decorators: List[ast.Expr]) -> bool:
node.span,
_ffi_api.Complete,
func,
root_info.alloc_buffers,
self.context.root_alloc_buffers,
)

self.context.exit_scope()
Expand Down
31 changes: 30 additions & 1 deletion python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,11 @@ def alloc_buffer(
buffer_type,
span=span,
)
self.context.current_block_scope().alloc_buffers.append(buffer)
if self.context.current_block_scope():
self.context.current_block_scope().alloc_buffers.append(buffer)
else:
# If it is allocated outside all blocks, allocate it under root block.
self.context.root_alloc_buffers.append(buffer)
self.context.update_symbol(buffer_name, buffer, self.node)

super().__init__(alloc_buffer, def_symbol=True)
Expand All @@ -309,6 +313,11 @@ def __init__(self):
def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span: Span = None):
assert self.context, "call 'exit_scope' before 'enter_scope'"
block_scope = self.context.current_block_scope()
if block_scope is None:
self.context.report_error(
"Expected to declare read regions inside a block.",
span,
)
if block_scope.reads is not None:
self.context.report_error(
"Duplicate write region declaration, "
Expand Down Expand Up @@ -344,6 +353,11 @@ def __init__(self):
def writes(write_region: Union[BufferSlice, List[BufferSlice]], span: Span = None):
assert self.context, "call 'exit_scope' before 'enter_scope'"
block_scope = self.context.current_block_scope()
if block_scope is None:
self.context.report_error(
"Expected to declare write regions inside a block.",
span,
)
if block_scope.writes is not None:
self.context.report_error(
"Duplicate write region declaration, "
Expand Down Expand Up @@ -381,6 +395,11 @@ def __init__(self):
def block_attr(attrs: Mapping[str, Object], span: Span = None):
assert self.context, "call 'exit_scope' before 'enter_scope'"
block_scope = self.context.current_block_scope()
if block_scope is None:
self.context.report_error(
"Expected to declare block annotations inside a block.",
span,
)
if block_scope.annotations is not None:
self.context.report_error(
"Duplicate block annotations declaration, "
Expand Down Expand Up @@ -438,6 +457,11 @@ def axis(
"""
assert self.context, "call 'exit_scope' before 'enter_scope'"
block_scope: BlockInfo = self.context.current_block_scope()
if block_scope is None:
self.context.report_error(
"Expected to declare block axes inside a block.",
self.node.span,
)
if var_name in [iter_var.var.name for iter_var in block_scope.iter_vars]:
self.context.report_error("Duplicate block axis " + var_name, self.node.span)

Expand Down Expand Up @@ -721,6 +745,11 @@ def __init__(self):
def where(predicate, span=None):
assert self.context, "call 'exit_scope' before 'enter_scope'"
block_scope = self.context.current_block_scope()
if block_scope is None:
self.context.report_error(
"Expected to declare the predicate inside a block.",
span,
)
if block_scope.predicate is not None:
self.context.report_error(
"Duplicate block predicate declaration, "
Expand Down
36 changes: 33 additions & 3 deletions tests/python/unittest/test_tvmscript_error_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,39 @@ def test_block_has_option_vars():
check_error(block_has_option_vars, 2)


def implicit_root_has_read():
T.reads([]) # error: implicit root does not support reads
T.evaluate(0.0)


def implicit_root_has_write():
T.writes([]) # error: implicit root does not support writes
T.evaluate(0.0)


def implicit_root_has_attrs():
T.block_attr({}) # error: implicit root does not support block_attr
T.evaluate(0.0)


def implicit_root_has_predicate():
T.where(True) # error: implicit root does not support predicate
T.evaluate(0.0)


def implicit_root_has_axes():
v = T.axis.S(0, 0) # error: implicit root does not support axis define
T.evaluate(0.0)


def test_implicit_root_has_attrs():
check_error(implicit_root_has_read, 2)
check_error(implicit_root_has_write, 2)
check_error(implicit_root_has_attrs, 2)
check_error(implicit_root_has_predicate, 2)
check_error(implicit_root_has_axes, 2)


def check_error(func, rel_lineno):
# Override the default renderer to accumulate errors
errors = []
Expand All @@ -510,9 +543,6 @@ def render(e):
), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}"


# TODO(Siyuan): block iter errors.


@T.prim_func
def elementwise_not_affine(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128, 128, 128))
Expand Down

0 comments on commit 3f9b72d

Please sign in to comment.