Skip to content

Commit

Permalink
[TIR] Use PopenPool instead of multiprocessing.pool (apache#8492)
Browse files Browse the repository at this point in the history
Co-authored-by: Wuwei Lin <wuwei@apache.org>
2 people authored and ylc committed Jan 13, 2022
1 parent be74cbb commit 01274b4
Showing 20 changed files with 205 additions and 153 deletions.
216 changes: 119 additions & 97 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
@@ -44,6 +44,7 @@
from tvm.ir import transform
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
from tvm.contrib import tar, ndk
from tvm.contrib.popen_pool import PopenWorker, PopenPoolExecutor, StatusKind
from tvm.target import Target


@@ -599,7 +600,7 @@ class MeasureErrorNo(object):
UNKNOWN_ERROR = 8 # Unknown error


def _timed_func(inp_serialized, build_func, verbose):
def _local_build_worker(inp_serialized, build_func, verbose):
tic = time.time()
inp = MeasureInput.deserialize(inp_serialized)
task = inp.task
@@ -664,15 +665,13 @@ def local_build_worker(args):
)
build_func = BuildFunc.build_func

res = call_func_with_timeout(timeout, _timed_func, args=(inp, build_func, verbose))
if isinstance(res, TimeoutError):
if verbose >= 1:
print(".T", end="", flush=True) # Build timeout
res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout
elif isinstance(res, Exception):
try:
res = _local_build_worker(inp, build_func, verbose)
# pylint: disable=broad-except
except Exception:
if verbose >= 1:
print(".E", end="", flush=True) # Build error
res = None, [], MeasureErrorNo.COMPILE_HOST, str(res), timeout
res = None, [], MeasureErrorNo.COMPILE_HOST, make_traceback_info(), timeout

return res

@@ -701,9 +700,8 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
res : List[BuildResult]
The build results of these MeasureInputs.
"""
# This pool is not doing computationally intensive work, so we can use threads
pool = multiprocessing.pool.ThreadPool(n_parallel)
tuple_res = pool.map(
executor = PopenPoolExecutor(n_parallel, timeout)
tuple_res = executor.map_with_error_catching(
local_build_worker,
[
(
@@ -715,13 +713,16 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
for i in inputs
],
)
pool.terminate()
pool.join()
del pool

results = []
for res in tuple_res:
results.append(BuildResult(*res))
if res.status == StatusKind.COMPLETE:
results.append(BuildResult(*res.value))
else:
assert res.status == StatusKind.TIMEOUT
if verbose >= 1:
print(".T", end="", flush=True) # Build timeout
results.append(BuildResult(None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout))

return results

@@ -817,21 +818,66 @@ def prepare_input_map(args):
return tensor_input_map


def prepare_runner_args(inp, build_res):
"""This function prepares the pre-defined arguments in `TASK_INPUT_BUFFER_TABLE` for local/rpc
runner in main process
Parameters
----------
inp : MeasureInput
Measure input to be measured.
build_res : BuildResult
Build result to be measured.
Returns
-------
List[Optional[numpy.ndarray]] :
List of arguments for running the program. If the argument does not have a pre-defined input
buffer, None is added to the list as a placeholder.
"""
# pylint: disable=import-outside-toplevel
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency

task_input_names = inp.task.task_input_names
tensor_input_map = prepare_input_map(build_res.args)
if not task_input_names:
tensor_input_map = {}
args = []
task_inputs_count = 0
for arg in build_res.args:
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_input_names:
task_input_buffer = get_task_input_buffer(inp.task.workload_key, tensor_name)
# convert tvm.NDArray to picklable numpy.ndarray
args.append(task_input_buffer.numpy())
task_inputs_count += 1
else:
raise ValueError(
"%s not found in task_inputs, " % (tensor_name)
+ "should provide with `SearchTask(..., task_inputs={...})`"
)
else:
args.append(None)
if task_inputs_count != len(task_input_names):
raise RuntimeError("task_inputs not fully matched, check if there's any unexpected error")
return args


def _timed_eval_func(
inp_serialized,
build_res,
args,
number,
repeat,
min_repeat_ms,
cooldown_interval,
enable_cpu_cache_flush,
verbose,
):
# pylint: disable=import-outside-toplevel
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency

inp = MeasureInput.deserialize(inp_serialized)
task_input_names = inp.task.task_input_names
tic = time.time()
error_no = 0
error_msg = None
@@ -862,33 +908,18 @@ def _timed_eval_func(
try:
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)
assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake"

tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {}
args = []
task_inputs_count = 0
for arg in build_res.args:
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_input_names:
args.append(
ndarray.array(
get_task_input_buffer(inp.task.workload_key, tensor_name), dev
)
)
task_inputs_count += 1
else:
raise ValueError(
"%s not found in task_inputs, " % (tensor_name)
+ "should provide with `SearchTask(..., task_inputs={...})`"
)
else:
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, dev)
assert len(args) == len(build_res.args)
# pylint: disable=consider-using-enumerate
for idx in range(len(args)):
if args[idx] is None:
build_res_arg = build_res.args[idx]
empty_array = ndarray.empty(
get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev
)
random_fill(empty_array)
args.append(empty_array)
if task_inputs_count != len(task_input_names):
raise RuntimeError(
"task_inputs not fully matched, check if there's any unexpected error"
)
args[idx] = empty_array
else:
args[idx] = ndarray.array(args[idx], dev)
dev.sync()
costs = time_f(*args).results
# pylint: disable=broad-except
@@ -968,6 +999,7 @@ def local_run(

measure_results = []
assert len(inputs) == len(build_results), "Measure input size should be equal to build results"
worker = PopenWorker()
for inp, build_res in zip(inputs, build_results):
if build_res.error_no != 0:
res = (
@@ -978,20 +1010,22 @@ def local_run(
time.time(),
)
else:
args = prepare_runner_args(inp, build_res)
res = call_func_with_timeout(
worker,
timeout,
_timed_eval_func,
args=(
inp.serialize(),
build_res,
args,
number,
repeat,
min_repeat_ms,
cooldown_interval,
enable_cpu_cache_flush,
verbose,
),
add_thread_wrapper=True,
)
if isinstance(res, TimeoutError):
if verbose >= 1:
@@ -1022,9 +1056,10 @@ def local_run(
return measure_results


def _timed_rpc_run(
def _rpc_run(
inp_serialized,
build_res,
args,
key,
host,
port,
@@ -1037,11 +1072,7 @@ def _timed_rpc_run(
enable_cpu_cache_flush,
verbose,
):
# pylint: disable=import-outside-toplevel
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency

inp = MeasureInput.deserialize(inp_serialized)
task_input_names = inp.task.task_input_names
tic = time.time()
error_no = 0
error_msg = None
@@ -1080,32 +1111,18 @@ def _timed_rpc_run(
random_fill
), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices"

tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {}
args = []
task_inputs_count = 0
for arg in build_res.args:
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_input_names:
args.append(
ndarray.array(
get_task_input_buffer(inp.task.workload_key, tensor_name), dev
)
)
task_inputs_count += 1
else:
raise ValueError(
"%s not found in task_inputs, " % (tensor_name)
+ "should provide with `SearchTask(..., task_inputs={...})`"
)
else:
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, dev)
assert len(args) == len(build_res.args)
# pylint: disable=consider-using-enumerate
for idx in range(len(args)):
if args[idx] is None:
build_res_arg = build_res.args[idx]
empty_array = ndarray.empty(
get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev
)
random_fill(empty_array)
args.append(empty_array)
if task_inputs_count != len(task_input_names):
logger.warning(
"task_inputs not fully matched, check if there's any unexpected error"
)
args[idx] = empty_array
else:
args[idx] = ndarray.array(args[idx], dev)
dev.sync()

# First run for check that the kernel is correct
@@ -1152,7 +1169,7 @@ def _rpc_run_worker(args):
res : MeasureResult
The measure result of this Runner thread.
"""
_, build_res, _, _, _, _, timeout, _, _, _, _, _, verbose = args
_, build_res, _, _, _, _, _, timeout, _, _, _, _, _, verbose = args
if build_res.error_no != MeasureErrorNo.NO_ERROR:
return (
(MAX_FLOAT,),
@@ -1162,24 +1179,16 @@ def _rpc_run_worker(args):
time.time(),
)

res = call_func_with_timeout(timeout, _timed_rpc_run, args=args)
if isinstance(res, TimeoutError):
if verbose >= 1:
print("*T", end="") # Run timeout
res = (
(MAX_FLOAT,),
MeasureErrorNo.RUN_TIMEOUT,
None,
build_res.time_cost + timeout,
time.time(),
)
elif isinstance(res, Exception):
try:
res = _rpc_run(*args)
# pylint: disable=broad-except
except Exception:
if verbose >= 1:
print("*E", end="") # Run error
res = (
(MAX_FLOAT,),
MeasureErrorNo.RUNTIME_DEVICE,
str(res),
make_traceback_info(),
build_res.time_cost + timeout,
time.time(),
)
@@ -1259,13 +1268,14 @@ def rpc_runner_run(
"""
assert len(inputs) == len(build_results), "Measure input size should be equal to build results"
# This pool is not doing computationally intensive work, so we can use threads
pool = multiprocessing.pool.ThreadPool(n_parallel)
tuple_res = pool.map(
executor = PopenPoolExecutor(n_parallel)
tuple_res = executor.map_with_error_catching(
_rpc_run_worker,
[
(
inp.serialize(),
build_res,
prepare_runner_args(inp, build_res),
key,
host,
port,
@@ -1281,13 +1291,25 @@ def rpc_runner_run(
for inp, build_res in zip(inputs, build_results)
],
)
pool.terminate()
pool.join()
del pool

results = []
for res in tuple_res:
results.append(MeasureResult(*res))
for i, res in enumerate(tuple_res):
if res.status == StatusKind.COMPLETE:
results.append(MeasureResult(*res.value))
else:
assert res.status == StatusKind.TIMEOUT
if verbose >= 1:
print("*T", end="") # Run timeout
build_res = build_results[i]
results.append(
MeasureResult(
(MAX_FLOAT,),
MeasureErrorNo.RUN_TIMEOUT,
None,
build_res.time_cost + timeout,
time.time(),
)
)

if verbose >= 1:
print("")
43 changes: 7 additions & 36 deletions python/tvm/auto_scheduler/utils.py
Original file line number Diff line number Diff line change
@@ -20,9 +20,6 @@

from typing import Hashable
import json
import multiprocessing
import multiprocessing.pool
import queue
import signal
import threading
import traceback
@@ -289,41 +286,15 @@ def wrapper():
return res[0]


def _func_wrapper(que, func, args, kwargs, add_thread_wrapper):
"""Call function and return the result over the queue."""
try:
if add_thread_wrapper:
# Add a new layer of threadinng to avoid the conflict between
# python's multiprocessing and tvm's thread pool.
res = call_func_with_thread(func, args, kwargs)
else:
res = func(*args, **kwargs)
que.put(res)
except Exception: # pylint: disable=broad-except
que.put(Exception(make_traceback_info()))


def call_func_with_timeout(timeout, func, args=(), kwargs=None, add_thread_wrapper=False):
def call_func_with_timeout(
worker, timeout, func, args=(), kwargs=None
): # pylint: disable=unused-argument
"""Call a function with timeout"""
que = multiprocessing.Queue(2)
process = multiprocessing.Process(
target=_func_wrapper, args=(que, func, args, kwargs or {}, add_thread_wrapper)
)
process.start()

worker.send(func, args, kwargs, timeout)
try:
res = que.get(timeout=timeout)
except queue.Empty:
res = TimeoutError()

# clean queue and process
kill_child_processes(process.pid)
process.terminate()
process.join()
que.close()
que.join_thread()
del process
del que
res = worker.recv()
except Exception: # pylint: disable=broad-except
res = Exception(make_traceback_info())

return res

4 changes: 2 additions & 2 deletions python/tvm/autotvm/record.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,6 @@
import argparse
import base64
import logging
import multiprocessing
import pickle
import json
import time
@@ -32,6 +31,7 @@

from .. import build, lower
from ..target import Target
from ..contrib import popen_pool
from .. import __version__
from . import task
from .task import ConfigEntity, ApplyHistoryBest
@@ -230,7 +230,7 @@ def split_workload(in_file, clean=True):
lines = list(open(in_file).readlines())

logger.info("start converting...")
pool = multiprocessing.Pool()
pool = popen_pool.PopenPoolExecutor()
lines = [rec for rec in pool.map(decode, lines) if rec is not None]
logger.info("map done %.2f", time.time() - tic)

4 changes: 2 additions & 2 deletions python/tvm/autotvm/utils.py
Original file line number Diff line number Diff line change
@@ -17,14 +17,14 @@
# pylint: disable=invalid-name
"""Utilities"""
import logging
import multiprocessing
import time

from random import randrange

import numpy as np
import tvm.arith
from tvm.tir import expr
from tvm.contrib.popen_pool import PopenPoolExecutor

logger = logging.getLogger("autotvm")

@@ -111,7 +111,7 @@ def pool_map(func, args, batch_size, verbose=False, pool=None):

ret = None
tic = time.time()
local_pool = pool or multiprocessing.Pool()
local_pool = pool or PopenPoolExecutor()
if verbose:
logger.info("mapping begin")
for i in range(0, len(args), batch_size):
9 changes: 8 additions & 1 deletion python/tvm/contrib/popen_pool.py
Original file line number Diff line number Diff line change
@@ -269,9 +269,16 @@ class PopenPoolExecutor:
timeout : float
Timeout value for each function submit.
Note
----
If max_workers is NONE then the number returned by
os.cpu_count() is used. This method aligns with the
behavior of multiprocessing.pool().
"""

def __init__(self, max_workers, timeout=None):
def __init__(self, max_workers=None, timeout=None):
if max_workers is None:
max_workers = os.cpu_count()
# Use an internal thread pool to send to popen workers
self._threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
self._timeout = timeout
34 changes: 34 additions & 0 deletions python/tvm/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin, wildcard-import
"""Utility Python functions for TVM testing"""
from .utils import assert_allclose, assert_prim_expr_equal, check_bool_expr_is_true
from .utils import check_int_constraints_trans_consistency, check_numerical_grads
from .utils import device_enabled, enabled_targets, exclude_targets
from .utils import fixture, parameter, parameters, parametrize_targets, uses_gpu
from .utils import known_failing_targets, requires_cuda, requires_cudagraph
from .utils import requires_gpu, requires_llvm, requires_rocm, requires_rpc
from .utils import requires_tensorcore, requires_metal, requires_micro, requires_opencl
from .utils import _auto_parametrize_target, _count_num_fixture_uses
from .utils import _remove_global_fixture_definitions, _parametrize_correlated_parameters
from .utils import _pytest_target_params, identity_after, terminate_self

from ._ffi_api import nop, echo, device_test, run_check_signal, object_use_count
from ._ffi_api import test_wrap_callback, test_raise_error_callback, test_check_eq_callback
from ._ffi_api import ErrorTest, FrontendTestModule

from . import auto_scheduler
21 changes: 21 additions & 0 deletions python/tvm/testing/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.testing"""
import tvm._ffi


tvm._ffi._init_api("testing", __name__)
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name, missing-function-docstring
"""Common functions for auto_scheduler test cases"""
import tvm
from tvm import auto_scheduler, te, topi
3 changes: 0 additions & 3 deletions python/tvm/testing.py → python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
@@ -1376,6 +1376,3 @@ def identity_after(x, sleep):
def terminate_self():
"""Testing function to terminate the process."""
sys.exit(-1)


tvm._ffi._init_api("testing", __name__)
2 changes: 1 addition & 1 deletion tests/python/unittest/test_auto_scheduler_compute_dag.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
from tvm import topi
from tvm import auto_scheduler, te

from test_auto_scheduler_common import (
from tvm.testing.auto_scheduler import (
get_tiled_matmul,
invalid_compute_definition,
matmul_auto_scheduler_test,
2 changes: 1 addition & 1 deletion tests/python/unittest/test_auto_scheduler_cost_model.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@
import tvm
from tvm import auto_scheduler

from test_auto_scheduler_common import matmul_auto_scheduler_test
from tvm.testing.auto_scheduler import matmul_auto_scheduler_test


def get_sample_records(number):
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@

import tvm
import pytest
from test_auto_scheduler_common import matmul_auto_scheduler_test
from tvm.testing.auto_scheduler import matmul_auto_scheduler_test
from tvm import auto_scheduler, te
from tvm.auto_scheduler.cost_model.cost_model import PythonBasedModel

2 changes: 1 addition & 1 deletion tests/python/unittest/test_auto_scheduler_feature.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
import tvm
from tvm import te, auto_scheduler

from test_auto_scheduler_common import matmul_auto_scheduler_test
from tvm.testing.auto_scheduler import matmul_auto_scheduler_test


def fequal(a, b):
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@
from tvm import topi
from tvm import auto_scheduler, te

from test_auto_scheduler_common import get_tiled_matmul, matmul_auto_scheduler_test
from tvm.testing.auto_scheduler import get_tiled_matmul, matmul_auto_scheduler_test


def test_apply_steps_with_layout_rewrite():
2 changes: 1 addition & 1 deletion tests/python/unittest/test_auto_scheduler_loop_state.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
from tvm import auto_scheduler, te
from tvm import topi

from test_auto_scheduler_common import (
from tvm.testing.auto_scheduler import (
matmul_auto_scheduler_test,
conv2d_nchw_bn_relu_auto_scheduler_test,
)
2 changes: 1 addition & 1 deletion tests/python/unittest/test_auto_scheduler_measure.py
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@
import tempfile
import tvm.testing
import pickle
from test_auto_scheduler_common import matmul_auto_scheduler_test
from tvm.testing.auto_scheduler import matmul_auto_scheduler_test
from tvm.auto_scheduler import workload_registry


2 changes: 1 addition & 1 deletion tests/python/unittest/test_auto_scheduler_search_policy.py
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@
from tvm import auto_scheduler
from tvm.auto_scheduler.utils import get_const_tuple

from test_auto_scheduler_common import (
from tvm.testing.auto_scheduler import (
matmul_auto_scheduler_test,
zero_rank_compute_auto_scheduler_test,
zero_rank_reduce_auto_scheduler_test,
2 changes: 1 addition & 1 deletion tests/python/unittest/test_auto_scheduler_search_task.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@
import tvm.testing
from tvm import auto_scheduler
from tvm.auto_scheduler.utils import get_const_tuple
from test_auto_scheduler_common import (
from tvm.testing.auto_scheduler import (
matmul_auto_scheduler_test,
zero_rank_compute_auto_scheduler_test,
zero_rank_reduce_auto_scheduler_test,
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@
from tvm.auto_scheduler import _ffi_api
from tvm.auto_scheduler.loop_state import Stage

from test_auto_scheduler_common import (
from tvm.testing.auto_scheduler import (
matmul_auto_scheduler_test,
double_matmul_auto_scheduler_test,
conv2d_nchw_bn_relu_auto_scheduler_test,
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@
import tvm.testing
from tvm import auto_scheduler

from test_auto_scheduler_common import matmul_auto_scheduler_test
from tvm.testing.auto_scheduler import matmul_auto_scheduler_test


@tvm.testing.requires_llvm

0 comments on commit 01274b4

Please sign in to comment.