Skip to content

Commit

Permalink
[Target] Add python binding to new JSON target construction. (apache#…
Browse files Browse the repository at this point in the history
…6315)

* Add python binding to new JSON target construction.

* Added json string parsing and new test.

* Add error type.

* Add error type in json decoding check.

* Fix sphinx formatting.
  • Loading branch information
jwfromm authored and Trevor Morris committed Aug 26, 2020
1 parent 0e2f3ef commit d4031db
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 10 deletions.
56 changes: 47 additions & 9 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Target data structure."""
import os
import re
import json
import warnings
import tvm._ffi

Expand Down Expand Up @@ -347,13 +348,43 @@ def create_llvm(llvm_args):
return _ffi_api.TargetCreate('hexagon', *args_list)


def create(target_str):
def create(target):
"""Get a target given target string.
Parameters
----------
target_str : str
The target string.
target : str or dict
Can be one of a literal target string, a json string describing
a configuration, or a dictionary of configuration options.
When using a dictionary or json string to configure target, the
possible values are:
kind : str (required)
Which codegen path to use, for example 'llvm' or 'cuda'.
keys : List of str (optional)
A set of strategies that can be dispatched to. When using
"kind=opencl" for example, one could set keys to ["mali", "opencl", "gpu"].
device : str (optional)
A single key that corresponds to the actual device being run on.
This will be effectively appended to the keys.
libs : List of str (optional)
The set of external libraries to use. For example ['cblas', 'mkl'].
system-lib : bool (optional)
If True, build a module that contains self registered functions.
Useful for environments where dynamic loading like dlopen is banned.
mcpu : str (optional)
The specific cpu being run on. Serves only as an annotation.
model : str (optional)
An annotation indicating what model a workload came from.
runtime : str (optional)
An annotation indicating which runtime to use with a workload.
mtriple : str (optional)
The llvm triplet describing the target, for example "arm64-linux-android".
mattr : List of str (optional)
The llvm features to compile with, for example ["+avx512f", "+mmx"].
mfloat-abi : str (optional)
An llvm setting that is one of 'hard' or 'soft' indicating whether to use
hardware or software floating-point operations.
Returns
-------
Expand All @@ -364,9 +395,16 @@ def create(target_str):
----
See the note on :py:mod:`tvm.target` on target string format.
"""
if isinstance(target_str, Target):
return target_str
if not isinstance(target_str, str):
raise ValueError("target_str has to be string type")

return _ffi_api.TargetFromString(target_str)
if isinstance(target, Target):
return target
if isinstance(target, dict):
return _ffi_api.TargetFromConfig(target)
if isinstance(target, str):
# Check if target is a valid json string by trying to load it.
# If we cant, then assume it is a non-json target string.
try:
return _ffi_api.TargetFromConfig(json.loads(target))
except json.decoder.JSONDecodeError:
return _ffi_api.TargetFromString(target)

raise ValueError("target has to be a string or dictionary.")
4 changes: 3 additions & 1 deletion src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ Target Target::FromConfig(const Map<String, ObjectRef>& config_dict) {
const auto* cfg_keys = config[kKeys].as<ArrayNode>();
CHECK(cfg_keys != nullptr)
<< "AttributeError: Expect type of field 'keys' is an Array, but get: "
<< config[kTag]->GetTypeKey();
<< config[kKeys]->GetTypeKey();
for (const ObjectRef& e : *cfg_keys) {
const auto* key = e.as<StringObj>();
CHECK(key != nullptr) << "AttributeError: Expect 'keys' to be an array of strings, but it "
Expand Down Expand Up @@ -525,6 +525,8 @@ TVM_REGISTER_GLOBAL("target.GetCurrentTarget").set_body_typed(Target::Current);

TVM_REGISTER_GLOBAL("target.TargetFromString").set_body_typed(Target::Create);

TVM_REGISTER_GLOBAL("target.TargetFromConfig").set_body_typed(Target::FromConfig);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const TargetNode*>(node.get());
Expand Down
47 changes: 47 additions & 0 deletions tests/python/unittest/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import tvm
from tvm import te
from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost, hexagon
Expand Down Expand Up @@ -80,7 +81,53 @@ def test_target_create():
assert tgt is not None


def test_target_config():
"""
Test that constructing a target from a dictionary works.
"""
target_config = {
'kind': 'llvm',
'keys': ['arm_cpu', 'cpu'],
'device': 'arm_cpu',
'libs': ['cblas'],
'system-lib': True,
'mfloat-abi': 'hard',
'mattr': ['+neon', '-avx512f'],
}
# Convert config dictionary to json string.
target_config_str = json.dumps(target_config)
# Test both dictionary input and json string.
for config in [target_config, target_config_str]:
target = tvm.target.create(config)
assert target.kind.name == 'llvm'
assert all([key in target.keys for key in ['arm_cpu', 'cpu']])
assert target.device_name == 'arm_cpu'
assert target.libs == ['cblas']
assert 'system-lib' in str(target)
assert target.attrs['mfloat-abi'] == 'hard'
assert all([attr in target.attrs['mattr'] for attr in ['+neon', '-avx512f']])


def test_config_map():
"""
Confirm that constructing a target with invalid
attributes fails as expected.
"""
target_config = {
'kind': 'llvm',
'libs': {'a': 'b', 'c': 'd'}
}
failed = False
try:
target = tvm.target.create(target_config)
except AttributeError:
failed = True
assert failed == True


if __name__ == "__main__":
test_target_dispatch()
test_target_string_parse()
test_target_create()
test_target_config()
test_config_map()

0 comments on commit d4031db

Please sign in to comment.