From 8c001876e8d5df4c59bdf9821867852ffe1aa6a0 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 21 Aug 2020 07:04:56 -0700 Subject: [PATCH] [Target] Add python binding to new JSON target construction. (#6315) * Add python binding to new JSON target construction. * Added json string parsing and new test. * Add error type. * Add error type in json decoding check. * Fix sphinx formatting. --- python/tvm/target/target.py | 56 +++++++++++++++++---- src/target/target.cc | 4 +- tests/python/unittest/test_target_target.py | 47 +++++++++++++++++ 3 files changed, 97 insertions(+), 10 deletions(-) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 0ea19f81f7653..9dcc164be78ae 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -17,6 +17,7 @@ """Target data structure.""" import os import re +import json import warnings import tvm._ffi @@ -347,13 +348,43 @@ def create_llvm(llvm_args): return _ffi_api.TargetCreate('hexagon', *args_list) -def create(target_str): +def create(target): """Get a target given target string. Parameters ---------- - target_str : str - The target string. + target : str or dict + Can be one of a literal target string, a json string describing + a configuration, or a dictionary of configuration options. + When using a dictionary or json string to configure target, the + possible values are: + + kind : str (required) + Which codegen path to use, for example 'llvm' or 'cuda'. + keys : List of str (optional) + A set of strategies that can be dispatched to. When using + "kind=opencl" for example, one could set keys to ["mali", "opencl", "gpu"]. + device : str (optional) + A single key that corresponds to the actual device being run on. + This will be effectively appended to the keys. + libs : List of str (optional) + The set of external libraries to use. For example ['cblas', 'mkl']. + system-lib : bool (optional) + If True, build a module that contains self registered functions. + Useful for environments where dynamic loading like dlopen is banned. + mcpu : str (optional) + The specific cpu being run on. Serves only as an annotation. + model : str (optional) + An annotation indicating what model a workload came from. + runtime : str (optional) + An annotation indicating which runtime to use with a workload. + mtriple : str (optional) + The llvm triplet describing the target, for example "arm64-linux-android". + mattr : List of str (optional) + The llvm features to compile with, for example ["+avx512f", "+mmx"]. + mfloat-abi : str (optional) + An llvm setting that is one of 'hard' or 'soft' indicating whether to use + hardware or software floating-point operations. Returns ------- @@ -364,9 +395,16 @@ def create(target_str): ---- See the note on :py:mod:`tvm.target` on target string format. """ - if isinstance(target_str, Target): - return target_str - if not isinstance(target_str, str): - raise ValueError("target_str has to be string type") - - return _ffi_api.TargetFromString(target_str) + if isinstance(target, Target): + return target + if isinstance(target, dict): + return _ffi_api.TargetFromConfig(target) + if isinstance(target, str): + # Check if target is a valid json string by trying to load it. + # If we cant, then assume it is a non-json target string. + try: + return _ffi_api.TargetFromConfig(json.loads(target)) + except json.decoder.JSONDecodeError: + return _ffi_api.TargetFromString(target) + + raise ValueError("target has to be a string or dictionary.") diff --git a/src/target/target.cc b/src/target/target.cc index 6a245973315ed..47b405430ead6 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -410,7 +410,7 @@ Target Target::FromConfig(const Map& config_dict) { const auto* cfg_keys = config[kKeys].as(); CHECK(cfg_keys != nullptr) << "AttributeError: Expect type of field 'keys' is an Array, but get: " - << config[kTag]->GetTypeKey(); + << config[kKeys]->GetTypeKey(); for (const ObjectRef& e : *cfg_keys) { const auto* key = e.as(); CHECK(key != nullptr) << "AttributeError: Expect 'keys' to be an array of strings, but it " @@ -525,6 +525,8 @@ TVM_REGISTER_GLOBAL("target.GetCurrentTarget").set_body_typed(Target::Current); TVM_REGISTER_GLOBAL("target.TargetFromString").set_body_typed(Target::Create); +TVM_REGISTER_GLOBAL("target.TargetFromConfig").set_body_typed(Target::FromConfig); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index 4258da9f576ed..d19f122d08f7f 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import json import tvm from tvm import te from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost, hexagon @@ -80,7 +81,53 @@ def test_target_create(): assert tgt is not None +def test_target_config(): + """ + Test that constructing a target from a dictionary works. + """ + target_config = { + 'kind': 'llvm', + 'keys': ['arm_cpu', 'cpu'], + 'device': 'arm_cpu', + 'libs': ['cblas'], + 'system-lib': True, + 'mfloat-abi': 'hard', + 'mattr': ['+neon', '-avx512f'], + } + # Convert config dictionary to json string. + target_config_str = json.dumps(target_config) + # Test both dictionary input and json string. + for config in [target_config, target_config_str]: + target = tvm.target.create(config) + assert target.kind.name == 'llvm' + assert all([key in target.keys for key in ['arm_cpu', 'cpu']]) + assert target.device_name == 'arm_cpu' + assert target.libs == ['cblas'] + assert 'system-lib' in str(target) + assert target.attrs['mfloat-abi'] == 'hard' + assert all([attr in target.attrs['mattr'] for attr in ['+neon', '-avx512f']]) + + +def test_config_map(): + """ + Confirm that constructing a target with invalid + attributes fails as expected. + """ + target_config = { + 'kind': 'llvm', + 'libs': {'a': 'b', 'c': 'd'} + } + failed = False + try: + target = tvm.target.create(target_config) + except AttributeError: + failed = True + assert failed == True + + if __name__ == "__main__": test_target_dispatch() test_target_string_parse() test_target_create() + test_target_config() + test_config_map()