Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Logging Interface Unification #11157

Merged
merged 19 commits into from
May 4, 2022
5 changes: 4 additions & 1 deletion include/tvm/meta_schedule/apply_history_best.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class ApplyHistoryBestNode : public runtime::Object {
public:
/*! \brief The database to be queried from */
Database database{nullptr};
/*! \brief The logging function to be used */
PackedFunc logging_func;

void VisitAttrs(AttrVisitor* v) { v->Visit("database", &database); }
/*!
Expand All @@ -58,8 +60,9 @@ class ApplyHistoryBest : public runtime::ObjectRef {
/*!
* \brief Constructor
* \param database The database to be queried from
* \param logging_func The logging function to use
*/
explicit ApplyHistoryBest(Database database);
explicit ApplyHistoryBest(Database database, PackedFunc logging_func);
/*!
* \brief The current ApplyHistoryBest in the context
* \return The ApplyHistoryBest in the current scope.
Expand Down
23 changes: 16 additions & 7 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class TaskSchedulerNode : public runtime::Object {
Array<MeasureCallback> measure_callbacks;
/*! \brief The number of trials already conducted. */
int num_trials_already;
/*! \brief The tuning task's logging function. t*/
PackedFunc logging_func;

/*! \brief The default destructor. */
virtual ~TaskSchedulerNode() = default;
Expand All @@ -96,6 +98,7 @@ class TaskSchedulerNode : public runtime::Object {
v->Visit("cost_model", &cost_model);
v->Visit("measure_callbacks", &measure_callbacks);
v->Visit("num_trials_already", &num_trials_already);
// v->Visit("logging_func", &logging_func);
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
}

/*! \brief Auto-tuning. */
Expand Down Expand Up @@ -234,15 +237,17 @@ class TaskScheduler : public runtime::ObjectRef {
* \param max_trials The maximum number of trials.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \param logging_func The tuning task's logging function.
* \return The task scheduler created.
*/
TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database, //
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks);
TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database, //
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
PackedFunc logging_func);
/*!
* \brief Create a task scheduler that fetches tasks in a gradient based fashion.
* \param tasks The tasks to be tuned.
Expand All @@ -253,6 +258,7 @@ class TaskScheduler : public runtime::ObjectRef {
* \param max_trials The maximum number of trials.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \param logging_func The tuning task's logging function.
* \param alpha The parameter alpha to control gradient computation.
* \param window_size The parameter to control backward window size.
* \param seed The random seed.
Expand All @@ -266,6 +272,7 @@ class TaskScheduler : public runtime::ObjectRef {
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
PackedFunc logging_func, //
double alpha, //
int window_size, //
support::LinearCongruentialEngine::TRandState seed);
Expand All @@ -278,6 +285,7 @@ class TaskScheduler : public runtime::ObjectRef {
* \param max_trials The maximum number of trials.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \param logging_func The tuning task's logging function.
* \param f_tune The packed function of `Tune`.
* \param f_initialize_task The packed function of `InitializeTask`.
* \param f_touch_task The packed function of `TouchTask`.
Expand All @@ -293,6 +301,7 @@ class TaskScheduler : public runtime::ObjectRef {
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
PackedFunc logging_func, //
PyTaskSchedulerNode::FTune f_tune, //
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
PyTaskSchedulerNode::FTouchTask f_touch_task, //
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class TuneContextNode : public runtime::Object {
Map<Mutator, FloatImm> mutator_probs;
/*! \brief The name of the tuning task. */
Optional<String> task_name;
/*! \brief The tuning task's logging function. t*/
PackedFunc logging_func;
/*! \brief The random state. */
support::LinearCongruentialEngine::TRandState rand_state;
/*! \brief The number of threads to be used. */
Expand Down Expand Up @@ -85,6 +87,7 @@ class TuneContextNode : public runtime::Object {
v->Visit("builder_results", &builder_results);
v->Visit("runner_futures", &runner_futures);
v->Visit("measure_candidates", &measure_candidates);
// v->Visit("logging_func", &logging_func);
}

/*! \brief Initialize members that needs initialization with tune context. */
Expand All @@ -110,6 +113,7 @@ class TuneContext : public runtime::ObjectRef {
* \param postprocs The postprocessors.
* \param mutator_probs The probability of using certain mutator.
* \param task_name The name of the tuning task.
* \param logging_func The tuning task's logging function.
* \param rand_state The random state.
* \param num_threads The number of threads to be used.
*/
Expand All @@ -121,6 +125,7 @@ class TuneContext : public runtime::ObjectRef {
Optional<Array<Postproc>> postprocs, //
Optional<Map<Mutator, FloatImm>> mutator_probs, //
Optional<String> task_name, //
PackedFunc logging_func, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TuneContext, ObjectRef, TuneContextNode);
Expand Down
15 changes: 10 additions & 5 deletions python/tvm/meta_schedule/apply_history_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""A context manager that injects the best tuning record in the database into compilation"""
import logging
from typing import List, Optional, Union

from tvm._ffi import register_object
Expand All @@ -25,6 +26,8 @@
from . import _ffi_api
from .database import Database

logger = logging.getLogger("tvm.meta_schedule") # pylint: disable=invalid-name
zxybazh marked this conversation as resolved.
Show resolved Hide resolved


@register_object("meta_schedule.ApplyHistoryBest")
class ApplyHistoryBest(Object):
Expand All @@ -34,15 +37,17 @@ class ApplyHistoryBest(Object):
----------
database : Database
The database to be queried from
logger : logging.Logger
The logger to be used
"""

database: Database
logger: logging.Logger
zxybazh marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
database: Database,
) -> None:
self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member
def __init__(self, database: Database) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ApplyHistoryBest, database, logger # type: ignore # pylint: disable=no-member
)

def query(
self,
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/meta_schedule/builder/local_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@
from tvm.target import Target

from ...contrib.popen_pool import MapResult, PopenPoolExecutor, StatusKind
from ..utils import cpu_count, derived_object, get_global_func_with_default_on_worker
from ..utils import (
cpu_count,
derived_object,
get_global_func_with_default_on_worker,
)
from .builder import BuilderInput, BuilderResult, PyBuilder

logger = logging.getLogger(__name__) # pylint: disable=invalid-name
logger = logging.getLogger("tvm.meta_schedule") # pylint: disable=invalid-name
zxybazh marked this conversation as resolved.
Show resolved Hide resolved


T_BUILD = Callable[ # pylint: disable=invalid-name
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from ..tune_context import TuneContext


logger = logging.getLogger(__name__) # pylint: disable=invalid-name
logger = logging.getLogger("tvm.meta_schedule") # pylint: disable=invalid-name
zxybazh marked this conversation as resolved.
Show resolved Hide resolved


def make_metric_sorter(focused_metric):
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/meta_schedule/runner/local_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Local Runner"""
from contextlib import contextmanager
import logging
from contextlib import contextmanager
from typing import Callable, List, Optional, Union

import tvm
Expand All @@ -33,7 +33,8 @@
run_evaluator_common,
)

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

logger = logging.getLogger("tvm.meta_schedule") # pylint: disable=invalid-name
zxybazh marked this conversation as resolved.
Show resolved Hide resolved


T_ALLOC_ARGUMENT = Callable[ # pylint: disable=invalid-name
Expand Down Expand Up @@ -293,7 +294,7 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
try:
result: List[float] = future.result()
error_message: str = None
except TimeoutError as exception:
except TimeoutError:
result = None
error_message = f"LocalRunner: Timeout, killed after {self.timeout_sec} seconds\n"
except Exception as exception: # pylint: disable=broad-except
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/runner/rpc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""RPC Runner"""
import concurrent.futures
import logging
import concurrent.futures
import os.path as osp
from contextlib import contextmanager
from typing import Callable, List, Optional, Union
Expand All @@ -39,8 +39,8 @@
run_evaluator_common,
)

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

logger = logging.getLogger("tvm.meta_schedule") # pylint: disable=invalid-name
zxybazh marked this conversation as resolved.
Show resolved Hide resolved

T_CREATE_SESSION = Callable[ # pylint: disable=invalid-name
[RPCConfig], # The RPC configuration
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/gradient_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
# specific language governing permissions and limitations
# under the License.
"""Gradient Based Task Scheduler"""
import logging
from typing import TYPE_CHECKING, List, Optional

from tvm._ffi import register_object
from tvm.meta_schedule.utils import make_logging_func
zxybazh marked this conversation as resolved.
Show resolved Hide resolved

from .. import _ffi_api
from ..builder import Builder
Expand Down Expand Up @@ -46,6 +48,7 @@ def __init__(
*,
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
logger: Optional[logging.Logger] = None,
alpha: float = 0.2,
window_size: int = 3,
seed: int = -1,
Expand All @@ -70,6 +73,8 @@ def __init__(
The cost model of the scheduler.
measure_callbacks : Optional[List[MeasureCallback]] = None
The list of measure callbacks of the scheduler.
logger: Optional[logging.Logger]
The logger of the task scheduler.
alpha : float = 0.2
The parameter alpha in gradient computation.
window_size : int = 3
Expand All @@ -87,6 +92,7 @@ def __init__(
max_trials,
cost_model,
measure_callbacks,
make_logging_func(logger),
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
alpha,
window_size,
seed,
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/round_robin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
# under the License.
"""Round Robin Task Scheduler"""

import logging
from typing import TYPE_CHECKING, List, Optional

from tvm._ffi import register_object
from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback
from tvm.meta_schedule.utils import make_logging_func
zxybazh marked this conversation as resolved.
Show resolved Hide resolved

from .. import _ffi_api
from ..builder import Builder
Expand Down Expand Up @@ -48,6 +50,8 @@ class RoundRobin(TaskScheduler):
The database of the scheduler.
measure_callbacks: Optional[List[MeasureCallback]] = None
The list of measure callbacks of the scheduler.
logger: Optional[logging.Logger]
The logger of the task scheduler.
"""

def __init__(
Expand All @@ -61,6 +65,7 @@ def __init__(
*,
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
logger: Optional[logging.Logger] = None,
) -> None:
"""Constructor.

Expand All @@ -82,6 +87,8 @@ def __init__(
The cost model.
measure_callbacks: Optional[List[MeasureCallback]]
The list of measure callbacks of the scheduler.
logger: Optional[logging.Logger]
The logger of the task scheduler.
"""
del task_weights
self.__init_handle_by_constructor__(
Expand All @@ -93,4 +100,5 @@ def __init__(
max_trials,
cost_model,
measure_callbacks,
make_logging_func(logger),
)
10 changes: 10 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# under the License.
"""Auto-tuning Task Scheduler"""

import logging
from typing import Callable, List, Optional

from tvm._ffi import register_object
from tvm.meta_schedule.utils import make_logging_func
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
from tvm.runtime import Object

from .. import _ffi_api
Expand Down Expand Up @@ -50,6 +52,8 @@ class TaskScheduler(Object):
The cost model used for search.
measure_callbacks: List[MeasureCallback] = None
The list of measure callbacks of the scheduler.
logger: Optional[logging.Logger]
The logger of the task scheduler.
num_trials_already : int
The number of trials already conducted.
"""
Expand All @@ -62,6 +66,7 @@ class TaskScheduler(Object):
cost_model: Optional[CostModel]
measure_callbacks: List[MeasureCallback]
num_trials_already: int
logger: Optional[logging.Logger]

def tune(self) -> None:
"""Auto-tuning."""
Expand Down Expand Up @@ -131,6 +136,7 @@ def __init__(
max_trials: int,
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
logger: Optional[logging.Logger] = None,
f_tune: Callable = None,
f_initialize_task: Callable = None,
f_touch_task: Callable = None,
Expand All @@ -148,6 +154,7 @@ def __init__(
max_trials,
cost_model,
measure_callbacks,
make_logging_func(logger),
f_tune,
f_initialize_task,
f_touch_task,
Expand All @@ -174,6 +181,7 @@ class PyTaskScheduler:
"max_trials",
"cost_model",
"measure_callbacks",
"logger",
],
"methods": [
"tune",
Expand All @@ -193,6 +201,7 @@ def __init__(
max_trials: int,
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
logger: Optional[logging.Logger] = None,
):
self.tasks = tasks
self.builder = builder
Expand All @@ -201,6 +210,7 @@ def __init__(
self.max_trials = max_trials
self.cost_model = cost_model
self.measure_callbacks = measure_callbacks
self.logger = logger

def tune(self) -> None:
"""Auto-tuning."""
Expand Down
Loading