Skip to content

Commit

Permalink
[TVM] upgrade to generic schedule (apache#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 29, 2018
1 parent 08e71b7 commit 9fb13a6
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 64 deletions.
19 changes: 12 additions & 7 deletions nnvm/python/nnvm/compiler/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def optimize(graph, shape, dtype="float32"):
return graph


def build(graph, target, shape, dtype="float32", params=None):
def build(graph, target=None, shape=None, dtype="float32", params=None):
"""Build graph into runtime library.
The build function will optimize the graph and do the compilation.
Expand All @@ -175,10 +175,10 @@ def build(graph, target, shape, dtype="float32", params=None):
graph : Graph
The graph to be used in lowering
target : str
target : str or :any:`tvm.target.Target`, optional
The build target
shape : dict of str to tuple
shape : dict of str to tuple, optional
The input shape to the graph
dtype : str or dict of str to str
Expand All @@ -201,8 +201,12 @@ def build(graph, target, shape, dtype="float32", params=None):
The updated parameters of graph if params is passed.
This can be different from the params passed in.
"""
if not isinstance(target, str):
raise TypeError("require target to be str")
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
target = tvm.target.create(target)

shape = shape if shape else {}
if not isinstance(shape, dict):
raise TypeError("require shape to be dict")
cfg = BuildConfig.current
Expand All @@ -223,13 +227,14 @@ def build(graph, target, shape, dtype="float32", params=None):
# Operator Fusion and generatiom
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph._set_json_attr("target", target, "str")
graph._set_json_attr("target", str(target), "str")
if cfg.pass_enabled("OpFusion"):
graph._set_json_attr("opt_level", 1, "int")
else:
graph._set_json_attr("opt_level", 0, "int")
graph = graph.apply("InferShape").apply("InferType")
graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile")
with target:
graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile")
libmod = graph_attr._move_out_module(graph, "module")
return graph, libmod, params

Expand Down
57 changes: 18 additions & 39 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ def compute_softmax(attrs, inputs, _):
@reg.register_schedule("softmax")
def schedule_softmax(_, outs, target):
"""Schedule definition of softmax"""
if target == "cuda":
return topi.cuda.schedule_softmax(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
with tvm.target.create(target):
return topi.generic.schedule_softmax(outs)


reg.register_pattern("softmax", OpPattern.OPAQUE)

Expand All @@ -68,10 +67,8 @@ def compute_log_softmax(attrs, inputs, _):
@reg.register_schedule("log_softmax")
def schedule_log_softmax(_, outs, target):
"""Schedule definition of softmax"""
if target == "cuda":
return topi.cuda.schedule_softmax(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
with tvm.target.create(target):
return topi.generic.schedule_softmax(outs)

# Mark softmax as extern as we do not fuse it in call cases
reg.register_pattern("log_softmax", OpPattern.OPAQUE)
Expand All @@ -87,10 +84,8 @@ def compute_dense(attrs, inputs, _):
@reg.register_schedule("dense")
def schedule_dense(_, outs, target):
"""Schedule definition of dense"""
if target == "cuda":
return topi.cuda.schedule_dense(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
with tvm.target.create(target):
return topi.generic.schedule_dense(outs)

reg.register_pattern("dense", OpPattern.OUT_ELEMWISE_FUSABLE)

Expand Down Expand Up @@ -123,18 +118,10 @@ def compute_conv2d(attrs, inputs, _):
def schedule_conv2d(attrs, outs, target):
"""Schedule definition of conv2d"""
groups = attrs.get_int("groups")
if target == "cuda":
if groups == 1:
return topi.cuda.schedule_conv2d_nchw(outs)
return topi.cuda.schedule_depthwise_conv2d_nchw(outs)
# naive schedule

if tvm.target.current_target() == tvm.target.rasp():
with tvm.target.create(target):
if groups == 1:
return topi.rasp.schedule_conv2d(outs)
return topi.rasp.schedule_depthwise_conv2d(outs)

return tvm.create_schedule([x.op for x in outs])
return topi.generic.schedule_conv2d_nchw(outs)
return topi.generic.schedule_depthwise_conv2d_nchw(outs)

reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)

Expand All @@ -155,10 +142,8 @@ def compute_max_pool2d(attrs, inputs, _):
@reg.register_schedule("max_pool2d")
def schedule_max_pool2d(_, outs, target):
"""Schedule definition of max_pool2d"""
if target == "cuda":
return topi.cuda.schedule_pool(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
with tvm.target.create(target):
return topi.generic.schedule_pool(outs)

reg.register_pattern("max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)

Expand All @@ -179,10 +164,8 @@ def compute_avg_pool2d(attrs, inputs, _):
@reg.register_schedule("avg_pool2d")
def schedule_avg_pool2d(_, outs, target):
"""Schedule definition of avg_pool2d"""
if target == "cuda":
return topi.cuda.schedule_pool(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
with tvm.target.create(target):
return topi.generic.schedule_pool(outs)

reg.register_pattern("avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)

Expand All @@ -198,10 +181,8 @@ def compute_global_max_pool2d(attrs, inputs, _):
@reg.register_schedule("global_max_pool2d")
def schedule_global_max_pool2d(_, outs, target):
"""Schedule definition of global_max_pool2d"""
if target == "cuda":
return topi.cuda.schedule_global_pool(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
with tvm.target.create(target):
return topi.generic.schedule_global_pool(outs)

reg.register_pattern("global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)

Expand All @@ -217,9 +198,7 @@ def compute_global_avg_pool2d(attrs, inputs, _):
@reg.register_schedule("global_avg_pool2d")
def schedule_global_avg_pool2d(_, outs, target):
"""Schedule definition of global_avg_pool2d"""
if target == "cuda":
return topi.cuda.schedule_global_pool(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
with tvm.target.create(target):
return topi.generic.schedule_global_pool(outs)

reg.register_pattern("global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
11 changes: 3 additions & 8 deletions nnvm/python/nnvm/top/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,9 @@

def _schedule_reduce(_, outs, target):
"""Generic schedule for reduce"""
if target == "cuda":
return topi.cuda.schedule_reduce(outs)
assert target.startswith("llvm")
s = tvm.create_schedule([x.op for x in outs])
x = outs[0]
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s
with tvm.target.create(target):
return topi.generic.schedule_reduce(outs)


_fschedule_reduce = tvm.convert(_schedule_reduce)

Expand Down
13 changes: 3 additions & 10 deletions nnvm/python/nnvm/top/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,8 @@

def _schedule_injective(_, outs, target):
"""Generic schedule for binary bcast"""
if target == "cuda":
return topi.cuda.schedule_injective(outs)
assert target.startswith("llvm")
s = tvm.create_schedule([x.op for x in outs])
x = outs[0]
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s

with tvm.target.create(target):
return topi.generic.schedule_injective(outs)

def _compute_binary_scalar(f):
"""auxiliary function"""
Expand Down Expand Up @@ -174,7 +167,7 @@ def _compute(attrs, x, _):

# broadcast_to
@reg.register_compute("broadcast_to")
def compute_softmax(attrs, inputs, out_info):
def compute_broadcast_to(attrs, inputs, out_info):
"""Compute definition of softmax"""
return topi.broadcast_to(inputs[0], shape=out_info[0].shape)
reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
Expand Down

0 comments on commit 9fb13a6

Please sign in to comment.