Skip to content

Commit

Permalink
[FRONTEND] Fix return op related control flow issues (#1637)
Browse files Browse the repository at this point in the history
- Case 1: Return after static control flow is taken. Peel off
instructions after the first `return` for each basic block.

```python
if static_condition:
    tl.store(...)
    return
return
```

- Case 2: Return exists in both `if` and `else` branches of an inlined
`JITFunction` function

```python
def foo():
    if dynamic_condition:
        return a
    else:
        return b
```

- Case 3: Return exists in a `JITFunction` from another module

```python
import module
if cond:
    a = module.func()
```

- Case 4: A chain of calls through undefined local variables

```python
import module
if cond:
    a = x
    a = a.to(tl.int32).to(tl.int32)
```

- Case 5: Call a function `func` without returning variables. `func` is
recognized as an `Expr` first instead of a `Call`.

```python
if cond:
    foo()
else:
    bar()
```

- Case 6: Call a `noinline` function. We don't need to check if the
function contains any return op.
  • Loading branch information
Jokeren authored May 9, 2023
1 parent 319af1f commit b19b274
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 29 deletions.
24 changes: 24 additions & 0 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ void init_triton_ir(py::module &&m) {
return !self.empty() &&
self.back().hasTrait<mlir::OpTrait::IsTerminator>();
})
.def("has_return",
[](mlir::Block &self) {
return !self.empty() &&
self.back().hasTrait<mlir::OpTrait::ReturnLike>();
})
.def("erase", [](mlir::Block &self) { self.erase(); });

// using eattr = ir::attribute_kind_t;
Expand Down Expand Up @@ -428,6 +433,25 @@ void init_triton_ir(py::module &&m) {
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
},
ret::reference)
.def("finalize",
[](mlir::triton::FuncOp &self) -> void {
// Remove dead code
// 1. Unreachable code after return
self.walk([&](mlir::Block *block) {
mlir::Operation *retOp = nullptr;
block->walk([&](mlir::Operation *op) {
if (mlir::isa<mlir::triton::ReturnOp>(op))
if (retOp == nullptr)
retOp = op;
});
if (retOp && retOp != &block->back()) {
auto pos = retOp->getIterator();
pos++;
auto *newBlock = block->splitBlock(pos);
newBlock->erase();
}
});
})
.def_property_readonly("type", &mlir::triton::FuncOp::getFunctionType)
.def("reset_type", &mlir::triton::FuncOp::setType);

Expand Down
74 changes: 64 additions & 10 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2550,24 +2550,30 @@ def kernel(Cond, TrueVal, FalseVal, Out):
assert to_numpy(out)[0] == false_val[0]


def test_if_return():
@pytest.mark.parametrize("mode", ["dynamic", "static"])
def test_if_return(mode):

@triton.jit
def kernel(ExitEarly, Out):
if tl.load(ExitEarly):
tl.store(Out, 0)
return
def kernel(ExitEarly, Out, cond: tl.constexpr, mode: tl.constexpr):
if mode == "dynamic":
if tl.load(ExitEarly):
tl.store(Out, 0)
return
else:
if cond:
tl.store(Out, 0)
return
tl.store(Out, 1)

out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
exit_early = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
# exit early path taken
exit_early[0] = 1
kernel[(1,)](exit_early, out)
kernel[(1,)](exit_early, out, True, mode)
assert to_numpy(out)[0] == 0
# exit early path not taken
exit_early[0] = 0
kernel[(1,)](exit_early, out)
kernel[(1,)](exit_early, out, False, mode)
assert to_numpy(out)[0] == 1


Expand All @@ -2576,21 +2582,69 @@ def add_fn(x):
return x + 1


@pytest.mark.parametrize("call_type", ["attribute", "jit_function"])
@triton.jit(noinline=True)
def add_fn_noinline(x):
return x + 1


@triton.jit
def add_fn_return(x, pid):
if pid == 0:
return x + 1
else:
return x + 2


@triton.jit
def add_fn_expr(Out, x):
tl.store(Out, x)


@triton.jit
def add_fn_static_cond(x, cond: tl.constexpr):
if cond == "":
return x
else:
return x + 1


@pytest.mark.parametrize("call_type", ["attribute", "jit_function", "jit_function_return",
"ifexp", "expr", "jit_function_static_cond", "jit_function_noinline"])
def test_if_call(call_type):
@triton.jit
def kernel(Out, call_type: tl.constexpr):
pid = tl.program_id(0)
o = tl.load(Out)
if pid == 0:
if call_type == "attribute":
# call attribute
a = o + 1
a = a.to(tl.int32)
a = a.to(tl.int32).to(tl.int32)
o = a
else:
a = o
a = add_fn(a)
if call_type == "jit_function":
# regular function call
a = add_fn(a)
elif call_type == "jit_function_return":
# function without end_if block
a = add_fn_return(a, pid)
elif call_type == "ifexp":
# ifexp expression
a = add_fn(a) if pid == 0 else add_fn_return(a, pid)
elif call_type == "expr":
if pid == 1:
return
a = add_fn(a)
if pid == 0:
# call without return
add_fn_expr(Out, a)
elif call_type == "jit_function_static_cond":
a = add_fn_static_cond(a, call_type)
elif call_type == "jit_function_noinline":
a = add_fn_noinline(a)
o = a

tl.store(Out, o)

out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
Expand Down
71 changes: 52 additions & 19 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n
self.debug = debug
self.noinline = noinline
self.scf_stack = []
self.last_ret_type = None
# SSA-construction
# name => language.tensor
self.local_defs: Dict[str, tensor] = {}
Expand Down Expand Up @@ -138,7 +139,7 @@ def name_lookup(name: str) -> Any:
def set_value(self, name: str,
value: Union[tensor, constexpr]) -> None:
''' This function:
called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
1. record local defined name (FIXME: should consider control flow)
2. store tensor in self.lvalue
'''
Expand All @@ -150,10 +151,9 @@ def set_value(self, name: str,
#
def visit_compound_statement(self, stmts):
for stmt in stmts:
self.last_ret_type = self.visit(stmt)
if isinstance(stmt, ast.Return):
break
return stmts and isinstance(stmt, ast.Return)
ret_type = self.visit(stmt)
if ret_type is not None and isinstance(stmt, ast.Return):
self.last_ret_type = ret_type

# TODO: should be its own AST visitor
def contains_return_op(self, node):
Expand All @@ -168,10 +168,23 @@ def contains_return_op(self, node):
pred = lambda s: self.contains_return_op(s)
return any(pred(s) for s in node.body)
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Attribute):
def check_undefined_name(cur_node):
# Check if name is an undefined local variable,
# which can only be a tensor or a constexpr
if isinstance(cur_node.func, ast.Attribute):
if isinstance(cur_node.func.value, ast.Name):
name = cur_node.func.value.id
if name not in self.lscope and name not in self.gscope:
return True
return False
# chain of calls
# e.g., tl.load(a).to(tl.float32)
return check_undefined_name(cur_node.func.value)
return False
if check_undefined_name(node):
return False
fn = self.visit(node.func)
if isinstance(fn, JITFunction):
if isinstance(fn, JITFunction) and fn.noinline is False:
old_gscope = self.gscope
self.gscope = sys.modules[fn.fn.__module__].__dict__
ret = self.contains_return_op(fn.parse())
Expand All @@ -184,6 +197,18 @@ def contains_return_op(self, node):
if node.orelse:
ret = ret or any(pred(s) for s in node.orelse)
return ret
elif isinstance(node, ast.IfExp):
return self.contains_return_op(node.body) or self.contains_return_op(node.orelse)
elif isinstance(node, ast.Expr):
ret = False
for _, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
ret = ret or self.contains_return_op(item)
elif isinstance(value, ast.AST):
ret = ret or self.contains_return_op(value)
return ret
else:
return False

Expand Down Expand Up @@ -257,9 +282,9 @@ def visit_FunctionDef(self, node):
self.set_value(arg_name, arg_value)
self.builder.set_insertion_point_to_start(entry)
# visit function body
has_ret = self.visit_compound_statement(node.body)
self.visit_compound_statement(node.body)
# finalize function
if not has_ret:
if self.last_ret_type is None:
self.builder.ret([])
else:
# update return type
Expand All @@ -271,6 +296,8 @@ def visit_FunctionDef(self, node):
fn.reset_type(self.prototype.to_ir(self.builder))
if insert_pt:
self.builder.set_insertion_point_to_end(insert_pt)
# Remove dead code
fn.finalize()

def visit_arguments(self, node):
arg_names = []
Expand Down Expand Up @@ -421,6 +448,7 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block):
return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types

def visit_if_top_level(self, cond, node):
has_endif_block = True
with enter_sub_region(self) as sr:
liveins, ip_block = sr
then_block = self.builder.create_block()
Expand All @@ -435,20 +463,25 @@ def visit_if_top_level(self, cond, node):
self.visit_then_else_blocks(node, liveins, then_block, else_block)
# then terminator
self.builder.set_insertion_point_to_end(then_block)
if not then_block.has_terminator():
if then_block.has_return() and else_block.has_return():
has_endif_block = False
endif_block.erase()
if not then_block.has_terminator() and has_endif_block:
self.builder.create_branch(endif_block, [then_defs[n].handle for n in names])
# else terminator
self.builder.set_insertion_point_to_end(else_block)
if not else_block.has_terminator():
if not else_block.has_terminator() and has_endif_block:
self.builder.create_branch(endif_block, [else_defs[n].handle for n in names])
for ty in ir_ret_types:
endif_block.add_argument(ty)
# change block
self.builder.set_insertion_point_to_start(endif_block)
# update value
for i, name in enumerate(names):
new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i])
self.set_value(name, new_tensor)
if has_endif_block:
for ty in ir_ret_types:
endif_block.add_argument(ty)
if has_endif_block:
# change block
self.builder.set_insertion_point_to_start(endif_block)
# update value
for i, name in enumerate(names):
new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i])
self.set_value(name, new_tensor)

# TODO: refactor
def visit_if_scf(self, cond, node):
Expand Down

0 comments on commit b19b274

Please sign in to comment.