Skip to content

Commit

Permalink
[MetaSchedule] Logging Interface Unification (apache#11157)
Browse files Browse the repository at this point in the history
* Implement new logging interface.

* Major interface usage update.

* Functionality fix.

* Switch logging conditions.

* Tweak logging interface.

* Minor fix.

* Feature updates.

* Logging usage.

* Linting.

* Fix linting.

* Fix handler type.

* Fix issues.

* Nits.

* Address issues.

* Add DEBUG level fall back.

* Minor fixes.

* Allow parameterized configuration.

* Linting.

* Polish interface.
  • Loading branch information
zxybazh authored and Sergey Shtin committed May 17, 2022
1 parent 1929bd4 commit b3e0bc0
Show file tree
Hide file tree
Showing 27 changed files with 451 additions and 105 deletions.
5 changes: 4 additions & 1 deletion include/tvm/meta_schedule/apply_history_best.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class ApplyHistoryBestNode : public runtime::Object {
public:
/*! \brief The database to be queried from */
Database database{nullptr};
/*! \brief The logging function to be used */
PackedFunc logging_func;

void VisitAttrs(AttrVisitor* v) { v->Visit("database", &database); }
/*!
Expand All @@ -58,8 +60,9 @@ class ApplyHistoryBest : public runtime::ObjectRef {
/*!
* \brief Constructor
* \param database The database to be queried from
* \param logging_func The logging function to use
*/
explicit ApplyHistoryBest(Database database);
explicit ApplyHistoryBest(Database database, PackedFunc logging_func);
/*!
* \brief The current ApplyHistoryBest in the context
* \return The ApplyHistoryBest in the current scope.
Expand Down
23 changes: 16 additions & 7 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class TaskSchedulerNode : public runtime::Object {
Array<MeasureCallback> measure_callbacks;
/*! \brief The number of trials already conducted. */
int num_trials_already;
/*! \brief The tuning task's logging function. t*/
PackedFunc logging_func;

/*! \brief The default destructor. */
virtual ~TaskSchedulerNode() = default;
Expand All @@ -96,6 +98,7 @@ class TaskSchedulerNode : public runtime::Object {
v->Visit("cost_model", &cost_model);
v->Visit("measure_callbacks", &measure_callbacks);
v->Visit("num_trials_already", &num_trials_already);
// `logging_func` is not visited
}

/*! \brief Auto-tuning. */
Expand Down Expand Up @@ -234,15 +237,17 @@ class TaskScheduler : public runtime::ObjectRef {
* \param max_trials The maximum number of trials.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \param logging_func The tuning task's logging function.
* \return The task scheduler created.
*/
TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database, //
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks);
TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database, //
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
PackedFunc logging_func);
/*!
* \brief Create a task scheduler that fetches tasks in a gradient based fashion.
* \param tasks The tasks to be tuned.
Expand All @@ -253,6 +258,7 @@ class TaskScheduler : public runtime::ObjectRef {
* \param max_trials The maximum number of trials.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \param logging_func The tuning task's logging function.
* \param alpha The parameter alpha to control gradient computation.
* \param window_size The parameter to control backward window size.
* \param seed The random seed.
Expand All @@ -266,6 +272,7 @@ class TaskScheduler : public runtime::ObjectRef {
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
PackedFunc logging_func, //
double alpha, //
int window_size, //
support::LinearCongruentialEngine::TRandState seed);
Expand All @@ -278,6 +285,7 @@ class TaskScheduler : public runtime::ObjectRef {
* \param max_trials The maximum number of trials.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \param logging_func The tuning task's logging function.
* \param f_tune The packed function of `Tune`.
* \param f_initialize_task The packed function of `InitializeTask`.
* \param f_touch_task The packed function of `TouchTask`.
Expand All @@ -293,6 +301,7 @@ class TaskScheduler : public runtime::ObjectRef {
int max_trials, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
PackedFunc logging_func, //
PyTaskSchedulerNode::FTune f_tune, //
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
PyTaskSchedulerNode::FTouchTask f_touch_task, //
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class TuneContextNode : public runtime::Object {
Map<Mutator, FloatImm> mutator_probs;
/*! \brief The name of the tuning task. */
Optional<String> task_name;
/*! \brief The tuning task's logging function. t*/
PackedFunc logging_func;
/*! \brief The random state. */
support::LinearCongruentialEngine::TRandState rand_state;
/*! \brief The number of threads to be used. */
Expand Down Expand Up @@ -85,6 +87,7 @@ class TuneContextNode : public runtime::Object {
v->Visit("builder_results", &builder_results);
v->Visit("runner_futures", &runner_futures);
v->Visit("measure_candidates", &measure_candidates);
// `logging_func` is not visited
}

/*! \brief Initialize members that needs initialization with tune context. */
Expand All @@ -110,6 +113,7 @@ class TuneContext : public runtime::ObjectRef {
* \param postprocs The postprocessors.
* \param mutator_probs The probability of using certain mutator.
* \param task_name The name of the tuning task.
* \param logging_func The tuning task's logging function.
* \param rand_state The random state.
* \param num_threads The number of threads to be used.
*/
Expand All @@ -121,6 +125,7 @@ class TuneContext : public runtime::ObjectRef {
Optional<Array<Postproc>> postprocs, //
Optional<Map<Mutator, FloatImm>> mutator_probs, //
Optional<String> task_name, //
PackedFunc logging_func, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TuneContext, ObjectRef, TuneContextNode);
Expand Down
13 changes: 8 additions & 5 deletions python/tvm/meta_schedule/apply_history_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""A context manager that injects the best tuning record in the database into compilation"""
import logging
from typing import List, Optional, Union

from tvm._ffi import register_object
Expand All @@ -24,6 +25,9 @@

from . import _ffi_api
from .database import Database
from .utils import make_logging_func

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


@register_object("meta_schedule.ApplyHistoryBest")
Expand All @@ -38,11 +42,10 @@ class ApplyHistoryBest(Object):

database: Database

def __init__(
self,
database: Database,
) -> None:
self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member
def __init__(self, database: Database) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ApplyHistoryBest, database, make_logging_func(logger) # type: ignore # pylint: disable=no-member
)

def query(
self,
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/meta_schedule/builder/local_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
from tvm.target import Target

from ...contrib.popen_pool import MapResult, PopenPoolExecutor, StatusKind
from ..utils import cpu_count, derived_object, get_global_func_with_default_on_worker
from ..utils import (
cpu_count,
derived_object,
get_global_func_with_default_on_worker,
)
from .builder import BuilderInput, BuilderResult, PyBuilder

logger = logging.getLogger(__name__) # pylint: disable=invalid-name
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/meta_schedule/runner/local_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Local Runner"""
from contextlib import contextmanager
import logging
from contextlib import contextmanager
from typing import Callable, List, Optional, Union

import tvm
Expand All @@ -33,6 +33,7 @@
run_evaluator_common,
)


logger = logging.getLogger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -293,7 +294,7 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
try:
result: List[float] = future.result()
error_message: str = None
except TimeoutError as exception:
except TimeoutError:
result = None
error_message = f"LocalRunner: Timeout, killed after {self.timeout_sec} seconds\n"
except Exception as exception: # pylint: disable=broad-except
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/runner/rpc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""RPC Runner"""
import concurrent.futures
import logging
import concurrent.futures
import os.path as osp
from contextlib import contextmanager
from typing import Callable, List, Optional, Union
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/gradient_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Gradient Based Task Scheduler"""
import logging
from typing import TYPE_CHECKING, List, Optional

from tvm._ffi import register_object
Expand All @@ -25,11 +26,14 @@
from ..database import Database
from ..measure_callback import MeasureCallback
from ..runner import Runner
from ..utils import make_logging_func
from .task_scheduler import TaskScheduler

if TYPE_CHECKING:
from ..tune_context import TuneContext

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


@register_object("meta_schedule.GradientBased")
class GradientBased(TaskScheduler):
Expand Down Expand Up @@ -87,6 +91,7 @@ def __init__(
max_trials,
cost_model,
measure_callbacks,
make_logging_func(logger),
alpha,
window_size,
seed,
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/round_robin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""Round Robin Task Scheduler"""

import logging
from typing import TYPE_CHECKING, List, Optional

from tvm._ffi import register_object
Expand All @@ -26,11 +27,14 @@
from ..cost_model import CostModel
from ..database import Database
from ..runner import Runner
from ..utils import make_logging_func
from .task_scheduler import TaskScheduler

if TYPE_CHECKING:
from ..tune_context import TuneContext

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


@register_object("meta_schedule.RoundRobin")
class RoundRobin(TaskScheduler):
Expand Down Expand Up @@ -93,4 +97,5 @@ def __init__(
max_trials,
cost_model,
measure_callbacks,
make_logging_func(logger),
)
6 changes: 6 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""Auto-tuning Task Scheduler"""

import logging
from typing import Callable, List, Optional

from tvm._ffi import register_object
Expand All @@ -28,6 +29,10 @@
from ..measure_callback import MeasureCallback
from ..runner import Runner, RunnerResult
from ..tune_context import TuneContext
from ..utils import make_logging_func


logger = logging.getLogger(__name__) # pylint: disable=invalid-name


@register_object("meta_schedule.TaskScheduler")
Expand Down Expand Up @@ -148,6 +153,7 @@ def __init__(
max_trials,
cost_model,
measure_callbacks,
make_logging_func(logger),
f_tune,
f_initialize_task,
f_touch_task,
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ def _parse_args():
return parsed


logging.basicConfig()
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
logging.basicConfig(
format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
logging.getLogger("tvm.meta_schedule").setLevel(logging.INFO)
ARGS = _parse_args()


Expand Down
6 changes: 4 additions & 2 deletions python/tvm/meta_schedule/testing/tune_te_meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def _parse_args():
return parsed


logging.basicConfig()
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
logging.basicConfig(
format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
logging.getLogger("tvm.meta_schedule").setLevel(logging.INFO)
ARGS = _parse_args()


Expand Down
Loading

0 comments on commit b3e0bc0

Please sign in to comment.