Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
tkonolige committed Nov 5, 2020
1 parent 57e6506 commit f2d8f26
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 10 deletions.
6 changes: 5 additions & 1 deletion python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,11 @@ def wrapper(outs, *args, **kwargs):
"""wrapper function for topi schedule"""
workload = get_workload(outs)
if workload is None:
raise RuntimeError(f"Cannot find workload for {task_name}. You may need to register a compute function for it with `@tvm.autotvm.register_topi_compute(\"{task_name}\")`")
raise RuntimeError(
f"Cannot find workload for {task_name}. You may need to "
"register a compute function for it with "
f'`@tvm.autotvm.register_topi_compute("{task_name}")`'
)
tgt = Target.current()
cfg = DispatchContext.current.query(tgt, workload)
return topi_schedule(cfg, outs, *args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def compute_space_to_depth(attrs, inputs, out_dtype):

# embed
@reg.register_compute("nn.embed")
def compute_embed_grad(attrs, inputs, out_type):
def compute_embed(attrs, inputs, out_type):
"""Compute definition of embed"""
return [topi.nn.embed(inputs[0], inputs[1])]

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/te/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _cast(func_id, args):
return _expr.Cast(func_id, args[0])


def cast(func_id, args):
def cast(_, args):
_internal_assert(args.__len__() == 2, "cast requires two arguments: dtype, value")
return _expr.Cast(args[0], args[1])

Expand Down
12 changes: 6 additions & 6 deletions python/tvm/topi/cuda/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
"""scheduler functions for cuda backend"""
from __future__ import absolute_import as _abs

from tvm import te
from .. import cpp

from tvm import autotvm, te


def schedule_lrn(outs):
"""Schedule for LRN
Expand All @@ -44,8 +43,8 @@ def loads_per_thread(dtype):
"""Number elements per load per thread"""
s = regex.search("[0-9]+", dtype)
assert s is not None
bytes = int(s.group()) // 8
return 16 // bytes
byts = int(s.group()) // 8
return 16 // byts


def schedule_embed_grad(outs):
Expand All @@ -63,7 +62,8 @@ def schedule_embed_grad(outs):
The computation schedule for the op.
"""
s = te.create_schedule([outs[0].op])
vec_size = loads_per_thread(outs[0].dtype) # this should be autotuned, but we can't with hybrid script
# this should be autotuned, but we can't with hybrid script
vec_size = loads_per_thread(outs[0].dtype)
warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size)
num_warps = 4
out = s.outputs[0].output(0)
Expand All @@ -72,7 +72,7 @@ def schedule_embed_grad(outs):
s[out].vectorize(ji)
joo, joi = s[out].split(jo, factor=warp_size)
s[out].bind(joi, te.thread_axis("threadIdx.x"))
jooo, jooi = s[out].split(joo, factor=num_warps)
_, jooi = s[out].split(joo, factor=num_warps)
s[out].bind(jooi, te.thread_axis("threadIdx.y"))
s[out].bind(i, te.thread_axis("blockIdx.x"))

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/x86/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name,too-many-locals,unused-variable
"""x86 nn operators"""
from tvm import autotvm, te
from tvm import te


def schedule_softmax(outs):
Expand Down
16 changes: 16 additions & 0 deletions tests/python/topi/python/test_topi_embed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
import tvm.topi.testing
Expand Down

0 comments on commit f2d8f26

Please sign in to comment.