diff --git a/include/tvm/meta_schedule/apply_history_best.h b/include/tvm/meta_schedule/apply_history_best.h index 9d6f46dd6c43..b5504a8ee0f8 100644 --- a/include/tvm/meta_schedule/apply_history_best.h +++ b/include/tvm/meta_schedule/apply_history_best.h @@ -33,6 +33,8 @@ class ApplyHistoryBestNode : public runtime::Object { public: /*! \brief The database to be queried from */ Database database{nullptr}; + /*! \brief The logging function to be used */ + PackedFunc logging_func; void VisitAttrs(AttrVisitor* v) { v->Visit("database", &database); } /*! @@ -58,8 +60,9 @@ class ApplyHistoryBest : public runtime::ObjectRef { /*! * \brief Constructor * \param database The database to be queried from + * \param logging_func The logging function to use */ - explicit ApplyHistoryBest(Database database); + explicit ApplyHistoryBest(Database database, PackedFunc logging_func); /*! * \brief The current ApplyHistoryBest in the context * \return The ApplyHistoryBest in the current scope. diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 81d340d33e6b..7453c2b484b9 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -83,6 +83,8 @@ class TaskSchedulerNode : public runtime::Object { Array measure_callbacks; /*! \brief The number of trials already conducted. */ int num_trials_already; + /*! \brief The tuning task's logging function. t*/ + PackedFunc logging_func; /*! \brief The default destructor. */ virtual ~TaskSchedulerNode() = default; @@ -96,6 +98,7 @@ class TaskSchedulerNode : public runtime::Object { v->Visit("cost_model", &cost_model); v->Visit("measure_callbacks", &measure_callbacks); v->Visit("num_trials_already", &num_trials_already); + // `logging_func` is not visited } /*! \brief Auto-tuning. */ @@ -234,15 +237,17 @@ class TaskScheduler : public runtime::ObjectRef { * \param max_trials The maximum number of trials. * \param cost_model The cost model of the scheduler. * \param measure_callbacks The measure callbacks of the scheduler. + * \param logging_func The tuning task's logging function. * \return The task scheduler created. */ - TVM_DLL static TaskScheduler RoundRobin(Array tasks, // - Builder builder, // - Runner runner, // - Database database, // - int max_trials, // - Optional cost_model, // - Optional> measure_callbacks); + TVM_DLL static TaskScheduler RoundRobin(Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + int max_trials, // + Optional cost_model, // + Optional> measure_callbacks, // + PackedFunc logging_func); /*! * \brief Create a task scheduler that fetches tasks in a gradient based fashion. * \param tasks The tasks to be tuned. @@ -253,6 +258,7 @@ class TaskScheduler : public runtime::ObjectRef { * \param max_trials The maximum number of trials. * \param cost_model The cost model of the scheduler. * \param measure_callbacks The measure callbacks of the scheduler. + * \param logging_func The tuning task's logging function. * \param alpha The parameter alpha to control gradient computation. * \param window_size The parameter to control backward window size. * \param seed The random seed. @@ -266,6 +272,7 @@ class TaskScheduler : public runtime::ObjectRef { int max_trials, // Optional cost_model, // Optional> measure_callbacks, // + PackedFunc logging_func, // double alpha, // int window_size, // support::LinearCongruentialEngine::TRandState seed); @@ -278,6 +285,7 @@ class TaskScheduler : public runtime::ObjectRef { * \param max_trials The maximum number of trials. * \param cost_model The cost model of the scheduler. * \param measure_callbacks The measure callbacks of the scheduler. + * \param logging_func The tuning task's logging function. * \param f_tune The packed function of `Tune`. * \param f_initialize_task The packed function of `InitializeTask`. * \param f_touch_task The packed function of `TouchTask`. @@ -293,6 +301,7 @@ class TaskScheduler : public runtime::ObjectRef { int max_trials, // Optional cost_model, // Optional> measure_callbacks, // + PackedFunc logging_func, // PyTaskSchedulerNode::FTune f_tune, // PyTaskSchedulerNode::FInitializeTask f_initialize_task, // PyTaskSchedulerNode::FTouchTask f_touch_task, // diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 1d2978c90533..faa24fc99f4c 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -54,6 +54,8 @@ class TuneContextNode : public runtime::Object { Map mutator_probs; /*! \brief The name of the tuning task. */ Optional task_name; + /*! \brief The tuning task's logging function. t*/ + PackedFunc logging_func; /*! \brief The random state. */ support::LinearCongruentialEngine::TRandState rand_state; /*! \brief The number of threads to be used. */ @@ -85,6 +87,7 @@ class TuneContextNode : public runtime::Object { v->Visit("builder_results", &builder_results); v->Visit("runner_futures", &runner_futures); v->Visit("measure_candidates", &measure_candidates); + // `logging_func` is not visited } /*! \brief Initialize members that needs initialization with tune context. */ @@ -110,6 +113,7 @@ class TuneContext : public runtime::ObjectRef { * \param postprocs The postprocessors. * \param mutator_probs The probability of using certain mutator. * \param task_name The name of the tuning task. + * \param logging_func The tuning task's logging function. * \param rand_state The random state. * \param num_threads The number of threads to be used. */ @@ -121,6 +125,7 @@ class TuneContext : public runtime::ObjectRef { Optional> postprocs, // Optional> mutator_probs, // Optional task_name, // + PackedFunc logging_func, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TuneContext, ObjectRef, TuneContextNode); diff --git a/python/tvm/meta_schedule/apply_history_best.py b/python/tvm/meta_schedule/apply_history_best.py index 5e1e40bd154b..bcde7c97b04d 100644 --- a/python/tvm/meta_schedule/apply_history_best.py +++ b/python/tvm/meta_schedule/apply_history_best.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """A context manager that injects the best tuning record in the database into compilation""" +import logging from typing import List, Optional, Union from tvm._ffi import register_object @@ -24,6 +25,9 @@ from . import _ffi_api from .database import Database +from .utils import make_logging_func + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name @register_object("meta_schedule.ApplyHistoryBest") @@ -38,11 +42,10 @@ class ApplyHistoryBest(Object): database: Database - def __init__( - self, - database: Database, - ) -> None: - self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member + def __init__(self, database: Database) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ApplyHistoryBest, database, make_logging_func(logger) # type: ignore # pylint: disable=no-member + ) def query( self, diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py index eb1b1f377b43..6f0f523b475d 100644 --- a/python/tvm/meta_schedule/builder/local_builder.py +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -26,7 +26,11 @@ from tvm.target import Target from ...contrib.popen_pool import MapResult, PopenPoolExecutor, StatusKind -from ..utils import cpu_count, derived_object, get_global_func_with_default_on_worker +from ..utils import ( + cpu_count, + derived_object, + get_global_func_with_default_on_worker, +) from .builder import BuilderInput, BuilderResult, PyBuilder logger = logging.getLogger(__name__) # pylint: disable=invalid-name diff --git a/python/tvm/meta_schedule/runner/local_runner.py b/python/tvm/meta_schedule/runner/local_runner.py index a574699b8b5f..d76fe0b840a4 100644 --- a/python/tvm/meta_schedule/runner/local_runner.py +++ b/python/tvm/meta_schedule/runner/local_runner.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """Local Runner""" -from contextlib import contextmanager import logging +from contextlib import contextmanager from typing import Callable, List, Optional, Union import tvm @@ -33,6 +33,7 @@ run_evaluator_common, ) + logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -293,7 +294,7 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: try: result: List[float] = future.result() error_message: str = None - except TimeoutError as exception: + except TimeoutError: result = None error_message = f"LocalRunner: Timeout, killed after {self.timeout_sec} seconds\n" except Exception as exception: # pylint: disable=broad-except diff --git a/python/tvm/meta_schedule/runner/rpc_runner.py b/python/tvm/meta_schedule/runner/rpc_runner.py index 5697f85f229e..16e422cc6073 100644 --- a/python/tvm/meta_schedule/runner/rpc_runner.py +++ b/python/tvm/meta_schedule/runner/rpc_runner.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """RPC Runner""" -import concurrent.futures import logging +import concurrent.futures import os.path as osp from contextlib import contextmanager from typing import Callable, List, Optional, Union diff --git a/python/tvm/meta_schedule/task_scheduler/gradient_based.py b/python/tvm/meta_schedule/task_scheduler/gradient_based.py index b0b13001382a..6234449bf09b 100644 --- a/python/tvm/meta_schedule/task_scheduler/gradient_based.py +++ b/python/tvm/meta_schedule/task_scheduler/gradient_based.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Gradient Based Task Scheduler""" +import logging from typing import TYPE_CHECKING, List, Optional from tvm._ffi import register_object @@ -25,11 +26,14 @@ from ..database import Database from ..measure_callback import MeasureCallback from ..runner import Runner +from ..utils import make_logging_func from .task_scheduler import TaskScheduler if TYPE_CHECKING: from ..tune_context import TuneContext +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + @register_object("meta_schedule.GradientBased") class GradientBased(TaskScheduler): @@ -87,6 +91,7 @@ def __init__( max_trials, cost_model, measure_callbacks, + make_logging_func(logger), alpha, window_size, seed, diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index 6634d6193e26..a46135828394 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -16,6 +16,7 @@ # under the License. """Round Robin Task Scheduler""" +import logging from typing import TYPE_CHECKING, List, Optional from tvm._ffi import register_object @@ -26,11 +27,14 @@ from ..cost_model import CostModel from ..database import Database from ..runner import Runner +from ..utils import make_logging_func from .task_scheduler import TaskScheduler if TYPE_CHECKING: from ..tune_context import TuneContext +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + @register_object("meta_schedule.RoundRobin") class RoundRobin(TaskScheduler): @@ -93,4 +97,5 @@ def __init__( max_trials, cost_model, measure_callbacks, + make_logging_func(logger), ) diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index d3bc25c1e03a..4454078a6f16 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -16,6 +16,7 @@ # under the License. """Auto-tuning Task Scheduler""" +import logging from typing import Callable, List, Optional from tvm._ffi import register_object @@ -28,6 +29,10 @@ from ..measure_callback import MeasureCallback from ..runner import Runner, RunnerResult from ..tune_context import TuneContext +from ..utils import make_logging_func + + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name @register_object("meta_schedule.TaskScheduler") @@ -148,6 +153,7 @@ def __init__( max_trials, cost_model, measure_callbacks, + make_logging_func(logger), f_tune, f_initialize_task, f_touch_task, diff --git a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py index d8e6d38695ac..88de0c336073 100644 --- a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py +++ b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py @@ -90,8 +90,10 @@ def _parse_args(): return parsed -logging.basicConfig() -logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +logging.basicConfig( + format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" +) +logging.getLogger("tvm.meta_schedule").setLevel(logging.INFO) ARGS = _parse_args() diff --git a/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py index 2e8b538b9cc9..b65761ba4fe5 100644 --- a/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py +++ b/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py @@ -79,8 +79,10 @@ def _parse_args(): return parsed -logging.basicConfig() -logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +logging.basicConfig( + format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" +) +logging.getLogger("tvm.meta_schedule").setLevel(logging.INFO) ARGS = _parse_args() diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 0cdb03d20f5c..82d99295ff1d 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -17,7 +17,9 @@ """User-facing Tuning API""" # pylint: disable=import-outside-toplevel import logging -import os.path +import logging.config +import os +from os import path as osp from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union from tvm._ffi.registry import register_func @@ -43,7 +45,7 @@ from .space_generator import PostOrderApply, SpaceGenerator from .task_scheduler import GradientBased, RoundRobin from .tune_context import TuneContext -from .utils import autotvm_silencer +from .utils import autotvm_silencer, batch_parameterize_config logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -226,8 +228,8 @@ def _runner(runner: Optional[Runner]) -> Runner: @staticmethod def _database(database: Union[None, Database], path: str) -> Database: if database is None: - path_workload = os.path.join(path, "database_workload.json") - path_tuning_record = os.path.join(path, "database_tuning_record.json") + path_workload = osp.join(path, "database_workload.json") + path_tuning_record = osp.join(path, "database_tuning_record.json") logger.info( "Creating JSONDatabase. Workload at: %s. Tuning records at: %s", path_workload, @@ -358,6 +360,8 @@ class TuneConfig(NamedTuple): Configuration for task scheduler. search_strategy_config: Optional[Dict[str, Any]] = None Configuration for search strategy. + logger_config: Optional[Dict[str, Any]] = None + Configuration for logger. """ max_trials_global: int @@ -367,6 +371,7 @@ class TuneConfig(NamedTuple): strategy: str = "evolutionary" task_scheduler_config: Optional[Dict[str, Any]] = None search_strategy_config: Optional[Dict[str, Any]] = None + logger_config: Optional[Dict[str, Any]] = None def create_strategy(self, **kwargs): """Create search strategy from configuration""" @@ -416,6 +421,96 @@ def create_task_scheduler(self, **kwargs): **config, ) + def create_loggers( + self, + log_dir: str, + params: List[Dict[str, Any]], + disable_existing_loggers: bool = False, + ): + """Create loggers from configuration""" + if self.logger_config is None: + config = {} + else: + config = self.logger_config + + global_logger_name = "tvm.meta_schedule" + config.setdefault("loggers", {}) + config.setdefault("handlers", {}) + config.setdefault("formatters", {}) + + config["loggers"].setdefault( + global_logger_name, + { + "level": "INFO", + "handlers": [global_logger_name + ".console", global_logger_name + ".file"], + "propagate": False, + }, + ) + config["loggers"].setdefault( + "{logger_name}", + { + "level": "INFO", + "handlers": [ + "{logger_name}.file", + ], + "propagate": False, + }, + ) + config["handlers"].setdefault( + global_logger_name + ".console", + { + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + "formatter": "tvm.meta_schedule.standard_formatter", + }, + ) + config["handlers"].setdefault( + global_logger_name + ".file", + { + "class": "logging.FileHandler", + "filename": "{log_dir}/" + __name__ + ".task_scheduler.log", + "mode": "a", + "level": "INFO", + "formatter": "tvm.meta_schedule.standard_formatter", + }, + ) + config["handlers"].setdefault( + "{logger_name}.file", + { + "class": "logging.FileHandler", + "filename": "{log_dir}/{logger_name}.log", + "mode": "a", + "level": "INFO", + "formatter": "tvm.meta_schedule.standard_formatter", + }, + ) + config["formatters"].setdefault( + "tvm.meta_schedule.standard_formatter", + { + "format": "%(asctime)s.%(msecs)03d %(levelname)s %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S", + }, + ) + + # set up dictConfig loggers + p_config = {"version": 1, "disable_existing_loggers": disable_existing_loggers} + for k, v in config.items(): + if k in ["formatters", "handlers", "loggers"]: + p_config[k] = batch_parameterize_config(v, params) # type: ignore + else: + p_config[k] = v + logging.config.dictConfig(p_config) + + # check global logger + global_logger = logging.getLogger(global_logger_name) + if global_logger.level not in [logging.DEBUG, logging.INFO]: + global_logger.critical( + "Logging level set to %s, please set to logging.INFO" + " or logging.DEBUG to view full log.", + logging._levelToName[logger.level], # pylint: disable=protected-access + ) + global_logger.info("Logging directory: %s", log_dir) + def tune_extracted_tasks( extracted_tasks: List[ExtractedTask], @@ -472,8 +567,25 @@ def tune_extracted_tasks( The database containing all the tuning results. """ - logger.info("Working directory: %s", work_dir) # pylint: disable=protected-access + # logging directory is set to `work_dir/logs` by default + log_dir = osp.join(work_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + max_width = len(str(len(extracted_tasks) - 1)) + logger_name_pattern = __name__ + ".task_{task_id:0" + f"{max_width}" + "d}_{task_name}" + + config.create_loggers( + log_dir=log_dir, + params=[ + { + "log_dir": log_dir, + "logger_name": logger_name_pattern.format(task_id=i, task_name=task.task_name), + } + for i, task in enumerate(extracted_tasks) + ], + ) + + logger.info("Working directory: %s", work_dir) database = Parse._database(database, work_dir) builder = Parse._builder(builder) runner = Parse._runner(runner) @@ -481,7 +593,7 @@ def tune_extracted_tasks( measure_callbacks = Parse._callbacks(measure_callbacks) # parse the tuning contexts tune_contexts = [] - for task in extracted_tasks: + for i, task in enumerate(extracted_tasks): assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now" tune_contexts.append( TuneContext( @@ -493,6 +605,9 @@ def tune_extracted_tasks( postprocs=Parse._postproc(postprocs, task.target), mutator_probs=Parse._mutator_probs(mutator_probs, task.target), task_name=task.task_name, + logger=logging.getLogger( + logger_name_pattern.format(task_id=i, task_name=task.task_name) + ), num_threads=num_threads, ) ) @@ -508,7 +623,7 @@ def tune_extracted_tasks( measure_callbacks=measure_callbacks, ) task_scheduler.tune() - cost_model.save(os.path.join(work_dir, "cost_model.xgb")) + cost_model.save(osp.join(work_dir, "cost_model.xgb")) return database @@ -558,6 +673,15 @@ def tune_tir( sch : Optional[Schedule] The tuned schedule. """ + # logging directory is set to `work_dir/logs` by default + log_dir = osp.join(work_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + + config.create_loggers( + log_dir=log_dir, + params=[{"log_dir": log_dir, "logger_name": __name__ + f".task_{task_name}"}], + ) + # pylint: disable=protected-access mod = Parse._mod(mod) target = Parse._target(target) @@ -712,14 +836,11 @@ def tune_relay( """ # pylint: disable=import-outside-toplevel from tvm.relay import build as relay_build - from .relay_integration import extract_task_from_relay - # pylint: enable=import-outside-toplevel - - logger.info("Working directory: %s", work_dir) - # pylint: disable=protected-access + # pylint: disable=protected-access, enable=import-outside-toplevel target = Parse._target(target) + # pylint: enable=protected-access, # parse the tuning contexts extracted_tasks = extract_task_from_relay(mod, target, params) database = tune_extracted_tasks( diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 196b1c16b6f2..ef2e4bcd8e6d 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -16,11 +16,12 @@ # under the License. """Meta Schedule tuning context.""" +import logging from typing import Optional, List, Dict, TYPE_CHECKING from tvm import IRModule from tvm._ffi import register_object -from tvm.meta_schedule.utils import cpu_count +from tvm.meta_schedule.utils import cpu_count, make_logging_func from tvm.runtime import Object from tvm.target import Target from tvm.tir import PrimFunc @@ -62,6 +63,8 @@ class TuneContext(Object): Mutators and their probability mass. task_name : Optional[str] = None The name of the tuning task. + logger : logging.Logger + The logger for the tuning task. rand_state : int = -1 The random state. Need to be in integer in [1, 2^31-1], -1 means using random number. @@ -84,6 +87,7 @@ class TuneContext(Object): postprocs: List["Postproc"] mutator_probs: Optional[Dict["Mutator", float]] task_name: str + logger: Optional[logging.Logger] rand_state: int num_threads: int @@ -98,6 +102,7 @@ def __init__( postprocs: Optional[List["Postproc"]] = None, mutator_probs: Optional[Dict["Mutator", float]] = None, task_name: str = "main", + logger: Optional[logging.Logger] = None, rand_state: int = -1, num_threads: Optional[int] = None, ): @@ -105,6 +110,10 @@ def __init__( mod = IRModule.from_expr(mod) if num_threads is None: num_threads = cpu_count() + if logger is None: + self.logger = logging.getLogger(__name__) + else: + self.logger = None self.__init_handle_by_constructor__( _ffi_api.TuneContext, # type: ignore # pylint: disable=no-member @@ -116,6 +125,7 @@ def __init__( postprocs, mutator_probs, task_name, + make_logging_func(logger), rand_state, num_threads, ) diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 8ea1c28b2dc6..919a29e6cf6c 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -17,10 +17,11 @@ """Utilities for meta schedule""" import ctypes import json +import logging import os import shutil from contextlib import contextmanager -from typing import Any, Callable, List, Optional, Union +from typing import Any, List, Dict, Callable, Optional, Union import psutil # type: ignore from tvm._ffi import get_global_func, register_func @@ -339,8 +340,11 @@ def shash2hex(mod: IRModule) -> str: def _get_default_str(obj: Any) -> str: return ( - f"meta_schedule.{obj.__class__.__name__}" + f"({_to_hex_address(obj._outer().handle)})" - ) # type: ignore + # pylint: disable=protected-access + f"meta_schedule.{obj.__class__.__name__}" + + f"({_to_hex_address(obj._outer().handle)})" # type: ignore + # pylint: enable=protected-access + ) def _to_hex_address(handle: ctypes.c_void_p) -> str: @@ -368,3 +372,87 @@ def autotvm_silencer(): yield finally: autotvm.GLOBAL_SCOPE.silent = silent + + +def make_logging_func(logger: logging.Logger) -> Optional[Callable]: + """Get the logging function. + Parameters + ---------- + logger : logging.Logger + The logger instance. + Returns + ------- + result : Optional[Callable] + The function to do the specified level of logging. + """ + if logger is None: + return None + + level2log = { + logging.DEBUG: logger.debug, + logging.INFO: logger.info, + logging.WARNING: logger.warning, + logging.ERROR: logger.error, + # logging.FATAL not included + } + + def logging_func(level: int, msg: str): + level2log[level](msg) + + return logging_func + + +def parameterize_config(config: Dict[str, Any], params: Dict[str, str]) -> Dict[str, Any]: + """Parameterize the given configuration. + + Parameters + ---------- + config : Dict[str, Any] + The given config dict. + Params : Dict[str, str] + The given parameters. + + Returns + ------- + result : Dict[str, Any] + The parameterized configuration. + """ + result = {} + for k, v in config.items(): + if isinstance(k, str): + k = k.format(**params) + if isinstance(v, str): + v = v.format(**params) + elif isinstance(v, dict): + v = parameterize_config(v, params) + elif isinstance(v, list): + v = [t.format(**params) for t in v] + result[k] = v + return result + + +def batch_parameterize_config( + config: Dict[str, Any], params: List[Dict[str, str]] +) -> Dict[str, Any]: + """Parameterize the given configuration with multiple parameters sets. + + Parameters + ---------- + config : Dict[str, Any] + The given config dict. + Params : List[Dict[str, str]] + List of the given multiple parameters sets. + + Returns + ------- + result : Dict[str, Any] + The parameterized configuration. + """ + results = {} + for name, cfg in config.items(): + for p in params: + p_name = name.format(**p) + if p_name not in results: + p_cfg = parameterize_config(cfg, p) + results[p_name] = p_cfg + return results diff --git a/src/meta_schedule/apply_history_best.cc b/src/meta_schedule/apply_history_best.cc index 41714cf7b0ce..18135811f5f1 100644 --- a/src/meta_schedule/apply_history_best.cc +++ b/src/meta_schedule/apply_history_best.cc @@ -87,9 +87,10 @@ void ApplyHistoryBest::ExitWithScope() { /**************** ApplyHistoryBest ****************/ -ApplyHistoryBest::ApplyHistoryBest(Database database) { +ApplyHistoryBest::ApplyHistoryBest(Database database, PackedFunc logging_func) { ObjectPtr n = make_object(); n->database = database; + n->logging_func = logging_func; data_ = n; } @@ -122,15 +123,14 @@ Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModu return IRModule({{gv, func}}); } } - LOG(WARNING) << "Cannot find workload: " << task_name; - DLOG(INFO) << tir::AsTVMScript(prim_mod); + TVM_PY_LOG(WARNING, logging_func) << "Cannot find workload: " << task_name; return NullOpt; } TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode); TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest") - .set_body_typed([](Database database) -> ApplyHistoryBest { - return ApplyHistoryBest(database); + .set_body_typed([](Database database, PackedFunc logging_func) -> ApplyHistoryBest { + return ApplyHistoryBest(database, logging_func); }); TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBestEnterScope") .set_body_typed(ApplyHistoryBestInternal::EnterScope); diff --git a/src/meta_schedule/measure_callback/echo_statistics.cc b/src/meta_schedule/measure_callback/echo_statistics.cc index f287596ffbbb..e45f98b52ea0 100644 --- a/src/meta_schedule/measure_callback/echo_statistics.cc +++ b/src/meta_schedule/measure_callback/echo_statistics.cc @@ -39,8 +39,10 @@ struct TaskInfo { double best_ms = kMaxTime; double best_gflops = 0.0; int error_count = 0; + PackedFunc logging_func; - explicit TaskInfo(const String& name) : name(name) {} + explicit TaskInfo(const String& name, PackedFunc logging_func) + : name(name), logging_func(logging_func) {} void Update(double run_ms) { ++trials; @@ -49,11 +51,11 @@ struct TaskInfo { best_round = trials; best_gflops = flop / run_ms / 1e6; } - LOG(INFO) << "[" << name << "] Trial #" << trials // - << std::fixed << std::setprecision(4) // - << ": GFLOPs: " << (flop / run_ms / 1e6) // - << ". Time: " << run_ms << " ms" // - << ". Best GFLOPs: " << best_gflops; + TVM_PY_LOG(INFO, logging_func) << "[" << name << "] Trial #" << trials // + << std::fixed << std::setprecision(4) // + << ": GFLOPs: " << (flop / run_ms / 1e6) // + << ". Time: " << run_ms << " ms" // + << ". Best GFLOPs: " << best_gflops; } void UpdateError(std::string err, const MeasureCandidate& candidate) { @@ -62,11 +64,12 @@ struct TaskInfo { err = (*f_proc)(err).operator std::string(); ++error_count; ++trials; - LOG(INFO) << "[" << name << "] Trial #" << trials // - << std::fixed << std::setprecision(4) // - << ": Error in building: " << err << "\n" - << tir::AsTVMScript(candidate->sch->mod()) << "\n" - << Concat(candidate->sch->trace().value()->AsPython(false), "\n"); + TVM_PY_LOG(INFO, logging_func) + << "[" << name << "] Trial #" << trials // + << std::fixed << std::setprecision(4) // + << ": Error in building: " << err << "\n" + << tir::AsTVMScript(candidate->sch->mod()) << "\n" + << Concat(candidate->sch->trace().value()->AsPython(false), "\n"); } }; @@ -104,7 +107,7 @@ class EchoStatisticsNode : public MeasureCallbackNode { task_info.reserve(tasks.size()); int task_id = 0; for (const TuneContext& task : tasks) { - task_info.push_back(TaskInfo(GetTaskName(task, task_id))); + task_info.push_back(TaskInfo(GetTaskName(task, task_id), task->logging_func)); TaskInfo& info = task_info.back(); info.flop = tir::EstimateTIRFlops(task->mod.value()); ++task_id; diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index 0c8546ccfcdd..242f1aea89c5 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -32,12 +32,14 @@ class CrossThreadReductionNode : public ScheduleRuleNode { Optional opt_warp_size = target->GetAttr("thread_warp_size"); if (!opt_max_threads_per_block.defined()) { - LOG(WARNING) << "Target does not have attribute \"max_threads_per_block\", therefore the " - "rule CrossThreadReduction will not be applied"; + TVM_PY_LOG(WARNING, context->logging_func) + << "Target does not have attribute \"max_threads_per_block\", therefore the " + "rule CrossThreadReduction will not be applied"; } if (!opt_warp_size.defined()) { - LOG(WARNING) << "Target does not have attribute \"thread_warp_size\", therefore the rule " - "CrossThreadReduction will not be applied"; + TVM_PY_LOG(WARNING, context->logging_func) + << "Target does not have attribute \"thread_warp_size\", therefore the rule " + "CrossThreadReduction will not be applied"; } max_threads_per_block = opt_max_threads_per_block.value_or(Integer(-1))->value; warp_size = opt_warp_size.value_or(Integer(-1))->value; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 0a3ea882b5eb..07c5ddd7ae70 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -68,7 +68,7 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) if (Optional v = context->target.value()->GetAttr("thread_warp_size")) { this->thread_warp_size_ = v.value()->value; } else { - LOG(INFO) << "'thread_warp_size' is not defined in the target"; + TVM_PY_LOG(INFO, context->logging_func) << "'thread_warp_size' is not defined in the target"; } } } diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index 365d2d69225d..bdef26ef876e 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -491,7 +491,8 @@ std::vector EvolutionarySearchNode::State::SampleInitPopulation(int nu out_schs.push_back(results[i]); } } - LOG(INFO) << "Sample-Init-Population summary:\n" << pp.SummarizeFailures(); + TVM_PY_LOG(INFO, self->context_->logging_func) << "Sample-Init-Population summary:\n" + << pp.SummarizeFailures(); } return out_schs; } @@ -568,7 +569,8 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( }; support::parallel_for_dynamic(0, self->population_size, self->num_threads_, f_find_candidate); population.swap(next_population); - LOG(INFO) << "Evolve iter #" << iter << " done. Summary:\n" << pp.SummarizeFailures(); + TVM_PY_LOG(INFO, self->context_->logging_func) << "Evolve iter #" << iter << " done. Summary:\n" + << pp.SummarizeFailures(); } // Return the best states from the heap, sorting from higher score to lower ones std::sort(heap.heap.begin(), heap.heap.end()); @@ -592,7 +594,8 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( os << std::fixed << std::setprecision(4) << heap.heap.at(i).score; } } - LOG(INFO) << "Scores of the best " << n << " candidates:" << os.str(); + TVM_PY_LOG(INFO, self->context_->logging_func) + << "Scores of the best " << n << " candidates:" << os.str(); return results; } @@ -653,17 +656,21 @@ Optional> EvolutionarySearchNode::State::GenerateMeasure std::vector inits; inits.reserve(pop); - LOG(INFO) << "Generating candidates......"; + TVM_PY_LOG(INFO, self->context_->logging_func) << "Generating candidates......"; std::vector measured = PickBestFromDatabase(pop * self->init_measured_ratio); - LOG(INFO) << "Picked top " << measured.size() << " candidate(s) from database"; + TVM_PY_LOG(INFO, self->context_->logging_func) + << "Picked top " << measured.size() << " candidate(s) from database"; std::vector unmeasured = SampleInitPopulation(pop - measured.size()); - LOG(INFO) << "Sampled " << unmeasured.size() << " candidate(s)"; + TVM_PY_LOG(INFO, self->context_->logging_func) + << "Sampled " << unmeasured.size() << " candidate(s)"; inits.insert(inits.end(), measured.begin(), measured.end()); inits.insert(inits.end(), unmeasured.begin(), unmeasured.end()); std::vector bests = EvolveWithCostModel(inits, sample_num); - LOG(INFO) << "Got " << bests.size() << " candidate(s) with evolutionary search"; + TVM_PY_LOG(INFO, self->context_->logging_func) + << "Got " << bests.size() << " candidate(s) with evolutionary search"; std::vector picks = PickWithEpsGreedy(unmeasured, bests, sample_num); - LOG(INFO) << "Sending " << picks.size() << " candidates(s) for measurement"; + TVM_PY_LOG(INFO, self->context_->logging_func) + << "Sending " << picks.size() << " candidates(s) for measurement"; if (picks.empty()) { ++this->num_empty_iters; if (this->num_empty_iters >= self->num_empty_iters_before_early_stop) { diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 09c134a101bc..dd1b0cd2cde4 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -79,6 +79,8 @@ class PostOrderApplyNode : public SpaceGeneratorNode { TRandState rand_state_ = -1; /*! \brief The schedule rules to be applied in order. */ Array sch_rules_{nullptr}; + /*! \brief The logging function to use. */ + PackedFunc logging_func; void VisitAttrs(tvm::AttrVisitor* v) { // `rand_state_` is not visited @@ -90,6 +92,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode { CHECK(context->sch_rules.defined()) << "ValueError: Schedules rules not given in PostOrderApply!"; this->sch_rules_ = context->sch_rules; + this->logging_func = context->logging_func; } Array GenerateDesignSpace(const IRModule& mod_) final { @@ -143,8 +146,9 @@ class PostOrderApplyNode : public SpaceGeneratorNode { const bool has_schedule_rule = custom_schedule_fn != nullptr; if (ann.defined() && !has_schedule_rule) { - LOG(WARNING) << "Custom schedule rule not found, ignoring schedule_rule annotation: " - << ann.value(); + TVM_PY_LOG(WARNING, this->logging_func) + << "Custom schedule rule not found, ignoring schedule_rule annotation: " + << ann.value(); } if ((has_schedule_rule && sch_rule.defined()) || diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index 25f4b227aecf..a95dbba6c3e1 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -108,7 +108,7 @@ class GradientBasedNode final : public TaskSchedulerNode { int n_tasks = task_records_.size(); // Round robin if (num_rounds_already_ == 0) { - LOG(INFO) << "\n" << this->TuningStatistics(); + TVM_PY_LOG(INFO, this->logging_func) << "\n" << this->TuningStatistics(); } if (num_rounds_already_ < n_tasks) { return num_rounds_already_++; @@ -169,21 +169,24 @@ class GradientBasedNode final : public TaskSchedulerNode { } record.best_time_cost_history.push_back(best_time_cost); record.trials += results.size(); - LOG(INFO) << "[Updated] Task #" << task_id << ": " << record.task->task_name << "\n" - << this->TuningStatistics(); + TVM_PY_LOG(INFO, this->logging_func) + << "[Updated] Task #" << task_id << ": " << record.task->task_name << "\n" + << this->TuningStatistics(); return results; } }; -TaskScheduler TaskScheduler::GradientBased(Array tasks, // - Array task_weights, // - Builder builder, // - Runner runner, // - Database database, // - int max_trials, // - Optional cost_model, // - Optional> measure_callbacks, - double alpha, int window_size, +TaskScheduler TaskScheduler::GradientBased(Array tasks, // + Array task_weights, // + Builder builder, // + Runner runner, // + Database database, // + int max_trials, // + Optional cost_model, // + Optional> measure_callbacks, // + PackedFunc logging_func, // + double alpha, // + int window_size, // support::LinearCongruentialEngine::TRandState seed) { CHECK_EQ(tasks.size(), task_weights.size()) << "The size of `tasks` should have the same as `task_weights`."; @@ -207,6 +210,7 @@ TaskScheduler TaskScheduler::GradientBased(Array tasks, // n->max_trials = max_trials; n->cost_model = cost_model; n->measure_callbacks = measure_callbacks.value_or({}); + n->logging_func = logging_func; n->num_trials_already = 0; n->alpha = alpha; n->window_size = window_size; diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index a5731af1fc4d..446b11837930 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -55,13 +55,14 @@ class RoundRobinNode final : public TaskSchedulerNode { } }; -TaskScheduler TaskScheduler::RoundRobin(Array tasks, // - Builder builder, // - Runner runner, // - Database database, // - int max_trials, // - Optional cost_model, // - Optional> measure_callbacks) { +TaskScheduler TaskScheduler::RoundRobin(Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + int max_trials, // + Optional cost_model, // + Optional> measure_callbacks, // + PackedFunc logging_func) { ObjectPtr n = make_object(); n->tasks = tasks; n->builder = builder; @@ -70,6 +71,7 @@ TaskScheduler TaskScheduler::RoundRobin(Array tasks, // n->max_trials = max_trials; n->cost_model = cost_model; n->measure_callbacks = measure_callbacks.value_or({}); + n->logging_func = logging_func; n->num_trials_already = 0; n->task_id = -1; for (const TuneContext& task : tasks) { diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index cd287fc1d498..7485f4e076cd 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -27,9 +27,9 @@ namespace meta_schedule { * \param context The tuning context. * \param candidates The measure candidates. */ -void SendToBuilder(const Builder& builder, const TuneContext& context) { +void SendToBuilder(const Builder& builder, const TuneContext& context, PackedFunc logging_func) { Array candidates = context->measure_candidates.value(); - LOG(INFO) << "Sending " << candidates.size() << " sample(s) to builder"; + TVM_PY_LOG(INFO, logging_func) << "Sending " << candidates.size() << " sample(s) to builder"; Target target = context->target.value(); Array inputs; inputs.reserve(candidates.size()); @@ -48,10 +48,10 @@ void SendToBuilder(const Builder& builder, const TuneContext& context) { * \param builder_results The builder results. * \return An array of the runner results. */ -void SendToRunner(const Runner& runner, const TuneContext& context) { +void SendToRunner(const Runner& runner, const TuneContext& context, PackedFunc logging_func) { Array candidates = context->measure_candidates.value(); Array builder_results = context->builder_results.value(); - LOG(INFO) << "Sending " << candidates.size() << " sample(s) to runner"; + TVM_PY_LOG(INFO, logging_func) << "Sending " << candidates.size() << " sample(s) to runner"; Target target = context->target.value(); ICHECK_EQ(candidates.size(), builder_results.size()); int n = candidates.size(); @@ -94,24 +94,26 @@ void SendToRunner(const Runner& runner, const TuneContext& context) { void TaskSchedulerNode::InitializeTask(int task_id) { TuneContext task = this->tasks[task_id]; - LOG(INFO) << "Initializing Task #" << task_id << ": " << task->task_name; + TVM_PY_LOG(INFO, task->logging_func) + << "Initializing Task #" << task_id << ": " << task->task_name; CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; CHECK(task->space_generator.defined()) << "ValueError: Require `context.space_generator`, but it is not defined"; CHECK(task->search_strategy.defined()) << "ValueError: Require `context.search_strategy`, but it is not defined"; - LOG(INFO) << "\n" << tir::AsTVMScript(task->mod); + TVM_PY_LOG(INFO, task->logging_func) << "\n" << tir::AsTVMScript(task->mod); task->Initialize(); Array design_spaces = task->space_generator.value()->GenerateDesignSpace(task->mod.value()); - LOG(INFO) << "Total " << design_spaces.size() << " design space(s) generated"; + TVM_PY_LOG(INFO, task->logging_func) + << "Total " << design_spaces.size() << " design space(s) generated"; for (int i = 0, n = design_spaces.size(); i < n; ++i) { tir::Schedule sch = design_spaces[i]; tir::Trace trace = sch->trace().value(); trace = trace->Simplified(true); - LOG(INFO) << "Design space #" << i << ":\n" - << tir::AsTVMScript(sch->mod()) << "\n" - << Concat(trace->AsPython(false), "\n"); + TVM_PY_LOG(INFO, task->logging_func) << "Design space #" << i << ":\n" + << tir::AsTVMScript(sch->mod()) << "\n" + << Concat(trace->AsPython(false), "\n"); } task->search_strategy.value()->PreTuning(design_spaces); } @@ -123,20 +125,22 @@ void TaskSchedulerNode::Tune() { } int running_tasks = tasks.size(); for (int task_id; num_trials_already < max_trials && (task_id = NextTaskId()) != -1;) { - LOG(INFO) << "Scheduler picks Task #" << task_id << ": " << tasks[task_id]->task_name; + TVM_PY_LOG(INFO, this->logging_func) + << "Scheduler picks Task #" << task_id << ": " << tasks[task_id]->task_name; TuneContext task = tasks[task_id]; ICHECK(!task->is_terminated); ICHECK(!task->runner_futures.defined()); SearchStrategy strategy = task->search_strategy.value(); if ((task->measure_candidates = strategy->GenerateMeasureCandidates()).defined()) { num_trials_already += task->measure_candidates.value().size(); - SendToBuilder(this->builder, task); - SendToRunner(this->runner, task); + SendToBuilder(this->builder, task, this->logging_func); + SendToRunner(this->runner, task, this->logging_func); } else { ICHECK(!task->is_terminated); task->is_terminated = true; --running_tasks; - LOG(INFO) << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks; + TVM_PY_LOG(INFO, this->logging_func) + << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks; } } for (int task_id = 0; task_id < n_tasks; ++task_id) { @@ -147,7 +151,8 @@ void TaskSchedulerNode::Tune() { } task->is_terminated = true; --running_tasks; - LOG(INFO) << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks; + TVM_PY_LOG(INFO, this->logging_func) + << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks; } task->search_strategy.value()->PostTuning(); } @@ -200,6 +205,7 @@ TaskScheduler TaskScheduler::PyTaskScheduler( int max_trials, // Optional cost_model, // Optional> measure_callbacks, // + PackedFunc logging_func, // PyTaskSchedulerNode::FTune f_tune, // PyTaskSchedulerNode::FInitializeTask f_initialize_task, // PyTaskSchedulerNode::FTouchTask f_touch_task, // @@ -217,6 +223,7 @@ TaskScheduler TaskScheduler::PyTaskScheduler( } else { n->measure_callbacks = {}; } + n->logging_func = logging_func; n->num_trials_already = 0; n->f_tune = f_tune; n->f_initialize_task = f_initialize_task; diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index ba8ee58c5ba4..382dd961dee0 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -31,6 +31,7 @@ TuneContext::TuneContext(Optional mod, Optional> postprocs, // Optional> mutator_probs, // Optional task_name, // + PackedFunc logging_func, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) { ObjectPtr n = make_object(); @@ -42,6 +43,7 @@ TuneContext::TuneContext(Optional mod, n->postprocs = postprocs.value_or({}); n->mutator_probs = mutator_probs.value_or({}); n->task_name = task_name; + n->logging_func = logging_func; support::LinearCongruentialEngine(&n->rand_state).Seed(rand_state); n->num_threads = num_threads; n->is_terminated = false; @@ -79,10 +81,11 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") Optional> postprocs, // Optional> mutator_probs, // Optional task_name, // + PackedFunc logging_func, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) -> TuneContext { return TuneContext(mod, target, space_generator, search_strategy, sch_rules, postprocs, - mutator_probs, task_name, rand_state, num_threads); + mutator_probs, task_name, logging_func, rand_state, num_threads); }); TVM_REGISTER_GLOBAL("meta_schedule._SHash2Hex").set_body_typed(SHash2Hex); diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index a29f991cbb60..533d062d0425 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -53,9 +53,56 @@ #include "../tir/schedule/primitive.h" #include "../tir/schedule/utils.h" +#define TVM_PY_LOG(logging_level, logging_func) \ + ::tvm::meta_schedule::PyLogMessage(__FILE__, __LINE__, logging_func, \ + PyLogMessage::Level::logging_level) \ + .stream() + namespace tvm { namespace meta_schedule { +/*! + * \brief Class to accumulate an log message on the python side. Do not use directly, instead use + * TVM_PY_LOG(DEBUG), TVM_PY_LOG(INFO), TVM_PY_LOG(WARNING), TVM_PY_ERROR(ERROR). + */ +class PyLogMessage { + public: + enum class Level : int32_t { + DEBUG = 10, + INFO = 20, + WARNING = 30, + ERROR = 40, + // FATAL not included + }; + + PyLogMessage(const std::string& file, int lineno, PackedFunc logging_func, Level logging_level) { + this->logging_func = logging_func; + this->logging_level = logging_level; + } + TVM_NO_INLINE ~PyLogMessage() { + if (this->logging_func.defined()) { + logging_func(static_cast(logging_level), stream_.str()); + } else { + if (logging_level == Level::INFO) + LOG(INFO) << stream_.str(); + else if (logging_level == Level::WARNING) + LOG(WARNING) << stream_.str(); + else if (logging_level == Level::ERROR) + LOG(ERROR) << stream_.str(); + else if (logging_level == Level::DEBUG) + DLOG(INFO) << stream_.str(); + else + LOG(FATAL) << stream_.str(); + } + } + std::ostringstream& stream() { return stream_; } + + private: + std::ostringstream stream_; + PackedFunc logging_func; + Level logging_level; +}; + /*! \brief The type of the random state */ using TRandState = support::LinearCongruentialEngine::TRandState; @@ -321,6 +368,7 @@ struct ThreadedTraceApply { return NullOpt; } } catch (const std::exception& e) { + // Used in multi-thread, only output to screen but failure summary sent to logging LOG(WARNING) << "ThreadedTraceApply::Apply failed with error " << e.what(); return NullOpt; } diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index 23f5ebac2c86..e154f9ff27b0 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -40,7 +40,9 @@ from tvm.tir.schedule.trace import Trace from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN -logging.basicConfig() +logging.basicConfig( + format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" +) logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument