Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Logging Interface Unification #11157

Merged
merged 19 commits into from
May 4, 2022
5 changes: 4 additions & 1 deletion include/tvm/meta_schedule/apply_history_best.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class ApplyHistoryBestNode : public runtime::Object {
public:
/*! \brief The database to be queried from */
Database database{nullptr};
/*! \brief The logging function to be used */
PackedFunc logging_func;

void VisitAttrs(AttrVisitor* v) { v->Visit("database", &database); }
/*!
Expand All @@ -58,8 +60,9 @@ class ApplyHistoryBest : public runtime::ObjectRef {
/*!
* \brief Constructor
* \param database The database to be queried from
* \param logging_func The logging function to use
*/
explicit ApplyHistoryBest(Database database);
explicit ApplyHistoryBest(Database database, PackedFunc logging_func);
/*!
* \brief The current ApplyHistoryBest in the context
* \return The ApplyHistoryBest in the current scope.
Expand Down
23 changes: 16 additions & 7 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class TaskSchedulerNode : public runtime::Object {
Array<MeasureCallback> measure_callbacks;
/*! \brief The number of trials already conducted. */
int num_trials_already;
/*! \brief The tuning task's logging function. t*/
PackedFunc logging_func;

/*! \brief The default destructor. */
virtual ~TaskSchedulerNode() = default;
Expand All @@ -96,6 +98,7 @@ class TaskSchedulerNode : public runtime::Object {
v->Visit("cost_model", &cost_model);
v->Visit("measure_callbacks", &measure_callbacks);
v->Visit("num_trials_already", &num_trials_already);
// `logging_func` is not visited
}

/*! \brief Auto-tuning. */
Expand Down Expand Up @@ -234,15 +237,17 @@ class TaskScheduler : public runtime::ObjectRef {
* \param max_trials The maximum number of trials.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \param logging_func The tuning task's logging function.
* \return The task scheduler created.
*/
TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database, //
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks);
TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database, //
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
PackedFunc logging_func);
/*!
* \brief Create a task scheduler that fetches tasks in a gradient based fashion.
* \param tasks The tasks to be tuned.
Expand All @@ -253,6 +258,7 @@ class TaskScheduler : public runtime::ObjectRef {
* \param max_trials The maximum number of trials.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \param logging_func The tuning task's logging function.
* \param alpha The parameter alpha to control gradient computation.
* \param window_size The parameter to control backward window size.
* \param seed The random seed.
Expand All @@ -266,6 +272,7 @@ class TaskScheduler : public runtime::ObjectRef {
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
PackedFunc logging_func, //
double alpha, //
int window_size, //
support::LinearCongruentialEngine::TRandState seed);
Expand All @@ -278,6 +285,7 @@ class TaskScheduler : public runtime::ObjectRef {
* \param max_trials The maximum number of trials.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \param logging_func The tuning task's logging function.
* \param f_tune The packed function of `Tune`.
* \param f_initialize_task The packed function of `InitializeTask`.
* \param f_touch_task The packed function of `TouchTask`.
Expand All @@ -293,6 +301,7 @@ class TaskScheduler : public runtime::ObjectRef {
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
PackedFunc logging_func, //
PyTaskSchedulerNode::FTune f_tune, //
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
PyTaskSchedulerNode::FTouchTask f_touch_task, //
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class TuneContextNode : public runtime::Object {
Map<Mutator, FloatImm> mutator_probs;
/*! \brief The name of the tuning task. */
Optional<String> task_name;
/*! \brief The tuning task's logging function. t*/
PackedFunc logging_func;
/*! \brief The random state. */
support::LinearCongruentialEngine::TRandState rand_state;
/*! \brief The number of threads to be used. */
Expand Down Expand Up @@ -85,6 +87,7 @@ class TuneContextNode : public runtime::Object {
v->Visit("builder_results", &builder_results);
v->Visit("runner_futures", &runner_futures);
v->Visit("measure_candidates", &measure_candidates);
// `logging_func` is not visited
}

/*! \brief Initialize members that needs initialization with tune context. */
Expand All @@ -110,6 +113,7 @@ class TuneContext : public runtime::ObjectRef {
* \param postprocs The postprocessors.
* \param mutator_probs The probability of using certain mutator.
* \param task_name The name of the tuning task.
* \param logging_func The tuning task's logging function.
* \param rand_state The random state.
* \param num_threads The number of threads to be used.
*/
Expand All @@ -121,6 +125,7 @@ class TuneContext : public runtime::ObjectRef {
Optional<Array<Postproc>> postprocs, //
Optional<Map<Mutator, FloatImm>> mutator_probs, //
Optional<String> task_name, //
PackedFunc logging_func, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TuneContext, ObjectRef, TuneContextNode);
Expand Down
13 changes: 8 additions & 5 deletions python/tvm/meta_schedule/apply_history_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""A context manager that injects the best tuning record in the database into compilation"""
import logging
from typing import List, Optional, Union

from tvm._ffi import register_object
Expand All @@ -24,6 +25,9 @@

from . import _ffi_api
from .database import Database
from .utils import make_logging_func

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


@register_object("meta_schedule.ApplyHistoryBest")
Expand All @@ -38,11 +42,10 @@ class ApplyHistoryBest(Object):

database: Database

def __init__(
self,
database: Database,
) -> None:
self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member
def __init__(self, database: Database) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ApplyHistoryBest, database, make_logging_func(logger) # type: ignore # pylint: disable=no-member
)

def query(
self,
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/meta_schedule/builder/local_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
from tvm.target import Target

from ...contrib.popen_pool import MapResult, PopenPoolExecutor, StatusKind
from ..utils import cpu_count, derived_object, get_global_func_with_default_on_worker
from ..utils import (
cpu_count,
derived_object,
get_global_func_with_default_on_worker,
)
from .builder import BuilderInput, BuilderResult, PyBuilder

logger = logging.getLogger(__name__) # pylint: disable=invalid-name
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/meta_schedule/runner/local_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Local Runner"""
from contextlib import contextmanager
import logging
from contextlib import contextmanager
from typing import Callable, List, Optional, Union

import tvm
Expand All @@ -33,6 +33,7 @@
run_evaluator_common,
)


logger = logging.getLogger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -293,7 +294,7 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
try:
result: List[float] = future.result()
error_message: str = None
except TimeoutError as exception:
except TimeoutError:
result = None
error_message = f"LocalRunner: Timeout, killed after {self.timeout_sec} seconds\n"
except Exception as exception: # pylint: disable=broad-except
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/runner/rpc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""RPC Runner"""
import concurrent.futures
import logging
import concurrent.futures
import os.path as osp
from contextlib import contextmanager
from typing import Callable, List, Optional, Union
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/gradient_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Gradient Based Task Scheduler"""
import logging
from typing import TYPE_CHECKING, List, Optional

from tvm._ffi import register_object
Expand All @@ -25,11 +26,14 @@
from ..database import Database
from ..measure_callback import MeasureCallback
from ..runner import Runner
from ..utils import make_logging_func
from .task_scheduler import TaskScheduler

if TYPE_CHECKING:
from ..tune_context import TuneContext

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


@register_object("meta_schedule.GradientBased")
class GradientBased(TaskScheduler):
Expand Down Expand Up @@ -87,6 +91,7 @@ def __init__(
max_trials,
cost_model,
measure_callbacks,
make_logging_func(logger),
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
alpha,
window_size,
seed,
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/round_robin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""Round Robin Task Scheduler"""

import logging
from typing import TYPE_CHECKING, List, Optional

from tvm._ffi import register_object
Expand All @@ -26,11 +27,14 @@
from ..cost_model import CostModel
from ..database import Database
from ..runner import Runner
from ..utils import make_logging_func
from .task_scheduler import TaskScheduler

if TYPE_CHECKING:
from ..tune_context import TuneContext

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


@register_object("meta_schedule.RoundRobin")
class RoundRobin(TaskScheduler):
Expand Down Expand Up @@ -93,4 +97,5 @@ def __init__(
max_trials,
cost_model,
measure_callbacks,
make_logging_func(logger),
)
6 changes: 6 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""Auto-tuning Task Scheduler"""

import logging
from typing import Callable, List, Optional

from tvm._ffi import register_object
Expand All @@ -28,6 +29,10 @@
from ..measure_callback import MeasureCallback
from ..runner import Runner, RunnerResult
from ..tune_context import TuneContext
from ..utils import make_logging_func


logger = logging.getLogger(__name__) # pylint: disable=invalid-name


@register_object("meta_schedule.TaskScheduler")
Expand Down Expand Up @@ -148,6 +153,7 @@ def __init__(
max_trials,
cost_model,
measure_callbacks,
make_logging_func(logger),
f_tune,
f_initialize_task,
f_touch_task,
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ def _parse_args():
return parsed


logging.basicConfig()
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
logging.basicConfig(
format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
logging.getLogger("tvm.meta_schedule").setLevel(logging.INFO)
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
ARGS = _parse_args()


Expand Down
6 changes: 4 additions & 2 deletions python/tvm/meta_schedule/testing/tune_te_meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def _parse_args():
return parsed


logging.basicConfig()
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
logging.basicConfig(
format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
logging.getLogger("tvm.meta_schedule").setLevel(logging.INFO)
ARGS = _parse_args()


Expand Down
Loading