Skip to content

Commit

Permalink
Nits.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Apr 29, 2022
1 parent c8f21f2 commit 578bce0
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 12 deletions.
3 changes: 2 additions & 1 deletion python/tvm/meta_schedule/apply_history_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from . import _ffi_api
from .database import Database
from .utils import make_logging_func

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

Expand All @@ -43,7 +44,7 @@ class ApplyHistoryBest(Object):

def __init__(self, database: Database) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ApplyHistoryBest, database, logger # type: ignore # pylint: disable=no-member
_ffi_api.ApplyHistoryBest, database, make_logging_func(logger) # type: ignore # pylint: disable=no-member
)

def query(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/runner/rpc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
run_evaluator_common,
)


logger = logging.getLogger(__name__) # pylint: disable=invalid-name


T_CREATE_SESSION = Callable[ # pylint: disable=invalid-name
[RPCConfig], # The RPC configuration
RPCSession, # The RPC Session
Expand Down
10 changes: 3 additions & 7 deletions python/tvm/meta_schedule/task_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from ..tune_context import TuneContext


logger = logging.getLogger(__name__) # pylint: disable=invalid-name


@register_object("meta_schedule.TaskScheduler")
class TaskScheduler(Object):
"""The abstract task scheduler interface.
Expand All @@ -52,8 +55,6 @@ class TaskScheduler(Object):
The cost model used for search.
measure_callbacks: List[MeasureCallback] = None
The list of measure callbacks of the scheduler.
logger: Optional[logging.Logger]
The logger of the task scheduler.
num_trials_already : int
The number of trials already conducted.
"""
Expand All @@ -66,7 +67,6 @@ class TaskScheduler(Object):
cost_model: Optional[CostModel]
measure_callbacks: List[MeasureCallback]
num_trials_already: int
logger: Optional[logging.Logger]

def tune(self) -> None:
"""Auto-tuning."""
Expand Down Expand Up @@ -136,7 +136,6 @@ def __init__(
max_trials: int,
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
logger: Optional[logging.Logger] = None,
f_tune: Callable = None,
f_initialize_task: Callable = None,
f_touch_task: Callable = None,
Expand Down Expand Up @@ -181,7 +180,6 @@ class PyTaskScheduler:
"max_trials",
"cost_model",
"measure_callbacks",
"logger",
],
"methods": [
"tune",
Expand All @@ -201,7 +199,6 @@ def __init__(
max_trials: int,
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
logger: Optional[logging.Logger] = None,
):
self.tasks = tasks
self.builder = builder
Expand All @@ -210,7 +207,6 @@ def __init__(
self.max_trials = max_trials
self.cost_model = cost_model
self.measure_callbacks = measure_callbacks
self.logger = logger

def tune(self) -> None:
"""Auto-tuning."""
Expand Down
1 change: 0 additions & 1 deletion python/tvm/meta_schedule/testing/relay_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from tvm.runtime import NDArray, load_param_dict, save_param_dict
from tvm.target import Target


logger = logging.getLogger(__name__) # pylint: disable=invalid-name


Expand Down
3 changes: 1 addition & 2 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,13 @@
from .tune_context import TuneContext
from .utils import autotvm_silencer

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

FnSpaceGenerator = Callable[[], SpaceGenerator]
FnScheduleRule = Callable[[], List[ScheduleRule]]
FnPostproc = Callable[[], List[Postproc]]
FnMutatorProb = Callable[[], Dict[Mutator, float]]

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


class DefaultLLVM:
"""Default tuning configuration for LLVM."""
Expand Down

0 comments on commit 578bce0

Please sign in to comment.