diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index 82c8dbb8cf5cb..1de4e8d49df6a 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -229,7 +229,11 @@ def wrapper(outs, *args, **kwargs): """wrapper function for topi schedule""" workload = get_workload(outs) if workload is None: - raise RuntimeError(f"Cannot find workload for {task_name}. You may need to register a compute function for it with `@tvm.autotvm.register_topi_compute(\"{task_name}\")`") + raise RuntimeError( + f"Cannot find workload for {task_name}. You may need to " + "register a compute function for it with " + f'`@tvm.autotvm.register_topi_compute("{task_name}")`' + ) tgt = Target.current() cfg = DispatchContext.current.query(tgt, workload) return topi_schedule(cfg, outs, *args, **kwargs) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index ee07886153340..8e0909fdcabba 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -748,7 +748,7 @@ def compute_space_to_depth(attrs, inputs, out_dtype): # embed @reg.register_compute("nn.embed") -def compute_embed_grad(attrs, inputs, out_type): +def compute_embed(attrs, inputs, out_type): """Compute definition of embed""" return [topi.nn.embed(inputs[0], inputs[1])] diff --git a/python/tvm/te/hybrid/calls.py b/python/tvm/te/hybrid/calls.py index d89663dda7224..09ba3713f0f8d 100644 --- a/python/tvm/te/hybrid/calls.py +++ b/python/tvm/te/hybrid/calls.py @@ -137,7 +137,7 @@ def _cast(func_id, args): return _expr.Cast(func_id, args[0]) -def cast(func_id, args): +def cast(_, args): _internal_assert(args.__len__() == 2, "cast requires two arguments: dtype, value") return _expr.Cast(args[0], args[1]) diff --git a/python/tvm/topi/cuda/nn.py b/python/tvm/topi/cuda/nn.py index 96da5b42ac140..c6820946689f8 100644 --- a/python/tvm/topi/cuda/nn.py +++ b/python/tvm/topi/cuda/nn.py @@ -18,10 +18,9 @@ """scheduler functions for cuda backend""" from __future__ import absolute_import as _abs +from tvm import te from .. import cpp -from tvm import autotvm, te - def schedule_lrn(outs): """Schedule for LRN @@ -44,8 +43,8 @@ def loads_per_thread(dtype): """Number elements per load per thread""" s = regex.search("[0-9]+", dtype) assert s is not None - bytes = int(s.group()) // 8 - return 16 // bytes + byts = int(s.group()) // 8 + return 16 // byts def schedule_embed_grad(outs): @@ -63,7 +62,8 @@ def schedule_embed_grad(outs): The computation schedule for the op. """ s = te.create_schedule([outs[0].op]) - vec_size = loads_per_thread(outs[0].dtype) # this should be autotuned, but we can't with hybrid script + # this should be autotuned, but we can't with hybrid script + vec_size = loads_per_thread(outs[0].dtype) warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) num_warps = 4 out = s.outputs[0].output(0) @@ -72,7 +72,7 @@ def schedule_embed_grad(outs): s[out].vectorize(ji) joo, joi = s[out].split(jo, factor=warp_size) s[out].bind(joi, te.thread_axis("threadIdx.x")) - jooo, jooi = s[out].split(joo, factor=num_warps) + _, jooi = s[out].split(joo, factor=num_warps) s[out].bind(jooi, te.thread_axis("threadIdx.y")) s[out].bind(i, te.thread_axis("blockIdx.x")) diff --git a/python/tvm/topi/x86/nn.py b/python/tvm/topi/x86/nn.py index 9701629715d80..35f5288d107a6 100644 --- a/python/tvm/topi/x86/nn.py +++ b/python/tvm/topi/x86/nn.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name,too-many-locals,unused-variable """x86 nn operators""" -from tvm import autotvm, te +from tvm import te def schedule_softmax(outs): diff --git a/tests/python/topi/python/test_topi_embed.py b/tests/python/topi/python/test_topi_embed.py index 4ef22c45e6c6f..078f7f63afb84 100644 --- a/tests/python/topi/python/test_topi_embed.py +++ b/tests/python/topi/python/test_topi_embed.py @@ -1,3 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. import tvm import tvm.testing import tvm.topi.testing