From 01274b444fb26b02e0d800e68d5f2bbe5178e8e2 Mon Sep 17 00:00:00 2001 From: Yuanjing Shi Date: Thu, 12 Aug 2021 13:35:04 -0700 Subject: [PATCH] [TIR] Use PopenPool instead of multiprocessing.pool (#8492) Co-authored-by: Wuwei Lin --- python/tvm/auto_scheduler/measure.py | 216 ++++++++++-------- python/tvm/auto_scheduler/utils.py | 43 +--- python/tvm/autotvm/record.py | 4 +- python/tvm/autotvm/utils.py | 4 +- python/tvm/contrib/popen_pool.py | 9 +- python/tvm/testing/__init__.py | 34 +++ python/tvm/testing/_ffi_api.py | 21 ++ .../tvm/testing/auto_scheduler.py | 2 +- python/tvm/{testing.py => testing/utils.py} | 3 - .../test_auto_scheduler_compute_dag.py | 2 +- .../test_auto_scheduler_cost_model.py | 2 +- ...test_auto_scheduler_evolutionary_search.py | 2 +- .../unittest/test_auto_scheduler_feature.py | 2 +- .../test_auto_scheduler_layout_rewrite.py | 2 +- .../test_auto_scheduler_loop_state.py | 2 +- .../unittest/test_auto_scheduler_measure.py | 2 +- .../test_auto_scheduler_search_policy.py | 2 +- .../test_auto_scheduler_search_task.py | 2 +- .../test_auto_scheduler_sketch_generation.py | 2 +- .../test_auto_scheduler_task_scheduler.py | 2 +- 20 files changed, 205 insertions(+), 153 deletions(-) create mode 100644 python/tvm/testing/__init__.py create mode 100644 python/tvm/testing/_ffi_api.py rename tests/python/unittest/test_auto_scheduler_common.py => python/tvm/testing/auto_scheduler.py (99%) rename python/tvm/{testing.py => testing/utils.py} (99%) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 8d762602bfd1..a202e837bc9b 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -44,6 +44,7 @@ from tvm.ir import transform from tvm.autotvm.measure.measure_methods import set_cuda_target_arch from tvm.contrib import tar, ndk +from tvm.contrib.popen_pool import PopenWorker, PopenPoolExecutor, StatusKind from tvm.target import Target @@ -599,7 +600,7 @@ class MeasureErrorNo(object): UNKNOWN_ERROR = 8 # Unknown error -def _timed_func(inp_serialized, build_func, verbose): +def _local_build_worker(inp_serialized, build_func, verbose): tic = time.time() inp = MeasureInput.deserialize(inp_serialized) task = inp.task @@ -664,15 +665,13 @@ def local_build_worker(args): ) build_func = BuildFunc.build_func - res = call_func_with_timeout(timeout, _timed_func, args=(inp, build_func, verbose)) - if isinstance(res, TimeoutError): - if verbose >= 1: - print(".T", end="", flush=True) # Build timeout - res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout - elif isinstance(res, Exception): + try: + res = _local_build_worker(inp, build_func, verbose) + # pylint: disable=broad-except + except Exception: if verbose >= 1: print(".E", end="", flush=True) # Build error - res = None, [], MeasureErrorNo.COMPILE_HOST, str(res), timeout + res = None, [], MeasureErrorNo.COMPILE_HOST, make_traceback_info(), timeout return res @@ -701,9 +700,8 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo res : List[BuildResult] The build results of these MeasureInputs. """ - # This pool is not doing computationally intensive work, so we can use threads - pool = multiprocessing.pool.ThreadPool(n_parallel) - tuple_res = pool.map( + executor = PopenPoolExecutor(n_parallel, timeout) + tuple_res = executor.map_with_error_catching( local_build_worker, [ ( @@ -715,13 +713,16 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo for i in inputs ], ) - pool.terminate() - pool.join() - del pool results = [] for res in tuple_res: - results.append(BuildResult(*res)) + if res.status == StatusKind.COMPLETE: + results.append(BuildResult(*res.value)) + else: + assert res.status == StatusKind.TIMEOUT + if verbose >= 1: + print(".T", end="", flush=True) # Build timeout + results.append(BuildResult(None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout)) return results @@ -817,9 +818,58 @@ def prepare_input_map(args): return tensor_input_map +def prepare_runner_args(inp, build_res): + """This function prepares the pre-defined arguments in `TASK_INPUT_BUFFER_TABLE` for local/rpc + runner in main process + + Parameters + ---------- + inp : MeasureInput + Measure input to be measured. + + build_res : BuildResult + Build result to be measured. + + Returns + ------- + List[Optional[numpy.ndarray]] : + List of arguments for running the program. If the argument does not have a pre-defined input + buffer, None is added to the list as a placeholder. + + """ + # pylint: disable=import-outside-toplevel + from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency + + task_input_names = inp.task.task_input_names + tensor_input_map = prepare_input_map(build_res.args) + if not task_input_names: + tensor_input_map = {} + args = [] + task_inputs_count = 0 + for arg in build_res.args: + if arg in tensor_input_map: + tensor_name = tensor_input_map[arg] + if tensor_name in task_input_names: + task_input_buffer = get_task_input_buffer(inp.task.workload_key, tensor_name) + # convert tvm.NDArray to picklable numpy.ndarray + args.append(task_input_buffer.numpy()) + task_inputs_count += 1 + else: + raise ValueError( + "%s not found in task_inputs, " % (tensor_name) + + "should provide with `SearchTask(..., task_inputs={...})`" + ) + else: + args.append(None) + if task_inputs_count != len(task_input_names): + raise RuntimeError("task_inputs not fully matched, check if there's any unexpected error") + return args + + def _timed_eval_func( inp_serialized, build_res, + args, number, repeat, min_repeat_ms, @@ -827,11 +877,7 @@ def _timed_eval_func( enable_cpu_cache_flush, verbose, ): - # pylint: disable=import-outside-toplevel - from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency - inp = MeasureInput.deserialize(inp_serialized) - task_input_names = inp.task.task_input_names tic = time.time() error_no = 0 error_msg = None @@ -862,33 +908,18 @@ def _timed_eval_func( try: random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True) assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake" - - tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {} - args = [] - task_inputs_count = 0 - for arg in build_res.args: - if arg in tensor_input_map: - tensor_name = tensor_input_map[arg] - if tensor_name in task_input_names: - args.append( - ndarray.array( - get_task_input_buffer(inp.task.workload_key, tensor_name), dev - ) - ) - task_inputs_count += 1 - else: - raise ValueError( - "%s not found in task_inputs, " % (tensor_name) - + "should provide with `SearchTask(..., task_inputs={...})`" - ) - else: - empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, dev) + assert len(args) == len(build_res.args) + # pylint: disable=consider-using-enumerate + for idx in range(len(args)): + if args[idx] is None: + build_res_arg = build_res.args[idx] + empty_array = ndarray.empty( + get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev + ) random_fill(empty_array) - args.append(empty_array) - if task_inputs_count != len(task_input_names): - raise RuntimeError( - "task_inputs not fully matched, check if there's any unexpected error" - ) + args[idx] = empty_array + else: + args[idx] = ndarray.array(args[idx], dev) dev.sync() costs = time_f(*args).results # pylint: disable=broad-except @@ -968,6 +999,7 @@ def local_run( measure_results = [] assert len(inputs) == len(build_results), "Measure input size should be equal to build results" + worker = PopenWorker() for inp, build_res in zip(inputs, build_results): if build_res.error_no != 0: res = ( @@ -978,12 +1010,15 @@ def local_run( time.time(), ) else: + args = prepare_runner_args(inp, build_res) res = call_func_with_timeout( + worker, timeout, _timed_eval_func, args=( inp.serialize(), build_res, + args, number, repeat, min_repeat_ms, @@ -991,7 +1026,6 @@ def local_run( enable_cpu_cache_flush, verbose, ), - add_thread_wrapper=True, ) if isinstance(res, TimeoutError): if verbose >= 1: @@ -1022,9 +1056,10 @@ def local_run( return measure_results -def _timed_rpc_run( +def _rpc_run( inp_serialized, build_res, + args, key, host, port, @@ -1037,11 +1072,7 @@ def _timed_rpc_run( enable_cpu_cache_flush, verbose, ): - # pylint: disable=import-outside-toplevel - from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency - inp = MeasureInput.deserialize(inp_serialized) - task_input_names = inp.task.task_input_names tic = time.time() error_no = 0 error_msg = None @@ -1080,32 +1111,18 @@ def _timed_rpc_run( random_fill ), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices" - tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {} - args = [] - task_inputs_count = 0 - for arg in build_res.args: - if arg in tensor_input_map: - tensor_name = tensor_input_map[arg] - if tensor_name in task_input_names: - args.append( - ndarray.array( - get_task_input_buffer(inp.task.workload_key, tensor_name), dev - ) - ) - task_inputs_count += 1 - else: - raise ValueError( - "%s not found in task_inputs, " % (tensor_name) - + "should provide with `SearchTask(..., task_inputs={...})`" - ) - else: - empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, dev) + assert len(args) == len(build_res.args) + # pylint: disable=consider-using-enumerate + for idx in range(len(args)): + if args[idx] is None: + build_res_arg = build_res.args[idx] + empty_array = ndarray.empty( + get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev + ) random_fill(empty_array) - args.append(empty_array) - if task_inputs_count != len(task_input_names): - logger.warning( - "task_inputs not fully matched, check if there's any unexpected error" - ) + args[idx] = empty_array + else: + args[idx] = ndarray.array(args[idx], dev) dev.sync() # First run for check that the kernel is correct @@ -1152,7 +1169,7 @@ def _rpc_run_worker(args): res : MeasureResult The measure result of this Runner thread. """ - _, build_res, _, _, _, _, timeout, _, _, _, _, _, verbose = args + _, build_res, _, _, _, _, _, timeout, _, _, _, _, _, verbose = args if build_res.error_no != MeasureErrorNo.NO_ERROR: return ( (MAX_FLOAT,), @@ -1162,24 +1179,16 @@ def _rpc_run_worker(args): time.time(), ) - res = call_func_with_timeout(timeout, _timed_rpc_run, args=args) - if isinstance(res, TimeoutError): - if verbose >= 1: - print("*T", end="") # Run timeout - res = ( - (MAX_FLOAT,), - MeasureErrorNo.RUN_TIMEOUT, - None, - build_res.time_cost + timeout, - time.time(), - ) - elif isinstance(res, Exception): + try: + res = _rpc_run(*args) + # pylint: disable=broad-except + except Exception: if verbose >= 1: print("*E", end="") # Run error res = ( (MAX_FLOAT,), MeasureErrorNo.RUNTIME_DEVICE, - str(res), + make_traceback_info(), build_res.time_cost + timeout, time.time(), ) @@ -1259,13 +1268,14 @@ def rpc_runner_run( """ assert len(inputs) == len(build_results), "Measure input size should be equal to build results" # This pool is not doing computationally intensive work, so we can use threads - pool = multiprocessing.pool.ThreadPool(n_parallel) - tuple_res = pool.map( + executor = PopenPoolExecutor(n_parallel) + tuple_res = executor.map_with_error_catching( _rpc_run_worker, [ ( inp.serialize(), build_res, + prepare_runner_args(inp, build_res), key, host, port, @@ -1281,13 +1291,25 @@ def rpc_runner_run( for inp, build_res in zip(inputs, build_results) ], ) - pool.terminate() - pool.join() - del pool results = [] - for res in tuple_res: - results.append(MeasureResult(*res)) + for i, res in enumerate(tuple_res): + if res.status == StatusKind.COMPLETE: + results.append(MeasureResult(*res.value)) + else: + assert res.status == StatusKind.TIMEOUT + if verbose >= 1: + print("*T", end="") # Run timeout + build_res = build_results[i] + results.append( + MeasureResult( + (MAX_FLOAT,), + MeasureErrorNo.RUN_TIMEOUT, + None, + build_res.time_cost + timeout, + time.time(), + ) + ) if verbose >= 1: print("") diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py index 1c03491c5614..9919bcb470ee 100644 --- a/python/tvm/auto_scheduler/utils.py +++ b/python/tvm/auto_scheduler/utils.py @@ -20,9 +20,6 @@ from typing import Hashable import json -import multiprocessing -import multiprocessing.pool -import queue import signal import threading import traceback @@ -289,41 +286,15 @@ def wrapper(): return res[0] -def _func_wrapper(que, func, args, kwargs, add_thread_wrapper): - """Call function and return the result over the queue.""" - try: - if add_thread_wrapper: - # Add a new layer of threadinng to avoid the conflict between - # python's multiprocessing and tvm's thread pool. - res = call_func_with_thread(func, args, kwargs) - else: - res = func(*args, **kwargs) - que.put(res) - except Exception: # pylint: disable=broad-except - que.put(Exception(make_traceback_info())) - - -def call_func_with_timeout(timeout, func, args=(), kwargs=None, add_thread_wrapper=False): +def call_func_with_timeout( + worker, timeout, func, args=(), kwargs=None +): # pylint: disable=unused-argument """Call a function with timeout""" - que = multiprocessing.Queue(2) - process = multiprocessing.Process( - target=_func_wrapper, args=(que, func, args, kwargs or {}, add_thread_wrapper) - ) - process.start() - + worker.send(func, args, kwargs, timeout) try: - res = que.get(timeout=timeout) - except queue.Empty: - res = TimeoutError() - - # clean queue and process - kill_child_processes(process.pid) - process.terminate() - process.join() - que.close() - que.join_thread() - del process - del que + res = worker.recv() + except Exception: # pylint: disable=broad-except + res = Exception(make_traceback_info()) return res diff --git a/python/tvm/autotvm/record.py b/python/tvm/autotvm/record.py index 4f11aea2911f..8145563f5075 100644 --- a/python/tvm/autotvm/record.py +++ b/python/tvm/autotvm/record.py @@ -21,7 +21,6 @@ import argparse import base64 import logging -import multiprocessing import pickle import json import time @@ -32,6 +31,7 @@ from .. import build, lower from ..target import Target +from ..contrib import popen_pool from .. import __version__ from . import task from .task import ConfigEntity, ApplyHistoryBest @@ -230,7 +230,7 @@ def split_workload(in_file, clean=True): lines = list(open(in_file).readlines()) logger.info("start converting...") - pool = multiprocessing.Pool() + pool = popen_pool.PopenPoolExecutor() lines = [rec for rec in pool.map(decode, lines) if rec is not None] logger.info("map done %.2f", time.time() - tic) diff --git a/python/tvm/autotvm/utils.py b/python/tvm/autotvm/utils.py index fa1dcfd1241b..ec3f18daa6c9 100644 --- a/python/tvm/autotvm/utils.py +++ b/python/tvm/autotvm/utils.py @@ -17,7 +17,6 @@ # pylint: disable=invalid-name """Utilities""" import logging -import multiprocessing import time from random import randrange @@ -25,6 +24,7 @@ import numpy as np import tvm.arith from tvm.tir import expr +from tvm.contrib.popen_pool import PopenPoolExecutor logger = logging.getLogger("autotvm") @@ -111,7 +111,7 @@ def pool_map(func, args, batch_size, verbose=False, pool=None): ret = None tic = time.time() - local_pool = pool or multiprocessing.Pool() + local_pool = pool or PopenPoolExecutor() if verbose: logger.info("mapping begin") for i in range(0, len(args), batch_size): diff --git a/python/tvm/contrib/popen_pool.py b/python/tvm/contrib/popen_pool.py index 2f552034e9f8..68c21ef5f212 100644 --- a/python/tvm/contrib/popen_pool.py +++ b/python/tvm/contrib/popen_pool.py @@ -269,9 +269,16 @@ class PopenPoolExecutor: timeout : float Timeout value for each function submit. + Note + ---- + If max_workers is NONE then the number returned by + os.cpu_count() is used. This method aligns with the + behavior of multiprocessing.pool(). """ - def __init__(self, max_workers, timeout=None): + def __init__(self, max_workers=None, timeout=None): + if max_workers is None: + max_workers = os.cpu_count() # Use an internal thread pool to send to popen workers self._threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) self._timeout = timeout diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py new file mode 100644 index 000000000000..bd1ada4fa284 --- /dev/null +++ b/python/tvm/testing/__init__.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, wildcard-import +"""Utility Python functions for TVM testing""" +from .utils import assert_allclose, assert_prim_expr_equal, check_bool_expr_is_true +from .utils import check_int_constraints_trans_consistency, check_numerical_grads +from .utils import device_enabled, enabled_targets, exclude_targets +from .utils import fixture, parameter, parameters, parametrize_targets, uses_gpu +from .utils import known_failing_targets, requires_cuda, requires_cudagraph +from .utils import requires_gpu, requires_llvm, requires_rocm, requires_rpc +from .utils import requires_tensorcore, requires_metal, requires_micro, requires_opencl +from .utils import _auto_parametrize_target, _count_num_fixture_uses +from .utils import _remove_global_fixture_definitions, _parametrize_correlated_parameters +from .utils import _pytest_target_params, identity_after, terminate_self + +from ._ffi_api import nop, echo, device_test, run_check_signal, object_use_count +from ._ffi_api import test_wrap_callback, test_raise_error_callback, test_check_eq_callback +from ._ffi_api import ErrorTest, FrontendTestModule + +from . import auto_scheduler diff --git a/python/tvm/testing/_ffi_api.py b/python/tvm/testing/_ffi_api.py new file mode 100644 index 000000000000..56a77223b767 --- /dev/null +++ b/python/tvm/testing/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.testing""" +import tvm._ffi + + +tvm._ffi._init_api("testing", __name__) diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/python/tvm/testing/auto_scheduler.py similarity index 99% rename from tests/python/unittest/test_auto_scheduler_common.py rename to python/tvm/testing/auto_scheduler.py index 4890268c907b..bc335c82d324 100644 --- a/tests/python/unittest/test_auto_scheduler_common.py +++ b/python/tvm/testing/auto_scheduler.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +# pylint: disable=invalid-name, missing-function-docstring """Common functions for auto_scheduler test cases""" import tvm from tvm import auto_scheduler, te, topi diff --git a/python/tvm/testing.py b/python/tvm/testing/utils.py similarity index 99% rename from python/tvm/testing.py rename to python/tvm/testing/utils.py index 9515189815e9..71ab0770d64e 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing/utils.py @@ -1376,6 +1376,3 @@ def identity_after(x, sleep): def terminate_self(): """Testing function to terminate the process.""" sys.exit(-1) - - -tvm._ffi._init_api("testing", __name__) diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py b/tests/python/unittest/test_auto_scheduler_compute_dag.py index e394115619a4..81ee5cabbfbc 100644 --- a/tests/python/unittest/test_auto_scheduler_compute_dag.py +++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py @@ -23,7 +23,7 @@ from tvm import topi from tvm import auto_scheduler, te -from test_auto_scheduler_common import ( +from tvm.testing.auto_scheduler import ( get_tiled_matmul, invalid_compute_definition, matmul_auto_scheduler_test, diff --git a/tests/python/unittest/test_auto_scheduler_cost_model.py b/tests/python/unittest/test_auto_scheduler_cost_model.py index 0b34615583db..50e3ceb6f5fa 100644 --- a/tests/python/unittest/test_auto_scheduler_cost_model.py +++ b/tests/python/unittest/test_auto_scheduler_cost_model.py @@ -24,7 +24,7 @@ import tvm from tvm import auto_scheduler -from test_auto_scheduler_common import matmul_auto_scheduler_test +from tvm.testing.auto_scheduler import matmul_auto_scheduler_test def get_sample_records(number): diff --git a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py index e28219d0979f..b5c99c0f05fd 100644 --- a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py +++ b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py @@ -18,7 +18,7 @@ import tvm import pytest -from test_auto_scheduler_common import matmul_auto_scheduler_test +from tvm.testing.auto_scheduler import matmul_auto_scheduler_test from tvm import auto_scheduler, te from tvm.auto_scheduler.cost_model.cost_model import PythonBasedModel diff --git a/tests/python/unittest/test_auto_scheduler_feature.py b/tests/python/unittest/test_auto_scheduler_feature.py index 82cfb1d6508b..96090e328328 100644 --- a/tests/python/unittest/test_auto_scheduler_feature.py +++ b/tests/python/unittest/test_auto_scheduler_feature.py @@ -23,7 +23,7 @@ import tvm from tvm import te, auto_scheduler -from test_auto_scheduler_common import matmul_auto_scheduler_test +from tvm.testing.auto_scheduler import matmul_auto_scheduler_test def fequal(a, b): diff --git a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py index c9291965613b..39673fad2495 100644 --- a/tests/python/unittest/test_auto_scheduler_layout_rewrite.py +++ b/tests/python/unittest/test_auto_scheduler_layout_rewrite.py @@ -26,7 +26,7 @@ from tvm import topi from tvm import auto_scheduler, te -from test_auto_scheduler_common import get_tiled_matmul, matmul_auto_scheduler_test +from tvm.testing.auto_scheduler import get_tiled_matmul, matmul_auto_scheduler_test def test_apply_steps_with_layout_rewrite(): diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index 44ed1fc42562..0965ed9efbac 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -23,7 +23,7 @@ from tvm import auto_scheduler, te from tvm import topi -from test_auto_scheduler_common import ( +from tvm.testing.auto_scheduler import ( matmul_auto_scheduler_test, conv2d_nchw_bn_relu_auto_scheduler_test, ) diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 375f8167ff08..9eae3dd33672 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -26,7 +26,7 @@ import tempfile import tvm.testing import pickle -from test_auto_scheduler_common import matmul_auto_scheduler_test +from tvm.testing.auto_scheduler import matmul_auto_scheduler_test from tvm.auto_scheduler import workload_registry diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index d114ce4f9d16..a9f6596a8548 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -27,7 +27,7 @@ from tvm import auto_scheduler from tvm.auto_scheduler.utils import get_const_tuple -from test_auto_scheduler_common import ( +from tvm.testing.auto_scheduler import ( matmul_auto_scheduler_test, zero_rank_compute_auto_scheduler_test, zero_rank_reduce_auto_scheduler_test, diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py b/tests/python/unittest/test_auto_scheduler_search_task.py index cd47f1e468ff..f23b02c24298 100644 --- a/tests/python/unittest/test_auto_scheduler_search_task.py +++ b/tests/python/unittest/test_auto_scheduler_search_task.py @@ -24,7 +24,7 @@ import tvm.testing from tvm import auto_scheduler from tvm.auto_scheduler.utils import get_const_tuple -from test_auto_scheduler_common import ( +from tvm.testing.auto_scheduler import ( matmul_auto_scheduler_test, zero_rank_compute_auto_scheduler_test, zero_rank_reduce_auto_scheduler_test, diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py b/tests/python/unittest/test_auto_scheduler_sketch_generation.py index 101730feaa77..70a4edd81616 100644 --- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py +++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py @@ -27,7 +27,7 @@ from tvm.auto_scheduler import _ffi_api from tvm.auto_scheduler.loop_state import Stage -from test_auto_scheduler_common import ( +from tvm.testing.auto_scheduler import ( matmul_auto_scheduler_test, double_matmul_auto_scheduler_test, conv2d_nchw_bn_relu_auto_scheduler_test, diff --git a/tests/python/unittest/test_auto_scheduler_task_scheduler.py b/tests/python/unittest/test_auto_scheduler_task_scheduler.py index bbe29b1ba4f9..a3f356929dd1 100644 --- a/tests/python/unittest/test_auto_scheduler_task_scheduler.py +++ b/tests/python/unittest/test_auto_scheduler_task_scheduler.py @@ -25,7 +25,7 @@ import tvm.testing from tvm import auto_scheduler -from test_auto_scheduler_common import matmul_auto_scheduler_test +from tvm.testing.auto_scheduler import matmul_auto_scheduler_test @tvm.testing.requires_llvm