diff --git a/docs/conf.py b/docs/conf.py index 32bc095272aa..a7198bf22355 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -251,6 +251,7 @@ "tune_relay_mobile_gpu.py", ], "auto_scheduler": ["tune_matmul_x86.py", "tune_conv2d_layer_cuda.py"], + "dev": ["low_level_custom_pass.py", "use_pass_infra.py", "bring_your_own_datatypes.py"], } diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h index 85154b5e406b..6d85835d2e4b 100755 --- a/include/tvm/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -44,17 +44,16 @@ class HardwareParamsNode : public Object { int cache_line_bytes; // GPU related parameters got from device query API - - /*! \brief The max shared memory per block. */ - int max_shared_memory_per_block{INT32_MAX}; - /*! \brief The max register memory per block. */ - int max_registers_per_block{INT32_MAX}; - /*! \brief The max threads per block. */ - int max_threads_per_block{INT32_MAX}; + /*! \brief The max shared memory per block in bytes. */ + int max_shared_memory_per_block; + /*! \brief The max number of register per block. */ + int max_registers_per_block; + /*! \brief The max number of threads per block. */ + int max_threads_per_block; /*! \brief The max vthread extent. */ - int max_vthread_extent{INT32_MAX}; + int max_vthread_extent; /*! \brief The thread numbers of a warp. */ - int warp_size{INT32_MAX}; + int warp_size; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("num_cores", &num_cores); @@ -90,8 +89,15 @@ class HardwareParams : public ObjectRef { * \param num_cores The number of cores. * \param vector_unit_bytes The width of vector units in bytes. * \param cache_line_bytes The size of cache line in bytes. + * \param max_shared_memory_per_block The max amount of shared memory per block for GPU. + * \param max_registers_per_block The max number of registers per block for GPU. + * \param max_threads_per_block The max number of threads per block for GPU. + * \param max_vthread_extent The max extent of vthread for GPU. + * \param warp_size The warp size for GPU */ - HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes); + HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes, + int max_shared_memory_per_block, int max_registers_per_block, + int max_threads_per_block, int max_vthread_extent, int warp_size); TVM_DEFINE_OBJECT_REF_METHODS(HardwareParams, ObjectRef, HardwareParamsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(HardwareParamsNode); diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index 5bf2335ec7cf..bee2e7f423b6 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -44,7 +44,12 @@ LocalRPCMeasureContext, ) from .measure_record import RecordToFile, RecordReader, load_best, load_records, save_records -from .relay_integration import extract_tasks, remove_index_check, rewrite_compute_body +from .relay_integration import ( + extract_tasks, + remove_index_check, + rewrite_compute_body, + is_auto_scheduler_enabled, +) from .search_task import SearchTask from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates from .task_scheduler import TaskScheduler diff --git a/python/tvm/auto_scheduler/auto_schedule.py b/python/tvm/auto_scheduler/auto_schedule.py index 5bc13fec62a9..57dc9588df51 100644 --- a/python/tvm/auto_scheduler/auto_schedule.py +++ b/python/tvm/auto_scheduler/auto_schedule.py @@ -46,11 +46,39 @@ class HardwareParams(Object): The width of vector units in bytes. cache_line_bytes : int The size of cache line in bytes. + max_shared_memory_per_block : int + The max shared memory per block in bytes. + max_registers_per_block : int + The max number of register per block. + max_threads_per_block : int + The max number of threads per block. + max_vthread_extent : int + The max vthread extent. + warp_size : int + The thread numbers of a warp. """ - def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes): + def __init__( + self, + num_cores, + vector_unit_bytes, + cache_line_bytes, + max_shared_memory_per_block, + max_registers_per_block, + max_threads_per_block, + max_vthread_extent, + warp_size, + ): self.__init_handle_by_constructor__( - _ffi_api.HardwareParams, num_cores, vector_unit_bytes, cache_line_bytes + _ffi_api.HardwareParams, + num_cores, + vector_unit_bytes, + cache_line_bytes, + max_shared_memory_per_block, + max_registers_per_block, + max_threads_per_block, + max_vthread_extent, + warp_size, ) diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 25b88811709e..5a197910e334 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -28,6 +28,7 @@ import tvm from tvm import autotvm, te, transform +from tvm.ir.transform import PassContext from tvm.runtime import convert_to_object from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor from tvm.tir import expr as _expr @@ -342,3 +343,14 @@ def rewrite_compute_body(compute_tensor, new_layout): num = op_node.num_outputs outputs = tuple(op_node.output(i) for i in range(num)) return outputs[0] if num == 1 else outputs + + +def is_auto_scheduler_enabled(): + """Return whether the auto-scheduler is enabled. + + Parameters + ---------- + enabled: bool + Whether the auto-scheduler is enabled + """ + return PassContext.current().config.get("relay.backend.use_auto_scheduler", False) diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index 35429552dc74..6f565edbd378 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -151,7 +151,7 @@ class SketchPolicy(SearchPolicy): "sample_init_min_population": 50, "sample_init_use_measured_ratio": 0.2, "evolutionary_search_population": 2048, - "evolutionary_search_num_iters": 3, + "evolutionary_search_num_iters": 4, "evolutionary_search_mutation_prob": 0.85, "cpu_multi_level_tiling_structure": "SSRSRS", "gpu_multi_level_tiling_structure": "SSSRRSRS", diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index 26bfa2e376b4..a3dbcae64b60 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -181,7 +181,7 @@ class TaskScheduler: The parameter used for 'gradient' strategy callbacks: Optional[List[TaskSchedulerCallback]] The task scheduler callbacks that will be called before and after tuning a task. - If None, then PrintTableInfo callback will be used. + If None, PrintTableInfo and LogEstimatedLatency callback will be used. """ def __init__( @@ -214,7 +214,11 @@ def __init__( self.beta = beta self.gamma = gamma self.backward_window_size = backward_window_size - self.callbacks = callbacks if callbacks is not None else [PrintTableInfo()] + self.callbacks = ( + callbacks + if callbacks is not None + else [PrintTableInfo(), LogEstimatedLatency("total_latency.tsv")] + ) assert len(self.tasks) != 0, "No tasks" assert self.strategy in ["round-robin", "gradient"] diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index bd96cad8ed02..029690680e7d 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -18,7 +18,7 @@ # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import from tvm import topi import tvm -from tvm.ir.transform import PassContext +from tvm.auto_scheduler import is_auto_scheduler_enabled from tvm.te import SpecializedCondition from tvm.contrib import nvcc from tvm._ffi import get_global_func @@ -230,10 +230,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): ) # register auto-scheduler implementations - use_auto_scheduler = PassContext.current().config.get( - "relay.backend.use_auto_scheduler", False - ) - if use_auto_scheduler and judge_winograd_auto_scheduler: + if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler: strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc), naive_schedule, # this implementation should never be picked by autotvm @@ -460,7 +457,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda", ) - if PassContext.current().config.get("relay.backend.use_auto_scheduler", False): + if is_auto_scheduler_enabled(): strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform), naive_schedule, # this implementation should never be picked by autotvm diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 98b56ef4d1c0..5dfeca65e5c3 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -20,6 +20,7 @@ import re from tvm import topi +from tvm.auto_scheduler import is_auto_scheduler_enabled from tvm.te import SpecializedCondition from tvm.relay.ty import is_dynamic from .generic import * @@ -117,6 +118,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWIO" + if not is_auto_scheduler_enabled(): + logger.warning("conv2d NHWC layout is not optimized for x86 with autotvm.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True), wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc), @@ -124,7 +127,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): ) elif layout == "HWCN": assert kernel_layout == "HWIO" - logger.warning("conv2d HWCN layout is not optimized for x86.") + if not is_auto_scheduler_enabled(): + logger.warning("conv2d HWCN layout is not optimized for x86 with autotvm.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_hwcn), wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn), @@ -157,7 +161,10 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): return depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWOI" - logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.") + if not is_auto_scheduler_enabled(): + logger.warning( + "depthwise_conv2d NHWC layout is not optimized for x86 with autotvm." + ) strategy.add_implementation( wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc), @@ -168,7 +175,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): else: # group_conv2d if layout == "NCHW": assert kernel_layout == "OIHW" - logger.warning("group_conv2d is not optimized for x86.") + if not is_auto_scheduler_enabled(): + logger.warning("group_conv2d is not optimized for x86 with autotvm.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True), wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw), @@ -176,7 +184,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): ) elif layout == "NHWC": assert kernel_layout == "HWIO" - logger.warning("group_conv2d is not optimized for x86.") + if not is_auto_scheduler_enabled(): + logger.warning("group_conv2d is not optimized for x86 with autotvm.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.group_conv2d_nhwc, has_groups=True), wrap_topi_schedule(topi.generic.schedule_group_conv2d_nhwc), diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index bd09a70c0655..48b3fc5eb38f 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -35,22 +35,26 @@ namespace auto_scheduler { TVM_REGISTER_NODE_TYPE(HardwareParamsNode); TVM_REGISTER_NODE_TYPE(SearchTaskNode); -HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes) { +HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes, + int max_shared_memory_per_block, int max_registers_per_block, + int max_threads_per_block, int max_vthread_extent, int warp_size) { auto node = make_object(); node->num_cores = num_cores; node->vector_unit_bytes = vector_unit_bytes; node->cache_line_bytes = cache_line_bytes; + node->max_shared_memory_per_block = max_shared_memory_per_block; + node->max_registers_per_block = max_registers_per_block; + node->max_threads_per_block = max_threads_per_block; + node->max_vthread_extent = max_vthread_extent; + node->warp_size = warp_size; data_ = std::move(node); } HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target, const Target& target_host) { if (target->kind->device_type == kDLCPU) { - return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64); + return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 0, 0, 0, 0, 0); } else if (target->kind->device_type == kDLGPU) { - auto hardware_params = HardwareParams(-1, 16, 64); - auto* p_hardware_params = hardware_params.CopyOnWrite(); - auto ctx = TVMContext{kDLGPU, 0}; auto func = tvm::runtime::Registry::Get("device_api.gpu"); ICHECK(func != nullptr) << "Cannot find GPU device_api in registry"; @@ -58,31 +62,30 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target tvm::runtime::TVMRetValue ret; device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); - p_hardware_params->max_shared_memory_per_block = ret; + int max_shared_memory_per_block = ret; device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, &ret); - p_hardware_params->max_registers_per_block = ret; + int max_registers_per_block = ret; device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret); - p_hardware_params->max_threads_per_block = ret; + int max_threads_per_block = ret; device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); - p_hardware_params->warp_size = ret; - - p_hardware_params->max_vthread_extent = p_hardware_params->warp_size / 4; + int warp_size = ret; - return hardware_params; + int max_vthread_extent = warp_size / 4; + return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_registers_per_block, + max_threads_per_block, max_vthread_extent, warp_size); } else if (target->kind->device_type == kDLMetal) { // Reference: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf // This setting looks working for Metal GPUs later than A10 - auto hardware_params = HardwareParams(-1, 16, 64); - auto* p_hardware_params = hardware_params.CopyOnWrite(); - p_hardware_params->max_shared_memory_per_block = 32 * 1024; - p_hardware_params->max_registers_per_block = 4 * 1024; - p_hardware_params->max_threads_per_block = 1024; - p_hardware_params->warp_size = 8; - p_hardware_params->max_vthread_extent = p_hardware_params->warp_size / 4; - return hardware_params; + int max_shared_memory_per_block = 32 * 1024; + int max_registers_per_block = 4 * 1024; + int max_threads_per_block = 1024; + int warp_size = 8; + int max_vthread_extent = warp_size / 4; + return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_registers_per_block, + max_threads_per_block, max_vthread_extent, warp_size); } else { LOG(FATAL) << "No default hardware parameters for target: " << target; } @@ -106,8 +109,12 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe } TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams") - .set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes) { - return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes); + .set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes, + int max_shared_memory_per_block, int max_registers_per_block, + int max_threads_per_block, int max_vthread_extent, int warp_size) { + return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes, + max_shared_memory_per_block, max_registers_per_block, + max_threads_per_block, max_vthread_extent, warp_size); }); TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask") diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py b/tests/python/unittest/test_auto_scheduler_compute_dag.py index bde3b786d370..1356154cacd6 100644 --- a/tests/python/unittest/test_auto_scheduler_compute_dag.py +++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py @@ -124,7 +124,7 @@ def test_stage_order(): dag, json.dumps(("test-key",)), tvm.target.Target("llvm"), - hardware_params=auto_scheduler.HardwareParams(100000, 16, 64), + hardware_params=auto_scheduler.HardwareParams(100000, 16, 64, 0, 0, 0, 0, 0), ) task2 = pickle.loads(pickle.dumps(task)) diff --git a/tests/python/unittest/test_auto_scheduler_feature.py b/tests/python/unittest/test_auto_scheduler_feature.py index 8cbe201859cc..7412dbc1f8a4 100644 --- a/tests/python/unittest/test_auto_scheduler_feature.py +++ b/tests/python/unittest/test_auto_scheduler_feature.py @@ -153,7 +153,9 @@ def test_gpu_feature(): inp.task.workload_key, inp.task.target, None, - auto_scheduler.HardwareParams(100000, 16, 64), + auto_scheduler.HardwareParams( + 100000, 16, 64, 1 << 30, 1 << 30, 1 << 30, 1 << 30, 1 << 30 + ), ) state = dag.infer_bound_from_state(inputs[0].state)