From fb77424d54452355cf5a16b8a7b5414594c02427 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Mon, 25 Feb 2019 18:18:41 -0800 Subject: [PATCH] [Hybrid Script] Add `max_num_threads` (#2672) * i think it works for now? * fix lint * fix 2/3 compat * fix py2 again * fine, i gave up --- python/tvm/hybrid/calls.py | 14 ++++- python/tvm/hybrid/parser.py | 27 +++++---- python/tvm/hybrid/preprocessor.py | 3 + python/tvm/hybrid/runtime.py | 61 ++++++++++++--------- src/contrib/hybrid/codegen_hybrid.cc | 3 + tests/python/unittest/test_hybrid_script.py | 16 ++++++ 6 files changed, 86 insertions(+), 38 deletions(-) diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index cd1e4e3a2085..56a73f784fa0 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -4,6 +4,7 @@ from .. import api as _api from .. import expr as _expr from .. import make as _make +from .. import target as _tgt from ..container import Array from .. import ir_pass from ..stmt import For @@ -123,7 +124,7 @@ def ceil_div(func_id, args): _internal_assert(isinstance(args[0], _expr.Expr), "Only expressions can div") _internal_assert(isinstance(args[1], _expr.Expr), "Only expressions can div") a, b = args[0], args[1] - return (a + b - 1) / b + return (a + b - 1) // b def likely(func_id, args): @@ -131,3 +132,14 @@ def likely(func_id, args): "Only one expression can be likely") _internal_assert(func_id == "likely", "This function cannot be directly invoked!") return call_pure_intrin(args[0].dtype, 'likely', *args) + + +def max_num_threads(func_id, args): + _internal_assert(func_id == "max_num_threads", "This function cannot be directly invoked!") + _internal_assert(args.__len__() <= 1, "At most one argument accepted!") + if args.__len__() == 0: + res = _tgt.current_target().max_num_threads + else: + _internal_assert(isinstance(args[0], _expr.UIntImm), "In tvm bool should be uint") + res = _tgt.current_target(args[0].value).max_num_threads + return _api.convert(res) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 0959c9df2e91..67a6f6632d16 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -219,6 +219,8 @@ def visit_Expr(self, node): def visit_Name(self, node): name = node.id + if sys.version_info[0] == 2 and name in ['True', 'False']: + return _api.convert(eval(name)) #pylint: disable=eval-used ty, entry = self.symbols[name] _internal_assert(name in self.symbols, "Unknown symbol %s!" % name) if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]: @@ -248,6 +250,10 @@ def visit_Num(self, node): return _api.const(node.n, dtype) + def visit_NameConstant(self, node): + return _api.convert(node.value) + + def visit_AugAssign(self, node): buf = self.visit(node.target) rhs = self.visit(node.value) @@ -450,17 +456,18 @@ def visit_Call(self, node): func_id = node.func.id args = [self.visit(i) for i in node.args] - try: + # Intrinsics' + if hasattr(calls, func_id): return getattr(calls, func_id)(func_id, args) - except AttributeError: - _internal_assert(func_id in self.symbols.keys(), \ - "The function called is not in the context either!") - ty, entry = self.symbols[func_id] - _internal_assert(ty is Symbol.Callable, \ - "Are you sure what you call is a function?!") - outs = entry(*args) - op = outs.op if isinstance(outs, Tensor) else outs[0].op - return op + # Contexts' + _internal_assert(func_id in self.symbols.keys(), \ + "The function called (%s) is not in the context either!" % func_id) + ty, entry = self.symbols[func_id] + _internal_assert(ty is Symbol.Callable, \ + "Are you sure what you call is a function?!") + outs = entry(*args) + op = outs.op if isinstance(outs, Tensor) else outs[0].op + return op def visit_For(self, node): diff --git a/python/tvm/hybrid/preprocessor.py b/python/tvm/hybrid/preprocessor.py index 50b610567c74..a83fb2eae287 100644 --- a/python/tvm/hybrid/preprocessor.py +++ b/python/tvm/hybrid/preprocessor.py @@ -59,6 +59,9 @@ def visit_AugAssign(self, node): def visit_Name(self, node): + # If it is True or False, we do not worry about it! + if sys.version_info[0] == 2 and node.id in ['True', 'False']: + return # If it is from the argument list or loop variable, we do not worry about it! if node.id in self._args.keys(): return diff --git a/python/tvm/hybrid/runtime.py b/python/tvm/hybrid/runtime.py index 293e069c24ea..b3c744f42652 100644 --- a/python/tvm/hybrid/runtime.py +++ b/python/tvm/hybrid/runtime.py @@ -1,6 +1,7 @@ """Intrinsics of TVM-Python Hybrid Script for Python emulation runtime""" import numpy +from .. import target class bind(object): #pylint: disable=invalid-name @@ -72,34 +73,40 @@ def sigmoid(x): return 1 / (1 + numpy.exp(-x)) +def max_num_threads(allow_none=True): + """Get max number of threads for GPU targets.""" + return target.current_target(allow_none).max_num_threads + + HYBRID_GLOBALS = { - 'unroll' : range, - 'vectorize' : range, - 'parallel' : range, - 'const_range' : range, - 'bind' : bind, - 'allocate' : allocate, - 'output_tensor': allocate, - 'sqrt' : numpy.sqrt, - 'log' : numpy.log, - 'tanh' : numpy.tanh, - 'power' : numpy.power, - 'exp' : numpy.exp, - 'sigmoid' : sigmoid, - 'popcount' : popcount, - 'likely' : lambda cond: cond, - 'uint8' : numpy.uint8, - 'uint16' : numpy.uint16, - 'uint32' : numpy.uint32, - 'uint64' : numpy.uint64, - 'int8' : numpy.int8, - 'int16' : numpy.int16, - 'int32' : numpy.int32, - 'int64' : numpy.int64, - 'float16' : numpy.float16, - 'float32' : numpy.float32, - 'float64' : numpy.float64, - 'ceil_div' : lambda a, b: (a + b - 1) / b + 'unroll' : range, + 'vectorize' : range, + 'parallel' : range, + 'const_range' : range, + 'bind' : bind, + 'allocate' : allocate, + 'output_tensor' : allocate, + 'sqrt' : numpy.sqrt, + 'log' : numpy.log, + 'tanh' : numpy.tanh, + 'power' : numpy.power, + 'exp' : numpy.exp, + 'sigmoid' : sigmoid, + 'popcount' : popcount, + 'likely' : lambda cond: cond, + 'uint8' : numpy.uint8, + 'uint16' : numpy.uint16, + 'uint32' : numpy.uint32, + 'uint64' : numpy.uint64, + 'int8' : numpy.int8, + 'int16' : numpy.int16, + 'int32' : numpy.int32, + 'int64' : numpy.int64, + 'float16' : numpy.float16, + 'float32' : numpy.float32, + 'float64' : numpy.float64, + 'ceil_div' : lambda a, b: (a + b - 1) // b, + 'max_num_threads': max_num_threads } diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 2117d471eeee..56564d668001 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -400,6 +400,8 @@ void CodeGenHybrid::ReserveKeywords() { GetUniqueName("for"); GetUniqueName("in"); GetUniqueName("range"); + GetUniqueName("True"); + GetUniqueName("False"); GetUniqueName("unroll"); GetUniqueName("const_range"); GetUniqueName("parallel"); @@ -434,6 +436,7 @@ void CodeGenHybrid::ReserveKeywords() { GetUniqueName("float32"); GetUniqueName("float64"); GetUniqueName("ceil_div"); + GetUniqueName("max_num_threads"); } void CodeGenHybrid::DumpStmt(const Stmt &stmt, diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index d35c8ab3a0df..5bed58c8f617 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -350,6 +350,22 @@ def foo(a): func, ins, outs = run_and_check(foo, [a], target='cuda') run_and_check(func, ins, outs=outs, target='cuda') + @tvm.hybrid.script + def max_threads(a): + b = output_tensor(a.shape, a.dtype) + n = a.shape[0] + m = max_num_threads(True) + for i in bind('threadIdx.x', m): + for j in bind('blockIdx.x', ceil_div(n, m)): + if i * m + j < n: + b[i * m + j] = a[i * m + j] + a[i * m + j] + return b + + a = tvm.placeholder((10000, ), 'float32') + with tvm.target.create('cuda'): + func, ins, outs = run_and_check(max_threads, [a], target='cuda') + run_and_check(func, ins, outs=outs, target='cuda') + def test_math_intrin(): @script